quic: validate connection id transport parameters
Validate the original_destination_connection_id and
initial_source_connection_id transport parameters.
RFC 9000, Section 7.3
For golang/go#58547
Change-Id: I8343fd53c5cc946f15d3410c632b3895205fd597
Reviewed-on: https://go-review.googlesource.com/c/net/+/530036
Reviewed-by: Jonathan Amsterdam <jba@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/internal/quic/conn.go b/internal/quic/conn.go
index 6097912..9db00fe 100644
--- a/internal/quic/conn.go
+++ b/internal/quic/conn.go
@@ -86,6 +86,7 @@
// non-blocking operation.
c.msgc = make(chan any, 1)
+ var originalDstConnID []byte
if c.side == clientSide {
if err := c.connIDState.initClient(c); err != nil {
return nil, err
@@ -95,6 +96,7 @@
if err := c.connIDState.initServer(c, initialConnID); err != nil {
return nil, err
}
+ originalDstConnID = initialConnID
}
// The smallest allowed maximum QUIC datagram size is 1200 bytes.
@@ -105,9 +107,10 @@
c.streamsInit()
c.lifetimeInit()
- // TODO: initial_source_connection_id, retry_source_connection_id
+ // TODO: retry_source_connection_id
if err := c.startTLS(now, initialConnID, transportParameters{
initialSrcConnID: c.connIDState.srcConnID(),
+ originalDstConnID: originalDstConnID,
ackDelayExponent: ackDelayExponent,
maxUDPPayloadSize: maxUDPPayloadSize,
maxAckDelay: maxAckDelay,
@@ -171,6 +174,9 @@
// receiveTransportParameters applies transport parameters sent by the peer.
func (c *Conn) receiveTransportParameters(p transportParameters) error {
+ if err := c.connIDState.validateTransportParameters(c.side, p); err != nil {
+ return err
+ }
c.streams.outflow.setMaxData(p.initialMaxData)
c.streams.localLimit[bidiStream].setMax(p.initialMaxStreamsBidi)
c.streams.localLimit[uniStream].setMax(p.initialMaxStreamsUni)
diff --git a/internal/quic/conn_id.go b/internal/quic/conn_id.go
index eb2f3ec..045e646 100644
--- a/internal/quic/conn_id.go
+++ b/internal/quic/conn_id.go
@@ -161,6 +161,39 @@
return nil
}
+// validateTransportParameters verifies the original_destination_connection_id and
+// initial_source_connection_id transport parameters match the expected values.
+func (s *connIDState) validateTransportParameters(side connSide, p transportParameters) error {
+ // TODO: Consider returning more detailed errors, for debugging.
+ switch side {
+ case clientSide:
+ // Verify original_destination_connection_id matches
+ // the transient remote connection ID we chose.
+ if len(s.remote) == 0 || s.remote[0].seq != -1 {
+ return localTransportError(errInternal)
+ }
+ if !bytes.Equal(s.remote[0].cid, p.originalDstConnID) {
+ return localTransportError(errTransportParameter)
+ }
+ // Remove the transient remote connection ID.
+ // We have no further need for it.
+ s.remote = append(s.remote[:0], s.remote[1:]...)
+ case serverSide:
+ if p.originalDstConnID != nil {
+ // Clients do not send original_destination_connection_id.
+ return localTransportError(errTransportParameter)
+ }
+ }
+ // Verify initial_source_connection_id matches the first remote connection ID.
+ if len(s.remote) == 0 || s.remote[0].seq != 0 {
+ return localTransportError(errInternal)
+ }
+ if !bytes.Equal(p.initialSrcConnID, s.remote[0].cid) {
+ return localTransportError(errTransportParameter)
+ }
+ return nil
+}
+
// handlePacket updates the connection ID state during the handshake
// (Initial and Handshake packets).
func (s *connIDState) handlePacket(c *Conn, ptype packetType, srcConnID []byte) {
@@ -170,10 +203,13 @@
// We're a client connection processing the first Initial packet
// from the server. Replace the transient remote connection ID
// with the Source Connection ID from the packet.
- s.remote[0] = connID{
+ // Leave the transient ID the list for now, since we'll need it when
+ // processing the transport parameters.
+ s.remote[0].retired = true
+ s.remote = append(s.remote, connID{
seq: 0,
cid: cloneBytes(srcConnID),
- }
+ })
}
case ptype == packetTypeInitial && c.side == serverSide:
if len(s.remote) == 0 {
@@ -185,7 +221,7 @@
})
}
case ptype == packetTypeHandshake && c.side == serverSide:
- if len(s.local) > 0 && s.local[0].seq == -1 {
+ if len(s.local) > 0 && s.local[0].seq == -1 && !s.local[0].retired {
// We're a server connection processing the first Handshake packet from
// the client. Discard the transient, client-chosen connection ID used
// for Initial packets; the client will never send it again.
@@ -213,7 +249,7 @@
active := 0
for i := range s.remote {
rcid := &s.remote[i]
- if !rcid.retired && rcid.seq < s.retireRemotePriorTo {
+ if !rcid.retired && rcid.seq >= 0 && rcid.seq < s.retireRemotePriorTo {
s.retireRemote(rcid)
}
if !rcid.retired {
diff --git a/internal/quic/conn_id_test.go b/internal/quic/conn_id_test.go
index c528958..44755ec 100644
--- a/internal/quic/conn_id_test.go
+++ b/internal/quic/conn_id_test.go
@@ -48,6 +48,9 @@
t.Errorf("local ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantLocal))
}
wantRemote := []connID{{
+ cid: testLocalConnID(-1),
+ seq: -1,
+ }, {
cid: testPeerConnID(0),
seq: 0,
}}
@@ -261,10 +264,12 @@
}
func TestConnIDPeerWithZeroLengthConnIDSendsNewConnectionID(t *testing.T) {
- // An endpoint that selects a zero-length connection ID during the handshake
+ // "An endpoint that selects a zero-length connection ID during the handshake
// cannot issue a new connection ID."
// https://www.rfc-editor.org/rfc/rfc9000#section-5.1.1-8
- tc := newTestConn(t, clientSide)
+ tc := newTestConn(t, clientSide, func(p *transportParameters) {
+ p.initialSrcConnID = []byte{}
+ })
tc.peerConnID = []byte{}
tc.ignoreFrame(frameTypeAck)
tc.uncheckedHandshake()
@@ -536,6 +541,7 @@
// Peer gives us more conn ids than our advertised limit,
// including a conn id in the preferred address transport parameter.
tc := newTestConn(t, serverSide, func(p *transportParameters) {
+ p.initialSrcConnID = []byte{}
p.preferredAddrV4 = netip.MustParseAddrPort("0.0.0.0:0")
p.preferredAddrV6 = netip.MustParseAddrPort("[::0]:0")
p.preferredAddrConnID = testPeerConnID(1)
@@ -552,3 +558,31 @@
code: errProtocolViolation,
})
}
+
+func TestConnIDInitialSrcConnIDMismatch(t *testing.T) {
+ // "Endpoints MUST validate that received [initial_source_connection_id]
+ // parameters match received connection ID values."
+ // https://www.rfc-editor.org/rfc/rfc9000#section-7.3-3
+ testSides(t, "", func(t *testing.T, side connSide) {
+ tc := newTestConn(t, side, func(p *transportParameters) {
+ p.initialSrcConnID = []byte("invalid")
+ })
+ tc.ignoreFrame(frameTypeAck)
+ tc.ignoreFrame(frameTypeCrypto)
+ tc.writeFrames(packetTypeInitial,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial],
+ })
+ if side == clientSide {
+ // Server transport parameters are carried in the Handshake packet.
+ tc.writeFrames(packetTypeHandshake,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake],
+ })
+ }
+ tc.wantFrame("initial_source_connection_id transport parameter mismatch",
+ packetTypeInitial, debugFrameConnectionCloseTransport{
+ code: errTransportParameter,
+ })
+ })
+}
diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go
index fd9e6e4..6a359e8 100644
--- a/internal/quic/conn_test.go
+++ b/internal/quic/conn_test.go
@@ -201,6 +201,10 @@
TLSConfig: newTestTLSConfig(side),
}
peerProvidedParams := defaultTransportParameters()
+ peerProvidedParams.initialSrcConnID = testPeerConnID(0)
+ if side == clientSide {
+ peerProvidedParams.originalDstConnID = testLocalConnID(-1)
+ }
for _, o := range opts {
switch o := o.(type) {
case func(*Config):