quic: connection ids
Each side of a QUIC connection chooses the connection IDs used by
its peer. In our case, we use 8-byte random IDs.
A connection has a list of connection IDs that it may receive
packets on, and a list that it may send packets to. Add a minimal
data structure for tracking these lists, and handling of the
connection IDs tracked across Initial and Handshake packets.
This does not yet handle post-handshake connection ID changes
made in NEW_CONNECTION_ID and RETIRE_CONNECTION_ID frames.
RFC 9000, Section 5.1.
For golang/go#58547
Change-Id: I3e059393cacafbcea04a1b4131c0c7dc28acad5e
Reviewed-on: https://go-review.googlesource.com/c/net/+/506675
Run-TryBot: Damien Neil <dneil@google.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/internal/quic/conn_id.go b/internal/quic/conn_id.go
new file mode 100644
index 0000000..deea70d
--- /dev/null
+++ b/internal/quic/conn_id.go
@@ -0,0 +1,147 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "crypto/rand"
+)
+
+// connIDState is a conn's connection IDs.
+type connIDState struct {
+ // The destination connection IDs of packets we receive are local.
+ // The destination connection IDs of packets we send are remote.
+ //
+ // Local IDs are usually issued by us, and remote IDs by the peer.
+ // The exception is the transient destination connection ID sent in
+ // a client's Initial packets, which is chosen by the client.
+ local []connID
+ remote []connID
+}
+
+// A connID is a connection ID and associated metadata.
+type connID struct {
+ // cid is the connection ID itself.
+ cid []byte
+
+ // seq is the connection ID's sequence number:
+ // https://www.rfc-editor.org/rfc/rfc9000.html#section-5.1.1-1
+ //
+ // For the transient destination ID in a client's Initial packet, this is -1.
+ seq int64
+}
+
+func (s *connIDState) initClient(newID newConnIDFunc) error {
+ // Client chooses its initial connection ID, and sends it
+ // in the Source Connection ID field of the first Initial packet.
+ locid, err := newID()
+ if err != nil {
+ return err
+ }
+ s.local = append(s.local, connID{
+ seq: 0,
+ cid: locid,
+ })
+
+ // Client chooses an initial, transient connection ID for the server,
+ // and sends it in the Destination Connection ID field of the first Initial packet.
+ remid, err := newID()
+ if err != nil {
+ return err
+ }
+ s.remote = append(s.remote, connID{
+ seq: -1,
+ cid: remid,
+ })
+ return nil
+}
+
+func (s *connIDState) initServer(newID newConnIDFunc, dstConnID []byte) error {
+ // Client-chosen, transient connection ID received in the first Initial packet.
+ // The server will not use this as the Source Connection ID of packets it sends,
+ // but remembers it because it may receive packets sent to this destination.
+ s.local = append(s.local, connID{
+ seq: -1,
+ cid: cloneBytes(dstConnID),
+ })
+
+ // Server chooses a connection ID, and sends it in the Source Connection ID of
+ // the response to the clent.
+ locid, err := newID()
+ if err != nil {
+ return err
+ }
+ s.local = append(s.local, connID{
+ seq: 0,
+ cid: locid,
+ })
+ return nil
+}
+
+// srcConnID is the Source Connection ID to use in a sent packet.
+func (s *connIDState) srcConnID() []byte {
+ if s.local[0].seq == -1 && len(s.local) > 1 {
+ // Don't use the transient connection ID if another is available.
+ return s.local[1].cid
+ }
+ return s.local[0].cid
+}
+
+// dstConnID is the Destination Connection ID to use in a sent packet.
+func (s *connIDState) dstConnID() []byte {
+ return s.remote[0].cid
+}
+
+// handlePacket updates the connection ID state during the handshake
+// (Initial and Handshake packets).
+func (s *connIDState) handlePacket(side connSide, ptype packetType, srcConnID []byte) {
+ switch {
+ case ptype == packetTypeInitial && side == clientSide:
+ if len(s.remote) == 1 && s.remote[0].seq == -1 {
+ // We're a client connection processing the first Initial packet
+ // from the server. Replace the transient remote connection ID
+ // with the Source Connection ID from the packet.
+ s.remote[0] = connID{
+ seq: 0,
+ cid: cloneBytes(srcConnID),
+ }
+ }
+ case ptype == packetTypeInitial && side == serverSide:
+ if len(s.remote) == 0 {
+ // We're a server connection processing the first Initial packet
+ // from the client. Set the client's connection ID.
+ s.remote = append(s.remote, connID{
+ seq: 0,
+ cid: cloneBytes(srcConnID),
+ })
+ }
+ case ptype == packetTypeHandshake && side == serverSide:
+ if len(s.local) > 0 && s.local[0].seq == -1 {
+ // We're a server connection processing the first Handshake packet from
+ // the client. Discard the transient, client-chosen connection ID used
+ // for Initial packets; the client will never send it again.
+ s.local = append(s.local[:0], s.local[1:]...)
+ }
+ }
+}
+
+func cloneBytes(b []byte) []byte {
+ n := make([]byte, len(b))
+ copy(n, b)
+ return n
+}
+
+type newConnIDFunc func() ([]byte, error)
+
+func newRandomConnID() ([]byte, error) {
+ // It is not necessary for connection IDs to be cryptographically secure,
+ // but it doesn't hurt.
+ id := make([]byte, connIDLen)
+ if _, err := rand.Read(id); err != nil {
+ return nil, err
+ }
+ return id, nil
+}
diff --git a/internal/quic/conn_id_test.go b/internal/quic/conn_id_test.go
new file mode 100644
index 0000000..7c31e9d
--- /dev/null
+++ b/internal/quic/conn_id_test.go
@@ -0,0 +1,109 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "fmt"
+ "reflect"
+ "testing"
+)
+
+func TestConnIDClientHandshake(t *testing.T) {
+ // On initialization, the client chooses local and remote IDs.
+ //
+ // The order in which we allocate the two isn't actually important,
+ // but test is a lot simpler if we assume.
+ var s connIDState
+ s.initClient(newConnIDSequence())
+ if got, want := string(s.srcConnID()), "local-1"; got != want {
+ t.Errorf("after initClient: srcConnID = %q, want %q", got, want)
+ }
+ if got, want := string(s.dstConnID()), "local-2"; got != want {
+ t.Errorf("after initClient: dstConnID = %q, want %q", got, want)
+ }
+
+ // The server's first Initial packet provides the client with a
+ // non-transient remote connection ID.
+ s.handlePacket(clientSide, packetTypeInitial, []byte("remote-1"))
+ if got, want := string(s.dstConnID()), "remote-1"; got != want {
+ t.Errorf("after receiving Initial: dstConnID = %q, want %q", got, want)
+ }
+
+ wantLocal := []connID{{
+ cid: []byte("local-1"),
+ seq: 0,
+ }}
+ if !reflect.DeepEqual(s.local, wantLocal) {
+ t.Errorf("local ids: %v, want %v", s.local, wantLocal)
+ }
+ wantRemote := []connID{{
+ cid: []byte("remote-1"),
+ seq: 0,
+ }}
+ if !reflect.DeepEqual(s.remote, wantRemote) {
+ t.Errorf("remote ids: %v, want %v", s.remote, wantRemote)
+ }
+}
+
+func TestConnIDServerHandshake(t *testing.T) {
+ // On initialization, the server is provided with the client-chosen
+ // transient connection ID, and allocates an ID of its own.
+ // The Initial packet sets the remote connection ID.
+ var s connIDState
+ s.initServer(newConnIDSequence(), []byte("transient"))
+ s.handlePacket(serverSide, packetTypeInitial, []byte("remote-1"))
+ if got, want := string(s.srcConnID()), "local-1"; got != want {
+ t.Errorf("after initClient: srcConnID = %q, want %q", got, want)
+ }
+ if got, want := string(s.dstConnID()), "remote-1"; got != want {
+ t.Errorf("after initClient: dstConnID = %q, want %q", got, want)
+ }
+
+ wantLocal := []connID{{
+ cid: []byte("transient"),
+ seq: -1,
+ }, {
+ cid: []byte("local-1"),
+ seq: 0,
+ }}
+ if !reflect.DeepEqual(s.local, wantLocal) {
+ t.Errorf("local ids: %v, want %v", s.local, wantLocal)
+ }
+ wantRemote := []connID{{
+ cid: []byte("remote-1"),
+ seq: 0,
+ }}
+ if !reflect.DeepEqual(s.remote, wantRemote) {
+ t.Errorf("remote ids: %v, want %v", s.remote, wantRemote)
+ }
+
+ // The client's first Handshake packet permits the server to discard the
+ // transient connection ID.
+ s.handlePacket(serverSide, packetTypeHandshake, []byte("remote-1"))
+ wantLocal = []connID{{
+ cid: []byte("local-1"),
+ seq: 0,
+ }}
+ if !reflect.DeepEqual(s.local, wantLocal) {
+ t.Errorf("after handshake local ids: %v, want %v", s.local, wantLocal)
+ }
+}
+
+func newConnIDSequence() newConnIDFunc {
+ var n uint64
+ return func() ([]byte, error) {
+ n++
+ return []byte(fmt.Sprintf("local-%v", n)), nil
+ }
+}
+
+func TestNewRandomConnID(t *testing.T) {
+ cid, err := newRandomConnID()
+ if len(cid) != connIDLen || err != nil {
+ t.Fatalf("newConnID() = %x, %v; want %v bytes", cid, connIDLen, err)
+ }
+}