blob: 0b600a2b46f9da40c8b38c073d8b9b82cc836927 [file] [log] [blame]
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.21 && !quicbasicnet && (darwin || linux)
package quic
import (
"encoding/binary"
"net"
"net/netip"
"sync"
"unsafe"
"golang.org/x/sys/unix"
)
// Network interface for platforms using sendmsg/recvmsg with cmsgs.
type netUDPConn struct {
c *net.UDPConn
localAddr netip.AddrPort
}
func newNetUDPConn(uc *net.UDPConn) (*netUDPConn, error) {
a, _ := uc.LocalAddr().(*net.UDPAddr)
localAddr := a.AddrPort()
if localAddr.Addr().IsUnspecified() {
// If the conn is not bound to a specified (non-wildcard) address,
// then set localAddr.Addr to an invalid netip.Addr.
// This better conveys that this is not an address we should be using,
// and is a bit more efficient to test against.
localAddr = netip.AddrPortFrom(netip.Addr{}, localAddr.Port())
}
sc, err := uc.SyscallConn()
if err != nil {
return nil, err
}
sc.Control(func(fd uintptr) {
// Ask for ECN info and (when we aren't bound to a fixed local address)
// destination info.
//
// If any of these calls fail, we won't get the requested information.
// That's fine, we'll gracefully handle the lack.
unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_RECVTOS, 1)
unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVTCLASS, 1)
if !localAddr.IsValid() {
unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
}
})
return &netUDPConn{
c: uc,
localAddr: localAddr,
}, nil
}
func (c *netUDPConn) Close() error { return c.c.Close() }
func (c *netUDPConn) LocalAddr() netip.AddrPort {
a, _ := c.c.LocalAddr().(*net.UDPAddr)
return a.AddrPort()
}
func (c *netUDPConn) Read(f func(*datagram)) {
// We shouldn't ever see all of these messages at the same time,
// but the total is small so just allocate enough space for everything we use.
const (
inPktinfoSize = 12 // int + in_addr + in_addr
in6PktinfoSize = 20 // in6_addr + int
ipTOSSize = 4
ipv6TclassSize = 4
)
control := make([]byte, 0+
unix.CmsgSpace(inPktinfoSize)+
unix.CmsgSpace(in6PktinfoSize)+
unix.CmsgSpace(ipTOSSize)+
unix.CmsgSpace(ipv6TclassSize))
for {
d := newDatagram()
n, controlLen, _, peerAddr, err := c.c.ReadMsgUDPAddrPort(d.b, control)
if err != nil {
return
}
if n == 0 {
continue
}
d.localAddr = c.localAddr
d.peerAddr = unmapAddrPort(peerAddr)
d.b = d.b[:n]
parseControl(d, control[:controlLen])
f(d)
}
}
var cmsgPool = sync.Pool{
New: func() any {
return new([]byte)
},
}
func (c *netUDPConn) Write(dgram datagram) error {
controlp := cmsgPool.Get().(*[]byte)
control := *controlp
defer func() {
*controlp = control[:0]
cmsgPool.Put(controlp)
}()
localIP := dgram.localAddr.Addr()
if localIP.IsValid() {
if localIP.Is4() {
control = appendCmsgIPSourceAddrV4(control, localIP)
} else {
control = appendCmsgIPSourceAddrV6(control, localIP)
}
}
if dgram.ecn != ecnNotECT {
if dgram.peerAddr.Addr().Is4() {
control = appendCmsgECNv4(control, dgram.ecn)
} else {
control = appendCmsgECNv6(control, dgram.ecn)
}
}
_, _, err := c.c.WriteMsgUDPAddrPort(dgram.b, control, dgram.peerAddr)
return err
}
func parseControl(d *datagram, control []byte) {
for len(control) > 0 {
hdr, data, remainder, err := unix.ParseOneSocketControlMessage(control)
if err != nil {
return
}
control = remainder
switch hdr.Level {
case unix.IPPROTO_IP:
switch hdr.Type {
case unix.IP_TOS, unix.IP_RECVTOS:
// (Linux sets the type to IP_TOS, Darwin to IP_RECVTOS,
// just check for both.)
if ecn, ok := parseIPTOS(data); ok {
d.ecn = ecn
}
case unix.IP_PKTINFO:
if a, ok := parseInPktinfo(data); ok {
d.localAddr = netip.AddrPortFrom(a, d.localAddr.Port())
}
}
case unix.IPPROTO_IPV6:
switch hdr.Type {
case unix.IPV6_TCLASS:
// 32-bit integer containing the traffic class field.
// The low two bits are the ECN field.
if ecn, ok := parseIPv6TCLASS(data); ok {
d.ecn = ecn
}
case unix.IPV6_PKTINFO:
if a, ok := parseIn6Pktinfo(data); ok {
d.localAddr = netip.AddrPortFrom(a, d.localAddr.Port())
}
}
}
}
}
// IPV6_TCLASS is specified by RFC 3542 as an int.
func parseIPv6TCLASS(b []byte) (ecnBits, bool) {
if len(b) != 4 {
return 0, false
}
return ecnBits(binary.NativeEndian.Uint32(b) & ecnMask), true
}
func appendCmsgECNv6(b []byte, ecn ecnBits) []byte {
b, data := appendCmsg(b, unix.IPPROTO_IPV6, unix.IPV6_TCLASS, 4)
binary.NativeEndian.PutUint32(data, uint32(ecn))
return b
}
// struct in_pktinfo {
// unsigned int ipi_ifindex; /* send/recv interface index */
// struct in_addr ipi_spec_dst; /* Local address */
// struct in_addr ipi_addr; /* IP Header dst address */
// };
// parseInPktinfo returns the destination address from an IP_PKTINFO.
func parseInPktinfo(b []byte) (dst netip.Addr, ok bool) {
if len(b) != 12 {
return netip.Addr{}, false
}
return netip.AddrFrom4([4]byte(b[8:][:4])), true
}
// appendCmsgIPSourceAddrV4 appends an IP_PKTINFO setting the source address
// for an outbound datagram.
func appendCmsgIPSourceAddrV4(b []byte, src netip.Addr) []byte {
// struct in_pktinfo {
// unsigned int ipi_ifindex; /* send/recv interface index */
// struct in_addr ipi_spec_dst; /* Local address */
// struct in_addr ipi_addr; /* IP Header dst address */
// };
b, data := appendCmsg(b, unix.IPPROTO_IP, unix.IP_PKTINFO, 12)
ip := src.As4()
copy(data[4:], ip[:])
return b
}
// struct in6_pktinfo {
// struct in6_addr ipi6_addr; /* src/dst IPv6 address */
// unsigned int ipi6_ifindex; /* send/recv interface index */
// };
// parseIn6Pktinfo returns the destination address from an IPV6_PKTINFO.
func parseIn6Pktinfo(b []byte) (netip.Addr, bool) {
if len(b) != 20 {
return netip.Addr{}, false
}
return netip.AddrFrom16([16]byte(b[:16])).Unmap(), true
}
// appendCmsgIPSourceAddrV6 appends an IPV6_PKTINFO setting the source address
// for an outbound datagram.
func appendCmsgIPSourceAddrV6(b []byte, src netip.Addr) []byte {
b, data := appendCmsg(b, unix.IPPROTO_IPV6, unix.IPV6_PKTINFO, 20)
ip := src.As16()
copy(data[0:], ip[:])
return b
}
// appendCmsg appends a cmsg with the given level, type, and size to b.
// It returns the new buffer, and the data section of the cmsg.
func appendCmsg(b []byte, level, typ int32, size int) (_, data []byte) {
off := len(b)
b = append(b, make([]byte, unix.CmsgSpace(size))...)
h := (*unix.Cmsghdr)(unsafe.Pointer(&b[off]))
h.Level = level
h.Type = typ
h.SetLen(unix.CmsgLen(size))
return b, b[off+unix.CmsgSpace(0):][:size]
}