quic: version negotiation

Servers respond to packets containing an unrecognized version
with a Version Negotiation packet.

Clients respond to Version Negotiation packets by aborting
the connection attempt, since we support only one version.

RFC 9000, Section 6

For golang/go#58547

Change-Id: I3f3a66a4d69950cc7dc22146ad2eddb93cbe34f7
Reviewed-on: https://go-review.googlesource.com/c/net/+/529739
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/internal/quic/conn_recv.go b/internal/quic/conn_recv.go
index 6347dda..19c4385 100644
--- a/internal/quic/conn_recv.go
+++ b/internal/quic/conn_recv.go
@@ -7,6 +7,9 @@
 package quic
 
 import (
+	"bytes"
+	"encoding/binary"
+	"errors"
 	"time"
 )
 
@@ -31,6 +34,9 @@
 			n = c.handleLongHeader(now, ptype, handshakeSpace, c.keysHandshake.r, buf)
 		case packetType1RTT:
 			n = c.handle1RTT(now, buf)
+		case packetTypeVersionNegotiation:
+			c.handleVersionNegotiation(now, buf)
+			return
 		default:
 			return
 		}
@@ -59,6 +65,11 @@
 		c.abort(now, localTransportError(errProtocolViolation))
 		return -1
 	}
+	if p.version != quicVersion1 {
+		// The peer has changed versions on us mid-handshake?
+		c.abort(now, localTransportError(errProtocolViolation))
+		return -1
+	}
 
 	if !c.acks[space].shouldProcess(p.num) {
 		return n
@@ -117,6 +128,42 @@
 	return len(buf)
 }
 
