quic: validate connection id transport parameters

Validate the original_destination_connection_id and
initial_source_connection_id transport parameters.

RFC 9000, Section 7.3

For golang/go#58547

Change-Id: I8343fd53c5cc946f15d3410c632b3895205fd597
Reviewed-on: https://go-review.googlesource.com/c/net/+/530036
Reviewed-by: Jonathan Amsterdam <jba@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/internal/quic/conn.go b/internal/quic/conn.go
index 6097912..9db00fe 100644
--- a/internal/quic/conn.go
+++ b/internal/quic/conn.go
@@ -86,6 +86,7 @@
 	// non-blocking operation.
 	c.msgc = make(chan any, 1)
 
+	var originalDstConnID []byte
 	if c.side == clientSide {
 		if err := c.connIDState.initClient(c); err != nil {
 			return nil, err
@@ -95,6 +96,7 @@
 		if err := c.connIDState.initServer(c, initialConnID); err != nil {
 			return nil, err
 		}
+		originalDstConnID = initialConnID
 	}
 
 	// The smallest allowed maximum QUIC datagram size is 1200 bytes.
@@ -105,9 +107,10 @@
 	c.streamsInit()
 	c.lifetimeInit()
 
-	// TODO: initial_source_connection_id, retry_source_connection_id
+	// TODO: retry_source_connection_id
 	if err := c.startTLS(now, initialConnID, transportParameters{
 		initialSrcConnID:               c.connIDState.srcConnID(),
+		originalDstConnID:              originalDstConnID,
 		ackDelayExponent:               ackDelayExponent,
 		maxUDPPayloadSize:              maxUDPPayloadSize,
 		maxAckDelay:                    maxAckDelay,
@@ -171,6 +174,9 @@
 
 // receiveTransportParameters applies transport parameters sent by the peer.
 func (c *Conn) receiveTransportParameters(p transportParameters) error {
+	if err := c.connIDState.validateTransportParameters(c.side, p); err != nil {
+		return err
+	}
 	c.streams.outflow.setMaxData(p.initialMaxData)
 	c.streams.localLimit[bidiStream].setMax(p.initialMaxStreamsBidi)
 	c.streams.localLimit[uniStream].setMax(p.initialMaxStreamsUni)
diff --git a/internal/quic/conn_id.go b/internal/quic/conn_id.go
index eb2f3ec..045e646 100644
--- a/internal/quic/conn_id.go
+++ b/internal/quic/conn_id.go
@@ -161,6 +161,39 @@
 	return nil
 }
 
+// validateTransportParameters verifies the original_destination_connection_id and
+// initial_source_connection_id transport parameters match the expected values.
+func (s *connIDState) validateTransportParameters(side connSide, p transportParameters) error {
+	// TODO: Consider returning more detailed errors, for debugging.
+	switch side {
+	case clientSide:
+		// Verify original_destination_connection_id matches
+		// the transient remote connection ID we chose.
+		if len(s.remote) == 0 || s.remote[0].seq != -1 {
+			return localTransportError(errInternal)
+		}
+		if !bytes.Equal(s.remote[0].cid, p.originalDstConnID) {
+			return localTransportError(errTransportParameter)
+		}
+		// Remove the transient remote connection ID.
+		// We have no further need for it.
+		s.remote = append(s.remote[:0], s.remote[1:]...)
+	case serverSide:
+		if p.originalDstConnID != nil {
+			// Clients do not send original_destination_connection_id.
+			return localTransportError(errTransportParameter)
+		}
+	}
+	// Verify initial_source_connection_id matches the first remote connection ID.
+	if len(s.remote) == 0 || s.remote[0].seq != 0 {
+		return localTransportError(errInternal)
+	}
+	if !bytes.Equal(p.initialSrcConnID, s.remote[0].cid) {
+		return localTransportError(errTransportParameter)
+	}
+	return nil
+}
+
 // handlePacket updates the connection ID state during the handshake
 // (Initial and Handshake packets).
 func (s *connIDState) handlePacket(c *Conn, ptype packetType, srcConnID []byte) {
@@ -170,10 +203,13 @@
 			// We're a client connection processing the first Initial packet
 			// from the server. Replace the transient remote connection ID
 			// with the Source Connection ID from the packet.
-			s.remote[0] = connID{
+			// Leave the transient ID the list for now, since we'll need it when
+			// processing the transport parameters.
+			s.remote[0].retired = true
+			s.remote = append(s.remote, connID{
 				seq: 0,
 				cid: cloneBytes(srcConnID),
-			}
+			})
 		}
 	case ptype == packetTypeInitial && c.side == serverSide:
 		if len(s.remote) == 0 {
@@ -185,7 +221,7 @@
 			})
 		}
 	case ptype == packetTypeHandshake && c.side == serverSide:
-		if len(s.local) > 0 && s.local[0].seq == -1 {
+		if len(s.local) > 0 && s.local[0].seq == -1 && !s.local[0].retired {
 			// We're a server connection processing the first Handshake packet from
 			// the client. Discard the transient, client-chosen connection ID used
 			// for Initial packets; the client will never send it again.
@@ -213,7 +249,7 @@
 	active := 0
 	for i := range s.remote {
 		rcid := &s.remote[i]
-		if !rcid.retired && rcid.seq < s.retireRemotePriorTo {
+		if !rcid.retired && rcid.seq >= 0 && rcid.seq < s.retireRemotePriorTo {
 			s.retireRemote(rcid)
 		}
 		if !rcid.retired {
diff --git a/internal/quic/conn_id_test.go b/internal/quic/conn_id_test.go
index c528958..44755ec 100644
--- a/internal/quic/conn_id_test.go
+++ b/internal/quic/conn_id_test.go
@@ -48,6 +48,9 @@
 		t.Errorf("local ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantLocal))
 	}
 	wantRemote := []connID{{
+		cid: testLocalConnID(-1),
+		seq: -1,
+	}, {
 		cid: testPeerConnID(0),
 		seq: 0,
 	}}
@@ -261,10 +264,12 @@
 }
 
 func TestConnIDPeerWithZeroLengthConnIDSendsNewConnectionID(t *testing.T) {
-	// An endpoint that selects a zero-length connection ID during the handshake
+	// "An endpoint that selects a zero-length connection ID during the handshake
 	// cannot issue a new connection ID."
 	// https://www.rfc-editor.org/rfc/rfc9000#section-5.1.1-8
-	tc := newTestConn(t, clientSide)
+	tc := newTestConn(t, clientSide, func(p *transportParameters) {
+		p.initialSrcConnID = []byte{}
+	})
 	tc.peerConnID = []byte{}
 	tc.ignoreFrame(frameTypeAck)
 	tc.uncheckedHandshake()
@@ -536,6 +541,7 @@
 	// Peer gives us more conn ids than our advertised limit,
 	// including a conn id in the preferred address transport parameter.
 	tc := newTestConn(t, serverSide, func(p *transportParameters) {
+		p.initialSrcConnID = []byte{}
 		p.preferredAddrV4 = netip.MustParseAddrPort("0.0.0.0:0")
 		p.preferredAddrV6 = netip.MustParseAddrPort("[::0]:0")
 		p.preferredAddrConnID = testPeerConnID(1)
@@ -552,3 +558,31 @@
 			code: errProtocolViolation,
 		})
 }
+
+func TestConnIDInitialSrcConnIDMismatch(t *testing.T) {
+	// "Endpoints MUST validate that received [initial_source_connection_id]
+	// parameters match received connection ID values."
+	// https://www.rfc-editor.org/rfc/rfc9000#section-7.3-3
+	testSides(t, "", func(t *testing.T, side connSide) {
+		tc := newTestConn(t, side, func(p *transportParameters) {
+			p.initialSrcConnID = []byte("invalid")
+		})
+		tc.ignoreFrame(frameTypeAck)
+		tc.ignoreFrame(frameTypeCrypto)
+		tc.writeFrames(packetTypeInitial,
+			debugFrameCrypto{
+				data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial],
+			})
+		if side == clientSide {
+			// Server transport parameters are carried in the Handshake packet.
+			tc.writeFrames(packetTypeHandshake,
+				debugFrameCrypto{
+					data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake],
+				})
+		}
+		tc.wantFrame("initial_source_connection_id transport parameter mismatch",
+			packetTypeInitial, debugFrameConnectionCloseTransport{
+				code: errTransportParameter,
+			})
+	})
+}
diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go
index fd9e6e4..6a359e8 100644
--- a/internal/quic/conn_test.go
+++ b/internal/quic/conn_test.go
@@ -201,6 +201,10 @@
 		TLSConfig: newTestTLSConfig(side),
 	}
 	peerProvidedParams := defaultTransportParameters()
+	peerProvidedParams.initialSrcConnID = testPeerConnID(0)
+	if side == clientSide {
+		peerProvidedParams.originalDstConnID = testLocalConnID(-1)
+	}
 	for _, o := range opts {
 		switch o := o.(type) {
 		case func(*Config):