unix: add RecvmsgBuffers and SendmsgBuffers

Fixes golang/go#52885

Change-Id: I04b5be1ac9543a3791ebc4cd59b9e35e958e0ba2
Reviewed-on: https://go-review.googlesource.com/c/sys/+/412497
Auto-Submit: Tobias Klauser <tobias.klauser@gmail.com>
Reviewed-by: Tobias Klauser <tobias.klauser@gmail.com>
Reviewed-by: Ian Lance Taylor <iant@google.com>
Reviewed-by: Carlos Amedee <carlos@golang.org>
Auto-Submit: Ian Lance Taylor <iant@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Ian Lance Taylor <iant@google.com>
diff --git a/unix/syscall_aix.go b/unix/syscall_aix.go
index ad22c33..ac579c6 100644
--- a/unix/syscall_aix.go
+++ b/unix/syscall_aix.go
@@ -217,12 +217,12 @@
 	return
 }
 
-func recvmsgRaw(fd int, p, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) {
+func recvmsgRaw(fd int, iov []Iovec, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) {
 	// Recvmsg not implemented on AIX
 	return -1, -1, -1, ENOSYS
 }
 
-func sendmsgN(fd int, p, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) {
+func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) {
 	// SendmsgN not implemented on AIX
 	return -1, ENOSYS
 }
diff --git a/unix/syscall_bsd.go b/unix/syscall_bsd.go
index 9c87c5f..c437fc5 100644
--- a/unix/syscall_bsd.go
+++ b/unix/syscall_bsd.go
@@ -325,27 +325,26 @@
 //sys	sendto(s int, buf []byte, flags int, to unsafe.Pointer, addrlen _Socklen) (err error)
 //sys	recvmsg(s int, msg *Msghdr, flags int) (n int, err error)
 
-func recvmsgRaw(fd int, p, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) {
+func recvmsgRaw(fd int, iov []Iovec, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) {
 	var msg Msghdr
 	msg.Name = (*byte)(unsafe.Pointer(rsa))
 	msg.Namelen = uint32(SizeofSockaddrAny)
-	var iov Iovec
-	if len(p) > 0 {
-		iov.Base = (*byte)(unsafe.Pointer(&p[0]))
-		iov.SetLen(len(p))
-	}
 	var dummy byte
 	if len(oob) > 0 {
 		// receive at least one normal byte
-		if len(p) == 0 {
-			iov.Base = &dummy
-			iov.SetLen(1)
+		if emptyIovecs(iov) {
+			var iova [1]Iovec
+			iova[0].Base = &dummy
+			iova[0].SetLen(1)
+			iov = iova[:]
 		}
 		msg.Control = (*byte)(unsafe.Pointer(&oob[0]))
 		msg.SetControllen(len(oob))
 	}
-	msg.Iov = &iov
-	msg.Iovlen = 1
+	if len(iov) > 0 {
+		msg.Iov = &iov[0]
+		msg.SetIovlen(len(iov))
+	}
 	if n, err = recvmsg(fd, &msg, flags); err != nil {
 		return
 	}
@@ -356,31 +355,32 @@
 
 //sys	sendmsg(s int, msg *Msghdr, flags int) (n int, err error)
 
-func sendmsgN(fd int, p, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) {
+func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) {
 	var msg Msghdr
 	msg.Name = (*byte)(unsafe.Pointer(ptr))
 	msg.Namelen = uint32(salen)
-	var iov Iovec
-	if len(p) > 0 {
-		iov.Base = (*byte)(unsafe.Pointer(&p[0]))
-		iov.SetLen(len(p))
-	}
 	var dummy byte
+	var empty bool
 	if len(oob) > 0 {
 		// send at least one normal byte
-		if len(p) == 0 {
-			iov.Base = &dummy
-			iov.SetLen(1)
+		empty := emptyIovecs(iov)
+		if empty {
+			var iova [1]Iovec
+			iova[0].Base = &dummy
+			iova[0].SetLen(1)
+			iov = iova[:]
 		}
 		msg.Control = (*byte)(unsafe.Pointer(&oob[0]))
 		msg.SetControllen(len(oob))
 	}
-	msg.Iov = &iov
-	msg.Iovlen = 1
+	if len(iov) > 0 {
+		msg.Iov = &iov[0]
+		msg.SetIovlen(len(iov))
+	}
 	if n, err = sendmsg(fd, &msg, flags); err != nil {
 		return 0, err
 	}
-	if len(oob) > 0 && len(p) == 0 {
+	if len(oob) > 0 && empty {
 		n = 0
 	}
 	return n, nil
diff --git a/unix/syscall_linux.go b/unix/syscall_linux.go
index c8d2032..5e4a94f 100644
--- a/unix/syscall_linux.go
+++ b/unix/syscall_linux.go
@@ -1499,18 +1499,13 @@
 //sys	keyctlRestrictKeyringByType(cmd int, arg2 int, keyType string, restriction string) (err error) = SYS_KEYCTL
 //sys	keyctlRestrictKeyring(cmd int, arg2 int) (err error) = SYS_KEYCTL
 
-func recvmsgRaw(fd int, p, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) {
+func recvmsgRaw(fd int, iov []Iovec, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) {
 	var msg Msghdr
 	msg.Name = (*byte)(unsafe.Pointer(rsa))
 	msg.Namelen = uint32(SizeofSockaddrAny)
-	var iov Iovec
-	if len(p) > 0 {
-		iov.Base = &p[0]
-		iov.SetLen(len(p))
-	}
 	var dummy byte
 	if len(oob) > 0 {
-		if len(p) == 0 {
+		if emptyIovecs(iov) {
 			var sockType int
 			sockType, err = GetsockoptInt(fd, SOL_SOCKET, SO_TYPE)
 			if err != nil {
@@ -1518,15 +1513,19 @@
 			}
 			// receive at least one normal byte
 			if sockType != SOCK_DGRAM {
-				iov.Base = &dummy
-				iov.SetLen(1)
+				var iova [1]Iovec
+				iova[0].Base = &dummy
+				iova[0].SetLen(1)
+				iov = iova[:]
 			}
 		}
 		msg.Control = &oob[0]
 		msg.SetControllen(len(oob))
 	}
-	msg.Iov = &iov
-	msg.Iovlen = 1
+	if len(iov) > 0 {
+		msg.Iov = &iov[0]
+		msg.SetIovlen(len(iov))
+	}
 	if n, err = recvmsg(fd, &msg, flags); err != nil {
 		return
 	}
@@ -1535,18 +1534,15 @@
 	return
 }
 
-func sendmsgN(fd int, p, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) {
+func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) {
 	var msg Msghdr
 	msg.Name = (*byte)(ptr)
 	msg.Namelen = uint32(salen)
-	var iov Iovec
-	if len(p) > 0 {
-		iov.Base = &p[0]
-		iov.SetLen(len(p))
-	}
 	var dummy byte
+	var empty bool
 	if len(oob) > 0 {
-		if len(p) == 0 {
+		empty := emptyIovecs(iov)
+		if empty {
 			var sockType int
 			sockType, err = GetsockoptInt(fd, SOL_SOCKET, SO_TYPE)
 			if err != nil {
@@ -1554,19 +1550,22 @@
 			}
 			// send at least one normal byte
 			if sockType != SOCK_DGRAM {
-				iov.Base = &dummy
-				iov.SetLen(1)
+				var iova [1]Iovec
+				iova[0].Base = &dummy
+				iova[0].SetLen(1)
 			}
 		}
 		msg.Control = &oob[0]
 		msg.SetControllen(len(oob))
 	}
-	msg.Iov = &iov
-	msg.Iovlen = 1
+	if len(iov) > 0 {
+		msg.Iov = &iov[0]
+		msg.SetIovlen(len(iov))
+	}
 	if n, err = sendmsg(fd, &msg, flags); err != nil {
 		return 0, err
 	}
-	if len(oob) > 0 && len(p) == 0 {
+	if len(oob) > 0 && empty {
 		n = 0
 	}
 	return n, nil
diff --git a/unix/syscall_solaris.go b/unix/syscall_solaris.go
index cd492d7..b5ec457 100644
--- a/unix/syscall_solaris.go
+++ b/unix/syscall_solaris.go
@@ -451,26 +451,25 @@
 
 //sys	recvmsg(s int, msg *Msghdr, flags int) (n int, err error) = libsocket.__xnet_recvmsg
 
-func recvmsgRaw(fd int, p, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) {
+func recvmsgRaw(fd int, iov []Iovec, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) {
 	var msg Msghdr
 	msg.Name = (*byte)(unsafe.Pointer(rsa))
 	msg.Namelen = uint32(SizeofSockaddrAny)
-	var iov Iovec
-	if len(p) > 0 {
-		iov.Base = &p[0]
-		iov.SetLen(len(p))
-	}
 	var dummy byte
 	if len(oob) > 0 {
 		// receive at least one normal byte
-		if len(p) == 0 {
-			iov.Base = &dummy
-			iov.SetLen(1)
+		if emptyIovecs(iov) {
+			var iova [1]Iovec
+			iova[0].Base = &dummy
+			iova[0].SetLen(1)
+			iov = iova[:]
 		}
 		msg.Accrightslen = int32(len(oob))
 	}
-	msg.Iov = &iov
-	msg.Iovlen = 1
+	if len(iov) > 0 {
+		msg.Iov = &iov[0]
+		msg.SetIovlen(len(iov))
+	}
 	if n, err = recvmsg(fd, &msg, flags); n == -1 {
 		return
 	}
@@ -480,30 +479,31 @@
 
 //sys	sendmsg(s int, msg *Msghdr, flags int) (n int, err error) = libsocket.__xnet_sendmsg
 
-func sendmsgN(fd int, p, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) {
+func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) {
 	var msg Msghdr
 	msg.Name = (*byte)(unsafe.Pointer(ptr))
 	msg.Namelen = uint32(salen)
-	var iov Iovec
-	if len(p) > 0 {
-		iov.Base = &p[0]
-		iov.SetLen(len(p))
-	}
 	var dummy byte
+	var empty bool
 	if len(oob) > 0 {
 		// send at least one normal byte
-		if len(p) == 0 {
-			iov.Base = &dummy
-			iov.SetLen(1)
+		empty = emptyIovecs(iov)
+		if empty {
+			var iova [1]Iovec
+			iova[0].Base = &dummy
+			iova[0].SetLen(1)
+			iov = iova[:]
 		}
 		msg.Accrightslen = int32(len(oob))
 	}
-	msg.Iov = &iov
-	msg.Iovlen = 1
+	if len(iov) > 0 {
+		msg.Iov = &iov[0]
+		msg.SetIovlen(len(iov))
+	}
 	if n, err = sendmsg(fd, &msg, flags); err != nil {
 		return 0, err
 	}
-	if len(oob) > 0 && len(p) == 0 {
+	if len(oob) > 0 && empty {
 		n = 0
 	}
 	return n, nil
diff --git a/unix/syscall_unix.go b/unix/syscall_unix.go
index 70508af..1ff5060 100644
--- a/unix/syscall_unix.go
+++ b/unix/syscall_unix.go
@@ -338,8 +338,13 @@
 }
 
 func Recvmsg(fd int, p, oob []byte, flags int) (n, oobn int, recvflags int, from Sockaddr, err error) {
+	var iov [1]Iovec
+	if len(p) > 0 {
+		iov[0].Base = &p[0]
+		iov[0].SetLen(len(p))
+	}
 	var rsa RawSockaddrAny
-	n, oobn, recvflags, err = recvmsgRaw(fd, p, oob, flags, &rsa)
+	n, oobn, recvflags, err = recvmsgRaw(fd, iov[:], oob, flags, &rsa)
 	// source address is only specified if the socket is unconnected
 	if rsa.Addr.Family != AF_UNSPEC {
 		from, err = anyToSockaddr(fd, &rsa)
@@ -347,12 +352,42 @@
 	return
 }
 
+// RecvmsgBuffers receives a message from a socket using the recvmsg
+// system call. The flags are passed to recvmsg. Any non-control data
+// read is scattered into the buffers slices. The results are:
+//   - n is the number of non-control data read into bufs
+//   - oobn is the number of control data read into oob; this may be interpreted using [ParseSocketControlMessage]
+//   - recvflags is flags returned by recvmsg
+//   - from is the address of the sender
+func RecvmsgBuffers(fd int, buffers [][]byte, oob []byte, flags int) (n, oobn int, recvflags int, from Sockaddr, err error) {
+	iov := make([]Iovec, len(buffers))
+	for i := range buffers {
+		if len(buffers[i]) > 0 {
+			iov[i].Base = &buffers[i][0]
+			iov[i].SetLen(len(buffers[i]))
+		} else {
+			iov[i].Base = (*byte)(unsafe.Pointer(&_zero))
+		}
+	}
+	var rsa RawSockaddrAny
+	n, oobn, recvflags, err = recvmsgRaw(fd, iov, oob, flags, &rsa)
+	if err == nil && rsa.Addr.Family != AF_UNSPEC {
+		from, err = anyToSockaddr(fd, &rsa)
+	}
+	return
+}
+
 func Sendmsg(fd int, p, oob []byte, to Sockaddr, flags int) (err error) {
 	_, err = SendmsgN(fd, p, oob, to, flags)
 	return
 }
 
 func SendmsgN(fd int, p, oob []byte, to Sockaddr, flags int) (n int, err error) {
+	var iov [1]Iovec
+	if len(p) > 0 {
+		iov[0].Base = &p[0]
+		iov[0].SetLen(len(p))
+	}
 	var ptr unsafe.Pointer
 	var salen _Socklen
 	if to != nil {
@@ -361,7 +396,32 @@
 			return 0, err
 		}
 	}
-	return sendmsgN(fd, p, oob, ptr, salen, flags)
+	return sendmsgN(fd, iov[:], oob, ptr, salen, flags)
+}
+
+// SendmsgBuffers sends a message on a socket to an address using the sendmsg
+// system call. The flags are passed to sendmsg. Any non-control data written
+// is gathered from buffers. The function returns the number of bytes written
+// to the socket.
+func SendmsgBuffers(fd int, buffers [][]byte, oob []byte, to Sockaddr, flags int) (n int, err error) {
+	iov := make([]Iovec, len(buffers))
+	for i := range buffers {
+		if len(buffers[i]) > 0 {
+			iov[i].Base = &buffers[i][0]
+			iov[i].SetLen(len(buffers[i]))
+		} else {
+			iov[i].Base = (*byte)(unsafe.Pointer(&_zero))
+		}
+	}
+	var ptr unsafe.Pointer
+	var salen _Socklen
+	if to != nil {
+		ptr, salen, err = to.sockaddr()
+		if err != nil {
+			return 0, err
+		}
+	}
+	return sendmsgN(fd, iov, oob, ptr, salen, flags)
 }
 
 func Send(s int, buf []byte, flags int) (err error) {
@@ -484,3 +544,13 @@
 	}
 	return UtimesNanoAt(AT_FDCWD, path, ts, AT_SYMLINK_NOFOLLOW)
 }
+
+// emptyIovec reports whether there are no bytes in the slice of Iovec.
+func emptyIovecs(iov []Iovec) bool {
+	for i := range iov {
+		if iov[i].Len > 0 {
+			return false
+		}
+	}
+	return true
+}
diff --git a/unix/syscall_unix_test.go b/unix/syscall_unix_test.go
index 81db934..c20afbe 100644
--- a/unix/syscall_unix_test.go
+++ b/unix/syscall_unix_test.go
@@ -16,8 +16,10 @@
 	"os"
 	"os/exec"
 	"path/filepath"
+	"reflect"
 	"runtime"
 	"strconv"
+	"sync"
 	"syscall"
 	"testing"
 	"time"
@@ -954,6 +956,68 @@
 	}
 }
 
+func TestSendmsgBuffers(t *testing.T) {
+	if runtime.GOOS == "aix" {
+		t.Skipf("SendmsgBuffers not supported on %s", runtime.GOOS)
+	}
+
+	fds, err := unix.Socketpair(unix.AF_LOCAL, unix.SOCK_STREAM, 0)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer unix.Close(fds[0])
+	defer unix.Close(fds[1])
+
+	var wg sync.WaitGroup
+	wg.Add(1)
+	go func() {
+		defer wg.Done()
+		bufs := [][]byte{
+			make([]byte, 5),
+			nil,
+			make([]byte, 5),
+		}
+		n, oobn, recvflags, _, err := unix.RecvmsgBuffers(fds[1], bufs, nil, 0)
+		if err != nil {
+			t.Fatal(err)
+		}
+		if n != 10 {
+			t.Errorf("got %d bytes, want 10", n)
+		}
+		if oobn != 0 {
+			t.Errorf("got %d OOB bytes, want 0", oobn)
+		}
+		if recvflags != 0 {
+			t.Errorf("got flags %#x, want %#x", recvflags, 0)
+		}
+		want := [][]byte{
+			[]byte("01234"),
+			nil,
+			[]byte("56789"),
+		}
+		if !reflect.DeepEqual(bufs, want) {
+			t.Errorf("got data %q, want %q", bufs, want)
+		}
+	}()
+
+	defer wg.Wait()
+
+	bufs := [][]byte{
+		[]byte("012"),
+		[]byte("34"),
+		nil,
+		[]byte("5678"),
+		[]byte("9"),
+	}
+	n, err := unix.SendmsgBuffers(fds[0], bufs, nil, nil, 0)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if n != 10 {
+		t.Errorf("sent %d bytes, want 10", n)
+	}
+}
+
 // mktmpfifo creates a temporary FIFO and provides a cleanup function.
 func mktmpfifo(t *testing.T) (*os.File, func()) {
 	err := unix.Mkfifo("fifo", 0666)