+var errVersionNegotiation = errors.New("server does not support QUIC version 1")
+
+func (c *Conn) handleVersionNegotiation(now time.Time, pkt []byte) {
+	if c.side != clientSide {
+		return // servers don't handle Version Negotiation packets
+	}
+	// "A client MUST discard any Version Negotiation packet if it has
+	// received and successfully processed any other packet [...]"
+	// https://www.rfc-editor.org/rfc/rfc9000#section-6.2-2
+	if !c.keysInitial.canRead() {
+		return // discarded Initial keys, connection is already established
+	}
+	if c.acks[initialSpace].seen.numRanges() != 0 {
+		return // processed at least one packet
+	}
+	_, srcConnID, versions := parseVersionNegotiation(pkt)
+	if len(c.connIDState.remote) < 1 || !bytes.Equal(c.connIDState.remote[0].cid, srcConnID) {
+		return // Source Connection ID doesn't match what we sent
+	}
+	for len(versions) >= 4 {
+		ver := binary.BigEndian.Uint32(versions)
+		if ver == 1 {
+			// "A client MUST discard a Version Negotiation packet that lists
+			// the QUIC version selected by the client."
+			// https://www.rfc-editor.org/rfc/rfc9000#section-6.2-2
+			return
+		}
+		versions = versions[4:]
+	}
+	// "A client that supports only this version of QUIC MUST
+	// abandon the current connection attempt if it receives
+	// a Version Negotiation packet, [with the two exceptions handled above]."
+	// https://www.rfc-editor.org/rfc/rfc9000#section-6.2-2
+	c.abortImmediately(now, errVersionNegotiation)
+}
+
 func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace, payload []byte) (ackEliciting bool) {
 	if len(payload) == 0 {
 		// "An endpoint MUST treat receipt of a packet containing no frames
diff --git a/internal/quic/conn_send.go b/internal/quic/conn_send.go
index 63f65b5..00b02c2 100644
--- a/internal/quic/conn_send.go
+++ b/internal/quic/conn_send.go
@@ -64,7 +64,7 @@
 			pnum := c.loss.nextNumber(initialSpace)
 			p := longPacket{
 				ptype:     packetTypeInitial,
-				version:   1,
+				version:   quicVersion1,
 				num:       pnum,
 				dstConnID: dstConnID,
 				srcConnID: c.connIDState.srcConnID(),
@@ -91,7 +91,7 @@
 			pnum := c.loss.nextNumber(handshakeSpace)
 			p := longPacket{
 				ptype:     packetTypeHandshake,
-				version:   1,
+				version:   quicVersion1,
 				num:       pnum,
 				dstConnID: dstConnID,
 				srcConnID: c.connIDState.srcConnID(),
diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go
index d75b2eb..fd9e6e4 100644
--- a/internal/quic/conn_test.go
+++ b/internal/quic/conn_test.go
@@ -409,7 +409,7 @@
 			keyNumber:   tc.sendKeyNumber,
 			keyPhaseBit: tc.sendKeyPhaseBit,
 			frames:      frames,
-			version:     1,
+			version:     quicVersion1,
 			dstConnID:   dstConnID,
 			srcConnID:   tc.peerConnID,
 		}},
diff --git a/internal/quic/listener.go b/internal/quic/listener.go
index 03d8ec6..96b1e45 100644
--- a/internal/quic/listener.go
+++ b/internal/quic/listener.go
@@ -239,32 +239,15 @@
 func (l *Listener) handleDatagram(m *datagram, conns map[string]*Conn) {
 	dstConnID, ok := dstConnIDForDatagram(m.b)
 	if !ok {
+		m.recycle()
 		return
 	}
 	c := conns[string(dstConnID)]
 	if c == nil {
-		if getPacketType(m.b) != packetTypeInitial {
-			// This packet isn't trying to create a new connection.
-			// It might be associated with some connection we've lost state for.
-			// TODO: Send a stateless reset when appropriate.
-			// https://www.rfc-editor.org/rfc/rfc9000.html#section-10.3
-			return
-		}
-		var now time.Time
-		if l.testHooks != nil {
-			now = l.testHooks.timeNow()
-		} else {
-			now = time.Now()
-		}
-		var err error
-		c, err = l.newConn(now, serverSide, dstConnID, m.addr)
-		if err != nil {
-			// The accept queue is probably full.
-			// We could send a CONNECTION_CLOSE to the peer to reject the connection.
-			// Currently, we just drop the datagram.
-			// https://www.rfc-editor.org/rfc/rfc9000.html#section-5.2.2-5
-			return
-		}
+		// TODO: Move this branch into a separate goroutine to avoid blocking
+		// the listener while processing packets.
+		l.handleUnknownDestinationDatagram(m)
+		return
 	}
 
 	// TODO: This can block the listener while waiting for the conn to accept the dgram.
@@ -272,6 +255,67 @@
 	c.sendMsg(m)
 }
 
+func (l *Listener) handleUnknownDestinationDatagram(m *datagram) {
+	defer func() {
+		if m != nil {
+			m.recycle()
+		}
+	}()
+	if len(m.b) < minimumClientInitialDatagramSize {
+		return
+	}
+	p, ok := parseGenericLongHeaderPacket(m.b)
+	if !ok {
+		// Not a long header packet, or not parseable.
+		// Short header (1-RTT) packets don't contain enough information
+		// to do anything useful with if we don't recognize the
+		// connection ID.
+		return
+	}
+
+	switch p.version {
+	case quicVersion1:
+	case 0:
+		// Version Negotiation for an unknown connection.
+		return
+	default:
+		// Unknown version.
+		l.sendVersionNegotiation(p, m.addr)
+		return
+	}
+	if getPacketType(m.b) != packetTypeInitial {
+		// This packet isn't trying to create a new connection.
+		// It might be associated with some connection we've lost state for.
+		// TODO: Send a stateless reset when appropriate.
+		// https://www.rfc-editor.org/rfc/rfc9000.html#section-10.3
+		return
+	}
+	var now time.Time
+	if l.testHooks != nil {
+		now = l.testHooks.timeNow()
+	} else {
+		now = time.Now()
+	}
+	var err error
+	c, err := l.newConn(now, serverSide, p.dstConnID, m.addr)
+	if err != nil {
+		// The accept queue is probably full.
+		// We could send a CONNECTION_CLOSE to the peer to reject the connection.
+		// Currently, we just drop the datagram.
+		// https://www.rfc-editor.org/rfc/rfc9000.html#section-5.2.2-5
+		return
+	}
+	c.sendMsg(m)
+	m = nil // don't recycle, sendMsg takes ownership
+}
+
+func (l *Listener) sendVersionNegotiation(p genericLongPacket, addr netip.AddrPort) {
+	m := newDatagram()
+	m.b = appendVersionNegotiation(m.b[:0], p.srcConnID, p.dstConnID, quicVersion1)
+	l.sendDatagram(m.b, addr)
+	m.recycle()
+}
+
 func (l *Listener) sendDatagram(p []byte, addr netip.AddrPort) error {
 	_, err := l.udpConn.WriteToUDPAddrPort(p, addr)
 	return err
diff --git a/internal/quic/packet.go b/internal/quic/packet.go
index 8242bd0..7d69f96 100644
--- a/internal/quic/packet.go
+++ b/internal/quic/packet.go
@@ -6,7 +6,10 @@
 
 package quic
 
-import "fmt"
+import (
+	"encoding/binary"
+	"fmt"
+)
 
 // packetType is a QUIC packet type.
 // https://www.rfc-editor.org/rfc/rfc9000.html#section-17
@@ -157,6 +160,33 @@
 	return b[:n], true
 }
 
+// parseVersionNegotiation parses a Version Negotiation packet.
+// The returned versions is a slice of big-endian uint32s.
+// It returns (nil, nil, nil) for an invalid packet.
+func parseVersionNegotiation(pkt []byte) (dstConnID, srcConnID, versions []byte) {
+	p, ok := parseGenericLongHeaderPacket(pkt)
+	if !ok {
+		return nil, nil, nil
+	}
+	if len(p.data)%4 != 0 {
+		return nil, nil, nil
+	}
+	return p.dstConnID, p.srcConnID, p.data
+}
+
+// appendVersionNegotiation appends a Version Negotiation packet to pkt,
+// returning the result.
+func appendVersionNegotiation(pkt, dstConnID, srcConnID []byte, versions ...uint32) []byte {
+	pkt = append(pkt, headerFormLong|fixedBit) // header byte
+	pkt = append(pkt, 0, 0, 0, 0)              // Version (0 for Version Negotiation)
+	pkt = appendUint8Bytes(pkt, dstConnID)     // Destination Connection ID
+	pkt = appendUint8Bytes(pkt, srcConnID)     // Source Connection ID
+	for _, v := range versions {
+		pkt = binary.BigEndian.AppendUint32(pkt, v) // Supported Version
+	}
+	return pkt
+}
+
 // A longPacket is a long header packet.
 type longPacket struct {
 	ptype     packetType
@@ -177,3 +207,42 @@
 	num     packetNumber
 	payload []byte
 }
+
+// A genericLongPacket is a long header packet of an arbitrary QUIC version.
+// https://www.rfc-editor.org/rfc/rfc8999#section-5.1
+type genericLongPacket struct {
+	version   uint32
+	dstConnID []byte
+	srcConnID []byte
+	data      []byte
+}
+
+func parseGenericLongHeaderPacket(b []byte) (p genericLongPacket, ok bool) {
+	if len(b) < 5 || !isLongHeader(b[0]) {
+		return genericLongPacket{}, false
+	}
+	b = b[1:]
+	// Version (32),
+	var n int
+	p.version, n = consumeUint32(b)
+	if n < 0 {
+		return genericLongPacket{}, false
+	}
+	b = b[n:]
+	// Destination Connection ID Length (8),
+	// Destination Connection ID (0..2048),
+	p.dstConnID, n = consumeUint8Bytes(b)
+	if n < 0 || len(p.dstConnID) > 2048/8 {
+		return genericLongPacket{}, false
+	}
+	b = b[n:]
+	// Source Connection ID Length (8),
+	// Source Connection ID (0..2048),
+	p.srcConnID, n = consumeUint8Bytes(b)
+	if n < 0 || len(p.dstConnID) > 2048/8 {
+		return genericLongPacket{}, false
+	}
+	b = b[n:]
+	p.data = b
+	return p, true
+}
diff --git a/internal/quic/packet_test.go b/internal/quic/packet_test.go
index b13a587..58c584e 100644
--- a/internal/quic/packet_test.go
+++ b/internal/quic/packet_test.go
@@ -8,7 +8,9 @@
 
 import (
 	"bytes"
+	"encoding/binary"
 	"encoding/hex"
+	"reflect"
 	"strings"
 	"testing"
 )
@@ -112,6 +114,124 @@
 	}
 }
 
