unix: add RecvmsgBuffers and SendmsgBuffers
Fixes golang/go#52885
Change-Id: I04b5be1ac9543a3791ebc4cd59b9e35e958e0ba2
Reviewed-on: https://go-review.googlesource.com/c/sys/+/412497
Auto-Submit: Tobias Klauser <tobias.klauser@gmail.com>
Reviewed-by: Tobias Klauser <tobias.klauser@gmail.com>
Reviewed-by: Ian Lance Taylor <iant@google.com>
Reviewed-by: Carlos Amedee <carlos@golang.org>
Auto-Submit: Ian Lance Taylor <iant@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Ian Lance Taylor <iant@google.com>
diff --git a/unix/syscall_aix.go b/unix/syscall_aix.go
index ad22c33..ac579c6 100644
--- a/unix/syscall_aix.go
+++ b/unix/syscall_aix.go
@@ -217,12 +217,12 @@
return
}
-func recvmsgRaw(fd int, p, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) {
+func recvmsgRaw(fd int, iov []Iovec, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) {
// Recvmsg not implemented on AIX
return -1, -1, -1, ENOSYS
}
-func sendmsgN(fd int, p, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) {
+func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) {
// SendmsgN not implemented on AIX
return -1, ENOSYS
}
diff --git a/unix/syscall_bsd.go b/unix/syscall_bsd.go
index 9c87c5f..c437fc5 100644
--- a/unix/syscall_bsd.go
+++ b/unix/syscall_bsd.go
@@ -325,27 +325,26 @@
//sys sendto(s int, buf []byte, flags int, to unsafe.Pointer, addrlen _Socklen) (err error)
//sys recvmsg(s int, msg *Msghdr, flags int) (n int, err error)
-func recvmsgRaw(fd int, p, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) {
+func recvmsgRaw(fd int, iov []Iovec, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) {
var msg Msghdr
msg.Name = (*byte)(unsafe.Pointer(rsa))
msg.Namelen = uint32(SizeofSockaddrAny)
- var iov Iovec
- if len(p) > 0 {
- iov.Base = (*byte)(unsafe.Pointer(&p[0]))
- iov.SetLen(len(p))
- }
var dummy byte
if len(oob) > 0 {
// receive at least one normal byte
- if len(p) == 0 {
- iov.Base = &dummy
- iov.SetLen(1)
+ if emptyIovecs(iov) {
+ var iova [1]Iovec
+ iova[0].Base = &dummy
+ iova[0].SetLen(1)
+ iov = iova[:]
}
msg.Control = (*byte)(unsafe.Pointer(&oob[0]))
msg.SetControllen(len(oob))
}
- msg.Iov = &iov
- msg.Iovlen = 1
+ if len(iov) > 0 {
+ msg.Iov = &iov[0]
+ msg.SetIovlen(len(iov))
+ }
if n, err = recvmsg(fd, &msg, flags); err != nil {
return
}
@@ -356,31 +355,32 @@
//sys sendmsg(s int, msg *Msghdr, flags int) (n int, err error)
-func sendmsgN(fd int, p, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) {
+func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) {
var msg Msghdr
msg.Name = (*byte)(unsafe.Pointer(ptr))
msg.Namelen = uint32(salen)
- var iov Iovec
- if len(p) > 0 {
- iov.Base = (*byte)(unsafe.Pointer(&p[0]))
- iov.SetLen(len(p))
- }
var dummy byte
+ var empty bool
if len(oob) > 0 {
// send at least one normal byte
- if len(p) == 0 {
- iov.Base = &dummy
- iov.SetLen(1)
+ empty := emptyIovecs(iov)
+ if empty {
+ var iova [1]Iovec
+ iova[0].Base = &dummy
+ iova[0].SetLen(1)
+ iov = iova[:]
}
msg.Control = (*byte)(unsafe.Pointer(&oob[0]))
msg.SetControllen(len(oob))
}
- msg.Iov = &iov
- msg.Iovlen = 1
+ if len(iov) > 0 {
+ msg.Iov = &iov[0]
+ msg.SetIovlen(len(iov))
+ }
if n, err = sendmsg(fd, &msg, flags); err != nil {
return 0, err
}
- if len(oob) > 0 && len(p) == 0 {
+ if len(oob) > 0 && empty {
n = 0
}
return n, nil
diff --git a/unix/syscall_linux.go b/unix/syscall_linux.go
index c8d2032..5e4a94f 100644
--- a/unix/syscall_linux.go
+++ b/unix/syscall_linux.go
@@ -1499,18 +1499,13 @@
//sys keyctlRestrictKeyringByType(cmd int, arg2 int, keyType string, restriction string) (err error) = SYS_KEYCTL
//sys keyctlRestrictKeyring(cmd int, arg2 int) (err error) = SYS_KEYCTL
-func recvmsgRaw(fd int, p, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) {
+func recvmsgRaw(fd int, iov []Iovec, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) {
var msg Msghdr
msg.Name = (*byte)(unsafe.Pointer(rsa))
msg.Namelen = uint32(SizeofSockaddrAny)
- var iov Iovec
- if len(p) > 0 {
- iov.Base = &p[0]
- iov.SetLen(len(p))
- }
var dummy byte
if len(oob) > 0 {
- if len(p) == 0 {
+ if emptyIovecs(iov) {
var sockType int
sockType, err = GetsockoptInt(fd, SOL_SOCKET, SO_TYPE)
if err != nil {
@@ -1518,15 +1513,19 @@
}
// receive at least one normal byte
if sockType != SOCK_DGRAM {
- iov.Base = &dummy
- iov.SetLen(1)
+ var iova [1]Iovec
+ iova[0].Base = &dummy
+ iova[0].SetLen(1)
+ iov = iova[:]
}
}
msg.Control = &oob[0]
msg.SetControllen(len(oob))
}
- msg.Iov = &iov
- msg.Iovlen = 1
+ if len(iov) > 0 {
+ msg.Iov = &iov[0]
+ msg.SetIovlen(len(iov))
+ }
if n, err = recvmsg(fd, &msg, flags); err != nil {
return
}
@@ -1535,18 +1534,15 @@
return
}
-func sendmsgN(fd int, p, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) {
+func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) {
var msg Msghdr
msg.Name = (*byte)(ptr)
msg.Namelen = uint32(salen)
- var iov Iovec
- if len(p) > 0 {
- iov.Base = &p[0]
- iov.SetLen(len(p))
- }
var dummy byte
+ var empty bool
if len(oob) > 0 {
- if len(p) == 0 {
+ empty := emptyIovecs(iov)
+ if empty {
var sockType int
sockType, err = GetsockoptInt(fd, SOL_SOCKET, SO_TYPE)
if err != nil {
@@ -1554,19 +1550,22 @@
}
// send at least one normal byte
if sockType != SOCK_DGRAM {
- iov.Base = &dummy
- iov.SetLen(1)
+ var iova [1]Iovec
+ iova[0].Base = &dummy
+ iova[0].SetLen(1)
}
}
msg.Control = &oob[0]
msg.SetControllen(len(oob))
}
- msg.Iov = &iov
- msg.Iovlen = 1
+ if len(iov) > 0 {
+ msg.Iov = &iov[0]
+ msg.SetIovlen(len(iov))
+ }
if n, err = sendmsg(fd, &msg, flags); err != nil {
return 0, err
}
- if len(oob) > 0 && len(p) == 0 {
+ if len(oob) > 0 && empty {
n = 0
}
return n, nil
diff --git a/unix/syscall_solaris.go b/unix/syscall_solaris.go
index cd492d7..b5ec457 100644
--- a/unix/syscall_solaris.go
+++ b/unix/syscall_solaris.go
@@ -451,26 +451,25 @@
//sys recvmsg(s int, msg *Msghdr, flags int) (n int, err error) = libsocket.__xnet_recvmsg
-func recvmsgRaw(fd int, p, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) {
+func recvmsgRaw(fd int, iov []Iovec, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) {
var msg Msghdr
msg.Name = (*byte)(unsafe.Pointer(rsa))
msg.Namelen = uint32(SizeofSockaddrAny)
- var iov Iovec
- if len(p) > 0 {
- iov.Base = &p[0]
- iov.SetLen(len(p))
- }
var dummy byte
if len(oob) > 0 {
// receive at least one normal byte
- if len(p) == 0 {
- iov.Base = &dummy
- iov.SetLen(1)
+ if emptyIovecs(iov) {
+ var iova [1]Iovec
+ iova[0].Base = &dummy
+ iova[0].SetLen(1)
+ iov = iova[:]
}
msg.Accrightslen = int32(len(oob))
}
- msg.Iov = &iov
- msg.Iovlen = 1
+ if len(iov) > 0 {
+ msg.Iov = &iov[0]
+ msg.SetIovlen(len(iov))
+ }
if n, err = recvmsg(fd, &msg, flags); n == -1 {
return
}
@@ -480,30 +479,31 @@
//sys sendmsg(s int, msg *Msghdr, flags int) (n int, err error) = libsocket.__xnet_sendmsg
-func sendmsgN(fd int, p, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) {
+func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) {
var msg Msghdr
msg.Name = (*byte)(unsafe.Pointer(ptr))
msg.Namelen = uint32(salen)
- var iov Iovec
- if len(p) > 0 {
- iov.Base = &p[0]
- iov.SetLen(len(p))
- }
var dummy byte
+ var empty bool
if len(oob) > 0 {
// send at least one normal byte
- if len(p) == 0 {
- iov.Base = &dummy
- iov.SetLen(1)
+ empty = emptyIovecs(iov)
+ if empty {
+ var iova [1]Iovec
+ iova[0].Base = &dummy
+ iova[0].SetLen(1)
+ iov = iova[:]
}
msg.Accrightslen = int32(len(oob))
}
- msg.Iov = &iov
- msg.Iovlen = 1
+ if len(iov) > 0 {
+ msg.Iov = &iov[0]
+ msg.SetIovlen(len(iov))
+ }
if n, err = sendmsg(fd, &msg, flags); err != nil {
return 0, err
}
- if len(oob) > 0 && len(p) == 0 {
+ if len(oob) > 0 && empty {
n = 0
}
return n, nil
diff --git a/unix/syscall_unix.go b/unix/syscall_unix.go
index 70508af..1ff5060 100644
--- a/unix/syscall_unix.go
+++ b/unix/syscall_unix.go
@@ -338,8 +338,13 @@
}
func Recvmsg(fd int, p, oob []byte, flags int) (n, oobn int, recvflags int, from Sockaddr, err error) {
+ var iov [1]Iovec
+ if len(p) > 0 {
+ iov[0].Base = &p[0]
+ iov[0].SetLen(len(p))
+ }
var rsa RawSockaddrAny
- n, oobn, recvflags, err = recvmsgRaw(fd, p, oob, flags, &rsa)
+ n, oobn, recvflags, err = recvmsgRaw(fd, iov[:], oob, flags, &rsa)
// source address is only specified if the socket is unconnected
if rsa.Addr.Family != AF_UNSPEC {
from, err = anyToSockaddr(fd, &rsa)
@@ -347,12 +352,42 @@
return
}
+// RecvmsgBuffers receives a message from a socket using the recvmsg
+// system call. The flags are passed to recvmsg. Any non-control data
+// read is scattered into the buffers slices. The results are:
+// - n is the number of non-control data read into bufs
+// - oobn is the number of control data read into oob; this may be interpreted using [ParseSocketControlMessage]
+// - recvflags is flags returned by recvmsg
+// - from is the address of the sender
+func RecvmsgBuffers(fd int, buffers [][]byte, oob []byte, flags int) (n, oobn int, recvflags int, from Sockaddr, err error) {
+ iov := make([]Iovec, len(buffers))
+ for i := range buffers {
+ if len(buffers[i]) > 0 {
+ iov[i].Base = &buffers[i][0]
+ iov[i].SetLen(len(buffers[i]))
+ } else {
+ iov[i].Base = (*byte)(unsafe.Pointer(&_zero))
+ }
+ }
+ var rsa RawSockaddrAny
+ n, oobn, recvflags, err = recvmsgRaw(fd, iov, oob, flags, &rsa)
+ if err == nil && rsa.Addr.Family != AF_UNSPEC {
+ from, err = anyToSockaddr(fd, &rsa)
+ }
+ return
+}
+
func Sendmsg(fd int, p, oob []byte, to Sockaddr, flags int) (err error) {
_, err = SendmsgN(fd, p, oob, to, flags)
return
}
func SendmsgN(fd int, p, oob []byte, to Sockaddr, flags int) (n int, err error) {
+ var iov [1]Iovec
+ if len(p) > 0 {
+ iov[0].Base = &p[0]
+ iov[0].SetLen(len(p))
+ }
var ptr unsafe.Pointer
var salen _Socklen
if to != nil {
@@ -361,7 +396,32 @@
return 0, err
}
}
- return sendmsgN(fd, p, oob, ptr, salen, flags)
+ return sendmsgN(fd, iov[:], oob, ptr, salen, flags)
+}
+
+// SendmsgBuffers sends a message on a socket to an address using the sendmsg
+// system call. The flags are passed to sendmsg. Any non-control data written
+// is gathered from buffers. The function returns the number of bytes written
+// to the socket.
+func SendmsgBuffers(fd int, buffers [][]byte, oob []byte, to Sockaddr, flags int) (n int, err error) {
+ iov := make([]Iovec, len(buffers))
+ for i := range buffers {
+ if len(buffers[i]) > 0 {
+ iov[i].Base = &buffers[i][0]
+ iov[i].SetLen(len(buffers[i]))
+ } else {
+ iov[i].Base = (*byte)(unsafe.Pointer(&_zero))
+ }
+ }
+ var ptr unsafe.Pointer
+ var salen _Socklen
+ if to != nil {
+ ptr, salen, err = to.sockaddr()
+ if err != nil {
+ return 0, err
+ }
+ }
+ return sendmsgN(fd, iov, oob, ptr, salen, flags)
}
func Send(s int, buf []byte, flags int) (err error) {
@@ -484,3 +544,13 @@
}
return UtimesNanoAt(AT_FDCWD, path, ts, AT_SYMLINK_NOFOLLOW)
}
+
+// emptyIovec reports whether there are no bytes in the slice of Iovec.
+func emptyIovecs(iov []Iovec) bool {
+ for i := range iov {
+ if iov[i].Len > 0 {
+ return false
+ }
+ }
+ return true
+}
diff --git a/unix/syscall_unix_test.go b/unix/syscall_unix_test.go
index 81db934..c20afbe 100644
--- a/unix/syscall_unix_test.go
+++ b/unix/syscall_unix_test.go
@@ -16,8 +16,10 @@
"os"
"os/exec"
"path/filepath"
+ "reflect"
"runtime"
"strconv"
+ "sync"
"syscall"
"testing"
"time"
@@ -954,6 +956,68 @@
}
}
+func TestSendmsgBuffers(t *testing.T) {
+ if runtime.GOOS == "aix" {
+ t.Skipf("SendmsgBuffers not supported on %s", runtime.GOOS)
+ }
+
+ fds, err := unix.Socketpair(unix.AF_LOCAL, unix.SOCK_STREAM, 0)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer unix.Close(fds[0])
+ defer unix.Close(fds[1])
+
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ bufs := [][]byte{
+ make([]byte, 5),
+ nil,
+ make([]byte, 5),
+ }
+ n, oobn, recvflags, _, err := unix.RecvmsgBuffers(fds[1], bufs, nil, 0)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != 10 {
+ t.Errorf("got %d bytes, want 10", n)
+ }
+ if oobn != 0 {
+ t.Errorf("got %d OOB bytes, want 0", oobn)
+ }
+ if recvflags != 0 {
+ t.Errorf("got flags %#x, want %#x", recvflags, 0)
+ }
+ want := [][]byte{
+ []byte("01234"),
+ nil,
+ []byte("56789"),
+ }
+ if !reflect.DeepEqual(bufs, want) {
+ t.Errorf("got data %q, want %q", bufs, want)
+ }
+ }()
+
+ defer wg.Wait()
+
+ bufs := [][]byte{
+ []byte("012"),
+ []byte("34"),
+ nil,
+ []byte("5678"),
+ []byte("9"),
+ }
+ n, err := unix.SendmsgBuffers(fds[0], bufs, nil, nil, 0)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != 10 {
+ t.Errorf("sent %d bytes, want 10", n)
+ }
+}
+
// mktmpfifo creates a temporary FIFO and provides a cleanup function.
func mktmpfifo(t *testing.T) (*os.File, func()) {
err := unix.Mkfifo("fifo", 0666)