unix: add ParseOneSocketControlMessage to parse control messages without allocating
Fixes golang/go#54714.
Change-Id: If711272937078b6c696756823aa4dfcec358b719
Reviewed-on: https://go-review.googlesource.com/c/sys/+/425917
Reviewed-by: Matt Layher <mdlayher@gmail.com>
Reviewed-by: Michael Pratt <mpratt@google.com>
Reviewed-by: Ian Lance Taylor <iant@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Ian Lance Taylor <iant@google.com>
Auto-Submit: Ian Lance Taylor <iant@google.com>
Run-TryBot: Matt Layher <mdlayher@gmail.com>
diff --git a/unix/sockcmsg_unix.go b/unix/sockcmsg_unix.go
index 453a942..3865943 100644
--- a/unix/sockcmsg_unix.go
+++ b/unix/sockcmsg_unix.go
@@ -52,6 +52,20 @@
return msgs, nil
}
+// ParseOneSocketControlMessage parses a single socket control message from b, returning the message header,
+// message data (a slice of b), and the remainder of b after that single message.
+// When there are no remaining messages, len(remainder) == 0.
+func ParseOneSocketControlMessage(b []byte) (hdr Cmsghdr, data []byte, remainder []byte, err error) {
+ h, dbuf, err := socketControlMessageHeaderAndData(b)
+ if err != nil {
+ return Cmsghdr{}, nil, nil, err
+ }
+ if i := cmsgAlignOf(int(h.Len)); i < len(b) {
+ remainder = b[i:]
+ }
+ return *h, dbuf, remainder, nil
+}
+
func socketControlMessageHeaderAndData(b []byte) (*Cmsghdr, []byte, error) {
h := (*Cmsghdr)(unsafe.Pointer(&b[0]))
if h.Len < SizeofCmsghdr || uint64(h.Len) > uint64(len(b)) {
diff --git a/unix/syscall_unix_test.go b/unix/syscall_unix_test.go
index baff92e..0517689 100644
--- a/unix/syscall_unix_test.go
+++ b/unix/syscall_unix_test.go
@@ -322,7 +322,7 @@
}
}
-// TestUnixRightsRoundtrip tests that UnixRights, ParseSocketControlMessage,
+// TestUnixRightsRoundtrip tests that UnixRights, ParseSocketControlMessage, ParseOneSocketControlMessage,
// and ParseUnixRights are able to successfully round-trip lists of file descriptors.
func TestUnixRightsRoundtrip(t *testing.T) {
testCases := [...][][]int{
@@ -350,6 +350,23 @@
if len(scms) != len(testCase) {
t.Fatalf("expected %v SocketControlMessage; got scms = %#v", len(testCase), scms)
}
+
+ var c int
+ for len(b) > 0 {
+ hdr, data, remainder, err := unix.ParseOneSocketControlMessage(b)
+ if err != nil {
+ t.Fatalf("ParseOneSocketControlMessage: %v", err)
+ }
+ if scms[c].Header != hdr || !bytes.Equal(scms[c].Data, data) {
+ t.Fatal("expected SocketControlMessage header and data to match")
+ }
+ b = remainder
+ c++
+ }
+ if c != len(scms) {
+ t.Fatalf("expected %d SocketControlMessages; got %d", len(scms), c)
+ }
+
for i, scm := range scms {
gotFds, err := unix.ParseUnixRights(&scm)
if err != nil {