+func TestEncodeDecodeVersionNegotiation(t *testing.T) {
+	dstConnID := []byte("this is a very long destination connection id")
+	srcConnID := []byte("this is a very long source connection id")
+	versions := []uint32{1, 0xffffffff}
+	got := appendVersionNegotiation([]byte{}, dstConnID, srcConnID, versions...)
+	want := bytes.Join([][]byte{{
+		0b1100_0000, // header byte
+		0, 0, 0, 0,  // Version
+		byte(len(dstConnID)),
+	}, dstConnID, {
+		byte(len(srcConnID)),
+	}, srcConnID, {
+		0x00, 0x00, 0x00, 0x01,
+		0xff, 0xff, 0xff, 0xff,
+	}}, nil)
+	if !bytes.Equal(got, want) {
+		t.Fatalf("appendVersionNegotiation(nil, %x, %x, %v):\ngot  %x\nwant %x",
+			dstConnID, srcConnID, versions, got, want)
+	}
+	gotDst, gotSrc, gotVersionBytes := parseVersionNegotiation(got)
+	if got, want := gotDst, dstConnID; !bytes.Equal(got, want) {
+		t.Errorf("parseVersionNegotiation: got dstConnID = %x, want %x", got, want)
+	}
+	if got, want := gotSrc, srcConnID; !bytes.Equal(got, want) {
+		t.Errorf("parseVersionNegotiation: got srcConnID = %x, want %x", got, want)
+	}
+	var gotVersions []uint32
+	for len(gotVersionBytes) >= 4 {
+		gotVersions = append(gotVersions, binary.BigEndian.Uint32(gotVersionBytes))
+		gotVersionBytes = gotVersionBytes[4:]
+	}
+	if got, want := gotVersions, versions; !reflect.DeepEqual(got, want) {
+		t.Errorf("parseVersionNegotiation: got versions = %v, want %v", got, want)
+	}
+}
+
+func TestParseGenericLongHeaderPacket(t *testing.T) {
+	for _, test := range []struct {
+		name      string
+		packet    []byte
+		version   uint32
+		dstConnID []byte
+		srcConnID []byte
+		data      []byte
+	}{{
+		name: "long header packet",
+		packet: unhex(`
+			80 01020304 04a1a2a3a4 05b1b2b3b4b5 c1
+		`),
+		version:   0x01020304,
+		dstConnID: unhex(`a1a2a3a4`),
+		srcConnID: unhex(`b1b2b3b4b5`),
+		data:      unhex(`c1`),
+	}, {
+		name: "zero everything",
+		packet: unhex(`
+			80 00000000 00 00
+		`),
+		version:   0,
+		dstConnID: []byte{},
+		srcConnID: []byte{},
+		data:      []byte{},
+	}} {
+		t.Run(test.name, func(t *testing.T) {
+			p, ok := parseGenericLongHeaderPacket(test.packet)
+			if !ok {
+				t.Fatalf("parseGenericLongHeaderPacket() = _, false; want true")
+			}
+			if got, want := p.version, test.version; got != want {
+				t.Errorf("version = %v, want %v", got, want)
+			}
+			if got, want := p.dstConnID, test.dstConnID; !bytes.Equal(got, want) {
+				t.Errorf("Destination Connection ID = {%x}, want {%x}", got, want)
+			}
+			if got, want := p.srcConnID, test.srcConnID; !bytes.Equal(got, want) {
+				t.Errorf("Source Connection ID = {%x}, want {%x}", got, want)
+			}
+			if got, want := p.data, test.data; !bytes.Equal(got, want) {
+				t.Errorf("Data = {%x}, want {%x}", got, want)
+			}
+		})
+	}
+}
+
+func TestParseGenericLongHeaderPacketErrors(t *testing.T) {
+	for _, test := range []struct {
+		name   string
+		packet []byte
+	}{{
+		name: "short header packet",
+		packet: unhex(`
+			00 01020304 04a1a2a3a4 05b1b2b3b4b5 c1
+		`),
+	}, {
+		name: "packet too short",
+		packet: unhex(`
+			80 000000
+		`),
+	}, {
+		name: "destination id too long",
+		packet: unhex(`
+			80 00000000 02 00
+		`),
+	}, {
+		name: "source id too long",
+		packet: unhex(`
+			80 00000000 00 01
+		`),
+	}} {
+		t.Run(test.name, func(t *testing.T) {
+			_, ok := parseGenericLongHeaderPacket(test.packet)
+			if ok {
+				t.Fatalf("parseGenericLongHeaderPacket() = _, true; want false")
+			}
+		})
+	}
+}
+
 func unhex(s string) []byte {
 	b, err := hex.DecodeString(strings.Map(func(c rune) rune {
 		switch c {
diff --git a/internal/quic/quic.go b/internal/quic/quic.go
index cf4137e..9de97b6 100644
--- a/internal/quic/quic.go
+++ b/internal/quic/quic.go
@@ -10,6 +10,13 @@
 	"time"
 )
 
+// QUIC versions.
+// We only support v1 at this time.
+const (
+	quicVersion1 = 1
+	quicVersion2 = 0x6b3343cf // https://www.rfc-editor.org/rfc/rfc9369
+)
+
 // connIDLen is the length in bytes of connection IDs chosen by this package.
 // Since 1-RTT packets don't include a connection ID length field,
 // we use a consistent length for all our IDs.
diff --git a/internal/quic/tls_test.go b/internal/quic/tls_test.go
index 4167076..81d17b8 100644
--- a/internal/quic/tls_test.go
+++ b/internal/quic/tls_test.go
@@ -97,7 +97,7 @@
 		packets: []*testPacket{{
 			ptype:     packetTypeInitial,
 			num:       0,
-			version:   1,
+			version:   quicVersion1,
 			srcConnID: clientConnIDs[0],
 			dstConnID: transientConnID,
 			frames: []debugFrame{
@@ -110,7 +110,7 @@
 		packets: []*testPacket{{
 			ptype:     packetTypeInitial,
 			num:       0,
-			version:   1,
+			version:   quicVersion1,
 			srcConnID: serverConnIDs[0],
 			dstConnID: clientConnIDs[0],
 			frames: []debugFrame{
@@ -122,7 +122,7 @@
 		}, {
 			ptype:     packetTypeHandshake,
 			num:       0,
-			version:   1,
+			version:   quicVersion1,
 			srcConnID: serverConnIDs[0],
 			dstConnID: clientConnIDs[0],
 			frames: []debugFrame{
@@ -144,7 +144,7 @@
 		packets: []*testPacket{{
 			ptype:     packetTypeInitial,
 			num:       1,
-			version:   1,
+			version:   quicVersion1,
 			srcConnID: clientConnIDs[0],
 			dstConnID: serverConnIDs[0],
 			frames: []debugFrame{
@@ -155,7 +155,7 @@
 		}, {
 			ptype:     packetTypeHandshake,
 			num:       0,
-			version:   1,
+			version:   quicVersion1,
 			srcConnID: clientConnIDs[0],
 			dstConnID: serverConnIDs[0],
 			frames: []debugFrame{
@@ -568,7 +568,7 @@
 		ptype:     packetType1RTT,
 		num:       1000,
 		frames:    []debugFrame{debugFramePing{}},
-		version:   1,
+		version:   quicVersion1,
 		dstConnID: dstConnID,
 		srcConnID: tc.peerConnID,
 	}, 0)
diff --git a/internal/quic/version_test.go b/internal/quic/version_test.go
new file mode 100644
index 0000000..cfb7ce4
--- /dev/null
+++ b/internal/quic/version_test.go
@@ -0,0 +1,110 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+	"bytes"
+	"context"
+	"crypto/tls"
+	"testing"
+)
+
+func TestVersionNegotiationServerReceivesUnknownVersion(t *testing.T) {
+	config := &Config{
+		TLSConfig: newTestTLSConfig(serverSide),
+	}
+	tl := newTestListener(t, config, nil)
+
+	// Packet of unknown contents for some unrecognized QUIC version.
+	dstConnID := []byte{1, 2, 3, 4}
+	srcConnID := []byte{5, 6, 7, 8}
+	pkt := []byte{
+		0b1000_0000,
+		0x00, 0x00, 0x00, 0x0f,
+	}
+	pkt = append(pkt, byte(len(dstConnID)))
+	pkt = append(pkt, dstConnID...)
+	pkt = append(pkt, byte(len(srcConnID)))
+	pkt = append(pkt, srcConnID...)
+	for len(pkt) < minimumClientInitialDatagramSize {
+		pkt = append(pkt, 0)
+	}
+
+	tl.write(&datagram{
+		b: pkt,
+	})
+	gotPkt := tl.read()
+	if gotPkt == nil {
+		t.Fatalf("got no response; want Version Negotiaion")
+	}
+	if got := getPacketType(gotPkt); got != packetTypeVersionNegotiation {
+		t.Fatalf("got packet type %v; want Version Negotiaion", got)
+	}
+	gotDst, gotSrc, versions := parseVersionNegotiation(gotPkt)
+	if got, want := gotDst, srcConnID; !bytes.Equal(got, want) {
+		t.Errorf("got Destination Connection ID %x, want %x", got, want)
+	}
+	if got, want := gotSrc, dstConnID; !bytes.Equal(got, want) {
+		t.Errorf("got Source Connection ID %x, want %x", got, want)
+	}
+	if got, want := versions, []byte{0, 0, 0, 1}; !bytes.Equal(got, want) {
+		t.Errorf("got Supported Version %x, want %x", got, want)
+	}
+}
+
+func TestVersionNegotiationClientAborts(t *testing.T) {
+	tc := newTestConn(t, clientSide)
+	p := tc.readPacket() // client Initial packet
+	tc.listener.write(&datagram{
+		b: appendVersionNegotiation(nil, p.srcConnID, p.dstConnID, 10),
+	})
+	tc.wantIdle("connection does not send a CONNECTION_CLOSE")
+	if err := tc.conn.waitReady(canceledContext()); err != errVersionNegotiation {
+		t.Errorf("conn.waitReady() = %v, want errVersionNegotiation", err)
+	}
+}
+
+func TestVersionNegotiationClientIgnoresAfterProcessingPacket(t *testing.T) {
+	tc := newTestConn(t, clientSide)
+	tc.ignoreFrame(frameTypeAck)
+	p := tc.readPacket() // client Initial packet
+	tc.writeFrames(packetTypeInitial,
+		debugFrameCrypto{
+			data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial],
+		})
+	tc.listener.write(&datagram{
+		b: appendVersionNegotiation(nil, p.srcConnID, p.dstConnID, 10),
+	})
+	if err := tc.conn.waitReady(canceledContext()); err != context.Canceled {
+		t.Errorf("conn.waitReady() = %v, want context.Canceled", err)
+	}
+	tc.writeFrames(packetTypeHandshake,
+		debugFrameCrypto{
+			data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake],
+		})
+	tc.wantFrameType("conn ignores Version Negotiation and continues with handshake",
+		packetTypeHandshake, debugFrameCrypto{})
+}
+
+func TestVersionNegotiationClientIgnoresMismatchingSourceConnID(t *testing.T) {
+	tc := newTestConn(t, clientSide)
+	tc.ignoreFrame(frameTypeAck)
+	p := tc.readPacket() // client Initial packet
+	tc.listener.write(&datagram{
+		b: appendVersionNegotiation(nil, p.srcConnID, []byte("mismatch"), 10),
+	})
+	tc.writeFrames(packetTypeInitial,
+		debugFrameCrypto{
+			data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial],
+		})
+	tc.writeFrames(packetTypeHandshake,
+		debugFrameCrypto{
+			data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake],
+		})
+	tc.wantFrameType("conn ignores Version Negotiation and continues with handshake",
+		packetTypeHandshake, debugFrameCrypto{})
+}