unix: add ParseOrigDstAddr

Add a function which turns a SocketControlMessage into a Sockaddr.
This can be used with IP(V6)_RECVORIGDSTADDR to retrieve the original
destination address of a packet.

Change-Id: Ib2d80cd01be6642e8b918cbc1584d4a49c3c6f1e
Reviewed-on: https://go-review.googlesource.com/c/sys/+/355609
Reviewed-by: Tobias Klauser <tobias.klauser@gmail.com>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Run-TryBot: Tobias Klauser <tobias.klauser@gmail.com>
TryBot-Result: Go Bot <gobot@golang.org>
diff --git a/unix/creds_test.go b/unix/creds_test.go
index 1d4c091..9ab57ec 100644
--- a/unix/creds_test.go
+++ b/unix/creds_test.go
@@ -195,3 +195,102 @@
 		})
 	}
 }
+
+func TestParseOrigDstAddr(t *testing.T) {
+	testcases := []struct {
+		network string
+		address *net.UDPAddr
+	}{
+		{"udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}},
+		{"udp6", &net.UDPAddr{IP: net.IPv6loopback}},
+	}
+
+	for _, test := range testcases {
+		t.Run(test.network, func(t *testing.T) {
+			conn, err := net.ListenUDP(test.network, test.address)
+			if errors.Is(err, unix.EADDRNOTAVAIL) {
+				t.Skipf("%v is not available", test.address)
+			}
+			if err != nil {
+				t.Fatal("Listen:", err)
+			}
+			defer conn.Close()
+
+			raw, err := conn.SyscallConn()
+			if err != nil {
+				t.Fatal("SyscallConn:", err)
+			}
+
+			var opErr error
+			err = raw.Control(func(fd uintptr) {
+				switch test.network {
+				case "udp4":
+					opErr = unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_RECVORIGDSTADDR, 1)
+				case "udp6":
+					opErr = unix.SetsockoptInt(int(fd), unix.SOL_IPV6, unix.IPV6_RECVORIGDSTADDR, 1)
+				}
+			})
+			if err != nil {
+				t.Fatal("Control:", err)
+			}
+			if opErr != nil {
+				t.Fatal("Can't enable RECVORIGDSTADDR:", err)
+			}
+
+			msg := []byte{1}
+			addr := conn.LocalAddr().(*net.UDPAddr)
+			_, err = conn.WriteToUDP(msg, addr)
+			if err != nil {
+				t.Fatal("WriteToUDP:", err)
+			}
+
+			conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
+			oob := make([]byte, unix.CmsgSpace(unix.SizeofSockaddrInet6))
+			_, oobn, _, _, err := conn.ReadMsgUDP(msg, oob)
+			if err != nil {
+				t.Fatal("ReadMsgUDP:", err)
+			}
+
+			scms, err := unix.ParseSocketControlMessage(oob[:oobn])
+			if err != nil {
+				t.Fatal("ParseSocketControlMessage:", err)
+			}
+
+			sa, err := unix.ParseOrigDstAddr(&scms[0])
+			if err != nil {
+				t.Fatal("ParseOrigDstAddr:", err)
+			}
+
+			switch test.network {
+			case "udp4":
+				sa4, ok := sa.(*unix.SockaddrInet4)
+				if !ok {
+					t.Fatalf("Got %T not *SockaddrInet4", sa)
+				}
+
+				lo := net.IPv4(127, 0, 0, 1)
+				if addr := net.IP(sa4.Addr[:]); !lo.Equal(addr) {
+					t.Errorf("Got address %v, want %v", addr, lo)
+				}
+
+				if sa4.Port != addr.Port {
+					t.Errorf("Got port %d, want %d", sa4.Port, addr.Port)
+				}
+
+			case "udp6":
+				sa6, ok := sa.(*unix.SockaddrInet6)
+				if !ok {
+					t.Fatalf("Got %T, want *SockaddrInet6", sa)
+				}
+
+				if addr := net.IP(sa6.Addr[:]); !net.IPv6loopback.Equal(addr) {
+					t.Errorf("Got address %v, want %v", addr, net.IPv6loopback)
+				}
+
+				if sa6.Port != addr.Port {
+					t.Errorf("Got port %d, want %d", sa6.Port, addr.Port)
+				}
+			}
+		})
+	}
+}
diff --git a/unix/sockcmsg_linux.go b/unix/sockcmsg_linux.go
index e86d543..326fb04 100644
--- a/unix/sockcmsg_linux.go
+++ b/unix/sockcmsg_linux.go
@@ -56,3 +56,34 @@
 	*(*Inet6Pktinfo)(h.data(0)) = *info
 	return b
 }
+
+// ParseOrigDstAddr decodes a socket control message containing the original
+// destination address. To receive such a message the IP_RECVORIGDSTADDR or
+// IPV6_RECVORIGDSTADDR option must be enabled on the socket.
+func ParseOrigDstAddr(m *SocketControlMessage) (Sockaddr, error) {
+	switch {
+	case m.Header.Level == SOL_IP && m.Header.Type == IP_ORIGDSTADDR:
+		pp := (*RawSockaddrInet4)(unsafe.Pointer(&m.Data[0]))
+		sa := new(SockaddrInet4)
+		p := (*[2]byte)(unsafe.Pointer(&pp.Port))
+		sa.Port = int(p[0])<<8 + int(p[1])
+		for i := 0; i < len(sa.Addr); i++ {
+			sa.Addr[i] = pp.Addr[i]
+		}
+		return sa, nil
+
+	case m.Header.Level == SOL_IPV6 && m.Header.Type == IPV6_ORIGDSTADDR:
+		pp := (*RawSockaddrInet6)(unsafe.Pointer(&m.Data[0]))
+		sa := new(SockaddrInet6)
+		p := (*[2]byte)(unsafe.Pointer(&pp.Port))
+		sa.Port = int(p[0])<<8 + int(p[1])
+		sa.ZoneId = pp.Scope_id
+		for i := 0; i < len(sa.Addr); i++ {
+			sa.Addr[i] = pp.Addr[i]
+		}
+		return sa, nil
+
+	default:
+		return nil, EINVAL
+	}
+}