unix: add functions to encode Inet4Pktinfo and Inet6Pktinfo

It's possible to control the source address of a UDP packet by
passing a socket control message of type IP_PKTINFO or IPV6_PKTINFO.
This is a somewhat esoteric feature of the network stack, but it's
extremely useful feature when you really need it.

Change-Id: I8300575f975679f6689d6f1282af253ba62e8f9d
Reviewed-on: https://go-review.googlesource.com/c/sys/+/355610
Run-TryBot: Tobias Klauser <tobias.klauser@gmail.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Trust: Tobias Klauser <tobias.klauser@gmail.com>
diff --git a/unix/creds_test.go b/unix/creds_test.go
index 483bfb1..1d4c091 100644
--- a/unix/creds_test.go
+++ b/unix/creds_test.go
@@ -9,9 +9,11 @@
 
 import (
 	"bytes"
+	"errors"
 	"net"
 	"os"
 	"testing"
+	"time"
 
 	"golang.org/x/sys/unix"
 )
@@ -118,3 +120,78 @@
 		}
 	}
 }
+
+func TestPktInfo(t *testing.T) {
+	testcases := []struct {
+		network string
+		address *net.UDPAddr
+	}{
+		{"udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")}},
+		{"udp6", &net.UDPAddr{IP: net.ParseIP("::1")}},
+	}
+	for _, test := range testcases {
+		t.Run(test.network, func(t *testing.T) {
+			conn, err := net.ListenUDP(test.network, test.address)
+			if errors.Is(err, unix.EADDRNOTAVAIL) {
+				t.Skipf("%v is not available", test.address)
+			}
+			if err != nil {
+				t.Fatal("Listen:", err)
+			}
+			defer conn.Close()
+
+			var pktInfo []byte
+			var src net.IP
+			switch test.network {
+			case "udp4":
+				var info4 unix.Inet4Pktinfo
+				src = net.ParseIP("127.0.0.2").To4()
+				copy(info4.Spec_dst[:], src)
+				pktInfo = unix.PktInfo4(&info4)
+
+			case "udp6":
+				var info6 unix.Inet6Pktinfo
+				src = net.ParseIP("2001:0DB8::1")
+				copy(info6.Addr[:], src)
+				pktInfo = unix.PktInfo6(&info6)
+
+				raw, err := conn.SyscallConn()
+				if err != nil {
+					t.Fatal("SyscallConn:", err)
+				}
+				var opErr error
+				err = raw.Control(func(fd uintptr) {
+					opErr = unix.SetsockoptInt(int(fd), unix.SOL_IPV6, unix.IPV6_FREEBIND, 1)
+				})
+				if err != nil {
+					t.Fatal("Control:", err)
+				}
+				if errors.Is(opErr, unix.ENOPROTOOPT) {
+					// Happens on android-amd64-emu, maybe Android has disabled
+					// IPV6_FREEBIND?
+					t.Skip("IPV6_FREEBIND not supported")
+				}
+				if opErr != nil {
+					t.Fatal("Can't enable IPV6_FREEBIND:", opErr)
+				}
+			}
+
+			msg := []byte{1}
+			addr := conn.LocalAddr().(*net.UDPAddr)
+			_, _, err = conn.WriteMsgUDP(msg, pktInfo, addr)
+			if err != nil {
+				t.Fatal("WriteMsgUDP:", err)
+			}
+
+			conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
+			_, _, _, remote, err := conn.ReadMsgUDP(msg, nil)
+			if err != nil {
+				t.Fatal("ReadMsgUDP:", err)
+			}
+
+			if !remote.IP.Equal(src) {
+				t.Errorf("Got packet from %v, want %v", remote.IP, src)
+			}
+		})
+	}
+}
diff --git a/unix/sockcmsg_linux.go b/unix/sockcmsg_linux.go
index 8bf4570..e86d543 100644
--- a/unix/sockcmsg_linux.go
+++ b/unix/sockcmsg_linux.go
@@ -34,3 +34,25 @@
 	ucred := *(*Ucred)(unsafe.Pointer(&m.Data[0]))
 	return &ucred, nil
 }
+
+// PktInfo4 encodes Inet4Pktinfo into a socket control message of type IP_PKTINFO.
+func PktInfo4(info *Inet4Pktinfo) []byte {
+	b := make([]byte, CmsgSpace(SizeofInet4Pktinfo))
+	h := (*Cmsghdr)(unsafe.Pointer(&b[0]))
+	h.Level = SOL_IP
+	h.Type = IP_PKTINFO
+	h.SetLen(CmsgLen(SizeofInet4Pktinfo))
+	*(*Inet4Pktinfo)(h.data(0)) = *info
+	return b
+}
+
+// PktInfo6 encodes Inet6Pktinfo into a socket control message of type IPV6_PKTINFO.
+func PktInfo6(info *Inet6Pktinfo) []byte {
+	b := make([]byte, CmsgSpace(SizeofInet6Pktinfo))
+	h := (*Cmsghdr)(unsafe.Pointer(&b[0]))
+	h.Level = SOL_IPV6
+	h.Type = IPV6_PKTINFO
+	h.SetLen(CmsgLen(SizeofInet6Pktinfo))
+	*(*Inet6Pktinfo)(h.data(0)) = *info
+	return b
+}