quic: connection ids

Each side of a QUIC connection chooses the connection IDs used by
its peer. In our case, we use 8-byte random IDs.

A connection has a list of connection IDs that it may receive
packets on, and a list that it may send packets to. Add a minimal
data structure for tracking these lists, and handling of the
connection IDs tracked across Initial and Handshake packets.

This does not yet handle post-handshake connection ID changes
made in NEW_CONNECTION_ID and RETIRE_CONNECTION_ID frames.

RFC 9000, Section 5.1.

For golang/go#58547

Change-Id: I3e059393cacafbcea04a1b4131c0c7dc28acad5e
Reviewed-on: https://go-review.googlesource.com/c/net/+/506675
Run-TryBot: Damien Neil <dneil@google.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/internal/quic/conn_id.go b/internal/quic/conn_id.go
new file mode 100644
index 0000000..deea70d
--- /dev/null
+++ b/internal/quic/conn_id.go
@@ -0,0 +1,147 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+	"crypto/rand"
+)
+
+// connIDState is a conn's connection IDs.
+type connIDState struct {
+	// The destination connection IDs of packets we receive are local.
+	// The destination connection IDs of packets we send are remote.
+	//
+	// Local IDs are usually issued by us, and remote IDs by the peer.
+	// The exception is the transient destination connection ID sent in
+	// a client's Initial packets, which is chosen by the client.
+	local  []connID
+	remote []connID
+}
+
+// A connID is a connection ID and associated metadata.
+type connID struct {
+	// cid is the connection ID itself.
+	cid []byte
+
+	// seq is the connection ID's sequence number:
+	// https://www.rfc-editor.org/rfc/rfc9000.html#section-5.1.1-1
+	//
+	// For the transient destination ID in a client's Initial packet, this is -1.
+	seq int64
+}
+
+func (s *connIDState) initClient(newID newConnIDFunc) error {
+	// Client chooses its initial connection ID, and sends it
+	// in the Source Connection ID field of the first Initial packet.
+	locid, err := newID()
+	if err != nil {
+		return err
+	}
+	s.local = append(s.local, connID{
+		seq: 0,
+		cid: locid,
+	})
+
+	// Client chooses an initial, transient connection ID for the server,
+	// and sends it in the Destination Connection ID field of the first Initial packet.
+	remid, err := newID()
+	if err != nil {
+		return err
+	}
+	s.remote = append(s.remote, connID{
+		seq: -1,
+		cid: remid,
+	})
+	return nil
+}
+
+func (s *connIDState) initServer(newID newConnIDFunc, dstConnID []byte) error {
+	// Client-chosen, transient connection ID received in the first Initial packet.
+	// The server will not use this as the Source Connection ID of packets it sends,
+	// but remembers it because it may receive packets sent to this destination.
+	s.local = append(s.local, connID{
+		seq: -1,
+		cid: cloneBytes(dstConnID),
+	})
+
+	// Server chooses a connection ID, and sends it in the Source Connection ID of
+	// the response to the clent.
+	locid, err := newID()
+	if err != nil {
+		return err
+	}
+	s.local = append(s.local, connID{
+		seq: 0,
+		cid: locid,
+	})
+	return nil
+}
+
+// srcConnID is the Source Connection ID to use in a sent packet.
+func (s *connIDState) srcConnID() []byte {
+	if s.local[0].seq == -1 && len(s.local) > 1 {
+		// Don't use the transient connection ID if another is available.
+		return s.local[1].cid
+	}
+	return s.local[0].cid
+}
+
+// dstConnID is the Destination Connection ID to use in a sent packet.
+func (s *connIDState) dstConnID() []byte {
+	return s.remote[0].cid
+}
+
+// handlePacket updates the connection ID state during the handshake
+// (Initial and Handshake packets).
+func (s *connIDState) handlePacket(side connSide, ptype packetType, srcConnID []byte) {
+	switch {
+	case ptype == packetTypeInitial && side == clientSide:
+		if len(s.remote) == 1 && s.remote[0].seq == -1 {
+			// We're a client connection processing the first Initial packet
+			// from the server. Replace the transient remote connection ID
+			// with the Source Connection ID from the packet.
+			s.remote[0] = connID{
+				seq: 0,
+				cid: cloneBytes(srcConnID),
+			}
+		}
+	case ptype == packetTypeInitial && side == serverSide:
+		if len(s.remote) == 0 {
+			// We're a server connection processing the first Initial packet
+			// from the client. Set the client's connection ID.
+			s.remote = append(s.remote, connID{
+				seq: 0,
+				cid: cloneBytes(srcConnID),
+			})
+		}
+	case ptype == packetTypeHandshake && side == serverSide:
+		if len(s.local) > 0 && s.local[0].seq == -1 {
+			// We're a server connection processing the first Handshake packet from
+			// the client. Discard the transient, client-chosen connection ID used
+			// for Initial packets; the client will never send it again.
+			s.local = append(s.local[:0], s.local[1:]...)
+		}
+	}
+}
+
+func cloneBytes(b []byte) []byte {
+	n := make([]byte, len(b))
+	copy(n, b)
+	return n
+}
+
+type newConnIDFunc func() ([]byte, error)
+
+func newRandomConnID() ([]byte, error) {
+	// It is not necessary for connection IDs to be cryptographically secure,
+	// but it doesn't hurt.
+	id := make([]byte, connIDLen)
+	if _, err := rand.Read(id); err != nil {
+		return nil, err
+	}
+	return id, nil
+}
diff --git a/internal/quic/conn_id_test.go b/internal/quic/conn_id_test.go
new file mode 100644
index 0000000..7c31e9d
--- /dev/null
+++ b/internal/quic/conn_id_test.go
@@ -0,0 +1,109 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+	"fmt"
+	"reflect"
+	"testing"
+)
+
+func TestConnIDClientHandshake(t *testing.T) {
+	// On initialization, the client chooses local and remote IDs.
+	//
+	// The order in which we allocate the two isn't actually important,
+	// but test is a lot simpler if we assume.
+	var s connIDState
+	s.initClient(newConnIDSequence())
+	if got, want := string(s.srcConnID()), "local-1"; got != want {
+		t.Errorf("after initClient: srcConnID = %q, want %q", got, want)
+	}
+	if got, want := string(s.dstConnID()), "local-2"; got != want {
+		t.Errorf("after initClient: dstConnID = %q, want %q", got, want)
+	}
+
+	// The server's first Initial packet provides the client with a
+	// non-transient remote connection ID.
+	s.handlePacket(clientSide, packetTypeInitial, []byte("remote-1"))
+	if got, want := string(s.dstConnID()), "remote-1"; got != want {
+		t.Errorf("after receiving Initial: dstConnID = %q, want %q", got, want)
+	}
+
+	wantLocal := []connID{{
+		cid: []byte("local-1"),
+		seq: 0,
+	}}
+	if !reflect.DeepEqual(s.local, wantLocal) {
+		t.Errorf("local ids: %v, want %v", s.local, wantLocal)
+	}
+	wantRemote := []connID{{
+		cid: []byte("remote-1"),
+		seq: 0,
+	}}
+	if !reflect.DeepEqual(s.remote, wantRemote) {
+		t.Errorf("remote ids: %v, want %v", s.remote, wantRemote)
+	}
+}
+
+func TestConnIDServerHandshake(t *testing.T) {
+	// On initialization, the server is provided with the client-chosen
+	// transient connection ID, and allocates an ID of its own.
+	// The Initial packet sets the remote connection ID.
+	var s connIDState
+	s.initServer(newConnIDSequence(), []byte("transient"))
+	s.handlePacket(serverSide, packetTypeInitial, []byte("remote-1"))
+	if got, want := string(s.srcConnID()), "local-1"; got != want {
+		t.Errorf("after initClient: srcConnID = %q, want %q", got, want)
+	}
+	if got, want := string(s.dstConnID()), "remote-1"; got != want {
+		t.Errorf("after initClient: dstConnID = %q, want %q", got, want)
+	}
+
+	wantLocal := []connID{{
+		cid: []byte("transient"),
+		seq: -1,
+	}, {
+		cid: []byte("local-1"),
+		seq: 0,
+	}}
+	if !reflect.DeepEqual(s.local, wantLocal) {
+		t.Errorf("local ids: %v, want %v", s.local, wantLocal)
+	}
+	wantRemote := []connID{{
+		cid: []byte("remote-1"),
+		seq: 0,
+	}}
+	if !reflect.DeepEqual(s.remote, wantRemote) {
+		t.Errorf("remote ids: %v, want %v", s.remote, wantRemote)
+	}
+
+	// The client's first Handshake packet permits the server to discard the
+	// transient connection ID.
+	s.handlePacket(serverSide, packetTypeHandshake, []byte("remote-1"))
+	wantLocal = []connID{{
+		cid: []byte("local-1"),
+		seq: 0,
+	}}
+	if !reflect.DeepEqual(s.local, wantLocal) {
+		t.Errorf("after handshake local ids: %v, want %v", s.local, wantLocal)
+	}
+}
+
+func newConnIDSequence() newConnIDFunc {
+	var n uint64
+	return func() ([]byte, error) {
+		n++
+		return []byte(fmt.Sprintf("local-%v", n)), nil
+	}
+}
+
+func TestNewRandomConnID(t *testing.T) {
+	cid, err := newRandomConnID()
+	if len(cid) != connIDLen || err != nil {
+		t.Fatalf("newConnID() = %x, %v; want %v bytes", cid, connIDLen, err)
+	}
+}