quic: support stateless reset

Add a StatelessResetKey config field to permit
generating consistent stateless reset tokens
across server restarts.

Set the stateless_reset_token transport parameter
and populate the Token field in NEW_CONNECTION_ID
frames.

Detect reset tokens in datagrams which cannot
be associated with a connection or cannot be parsed.

RFC 9000, Section 10.3.

For golang/go#58547

Change-Id: Idb52ba07092ab5c08b323d6b531964a7e7e5ecea
Reviewed-on: https://go-review.googlesource.com/c/net/+/536315
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
Auto-Submit: Damien Neil <dneil@google.com>
diff --git a/internal/quic/conn_id_test.go b/internal/quic/conn_id_test.go
index 784c5e2..63feec9 100644
--- a/internal/quic/conn_id_test.go
+++ b/internal/quic/conn_id_test.go
@@ -47,12 +47,14 @@
 	if got := tc.conn.connIDState.local; !connIDListEqual(got, wantLocal) {
 		t.Errorf("local ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantLocal))
 	}
-	wantRemote := []connID{{
-		cid: testPeerConnID(0),
-		seq: 0,
+	wantRemote := []remoteConnID{{
+		connID: connID{
+			cid: testPeerConnID(0),
+			seq: 0,
+		},
 	}}
-	if got := tc.conn.connIDState.remote; !connIDListEqual(got, wantRemote) {
-		t.Errorf("remote ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantRemote))
+	if got := tc.conn.connIDState.remote; !remoteConnIDListEqual(got, wantRemote) {
+		t.Errorf("remote ids: %v, want %v", fmtRemoteConnIDList(got), fmtRemoteConnIDList(wantRemote))
 	}
 }
 
@@ -93,12 +95,14 @@
 	if got := tc.conn.connIDState.local; !connIDListEqual(got, wantLocal) {
 		t.Errorf("local ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantLocal))
 	}
-	wantRemote := []connID{{
-		cid: testPeerConnID(0),
-		seq: 0,
+	wantRemote := []remoteConnID{{
+		connID: connID{
+			cid: testPeerConnID(0),
+			seq: 0,
+		},
 	}}
-	if got := tc.conn.connIDState.remote; !connIDListEqual(got, wantRemote) {
-		t.Errorf("remote ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantRemote))
+	if got := tc.conn.connIDState.remote; !remoteConnIDListEqual(got, wantRemote) {
+		t.Errorf("remote ids: %v, want %v", fmtRemoteConnIDList(got), fmtRemoteConnIDList(wantRemote))
 	}
 
 	// The client's first Handshake packet permits the server to discard the
@@ -134,6 +138,24 @@
 	return true
 }
 
+func remoteConnIDListEqual(a, b []remoteConnID) bool {
+	if len(a) != len(b) {
+		return false
+	}
+	for i := range a {
+		if a[i].seq != b[i].seq {
+			return false
+		}
+		if !bytes.Equal(a[i].cid, b[i].cid) {
+			return false
+		}
+		if a[i].resetToken != b[i].resetToken {
+			return false
+		}
+	}
+	return true
+}
+
 func fmtConnIDList(s []connID) string {
 	var strs []string
 	for _, cid := range s {
@@ -142,6 +164,14 @@
 	return "{" + strings.Join(strs, " ") + "}"
 }
 
+func fmtRemoteConnIDList(s []remoteConnID) string {
+	var strs []string
+	for _, cid := range s {
+		strs = append(strs, fmt.Sprintf("[seq:%v cid:{%x} token:{%x}]", cid.seq, cid.cid, cid.resetToken))
+	}
+	return "{" + strings.Join(strs, " ") + "}"
+}
+
 func TestNewRandomConnID(t *testing.T) {
 	cid, err := newRandomConnID(0)
 	if len(cid) != connIDLen || err != nil {
@@ -174,16 +204,19 @@
 		packetType1RTT, debugFrameNewConnectionID{
 			seq:    1,
 			connID: testLocalConnID(1),
+			token:  testLocalStatelessResetToken(1),
 		})
 	tc.wantFrame("provide additional connection ID 2",
 		packetType1RTT, debugFrameNewConnectionID{
 			seq:    2,
 			connID: testLocalConnID(2),
+			token:  testLocalStatelessResetToken(2),
 		})
 	tc.wantFrame("provide additional connection ID 3",
 		packetType1RTT, debugFrameNewConnectionID{
 			seq:    3,
 			connID: testLocalConnID(3),
+			token:  testLocalStatelessResetToken(3),
 		})
 	tc.wantIdle("connection ID limit reached, no more to provide")
 }
@@ -255,6 +288,7 @@
 					seq:           2,
 					retirePriorTo: 1,
 					connID:        testLocalConnID(2),
+					token:         testLocalStatelessResetToken(2),
 				})
 		})
 	}
@@ -455,6 +489,7 @@
 			retirePriorTo: 1,
 			seq:           2,
 			connID:        testLocalConnID(2),
+			token:         testLocalStatelessResetToken(2),
 		})
 	tc.wantIdle("repeated RETIRE_CONNECTION_ID frames are not an error")
 }
@@ -583,3 +618,46 @@
 			})
 	})
 }
+
+func TestConnIDsCleanedUpAfterClose(t *testing.T) {
+	testSides(t, "", func(t *testing.T, side connSide) {
+		tc := newTestConn(t, side, func(p *transportParameters) {
+			if side == clientSide {
+				token := testPeerStatelessResetToken(0)
+				p.statelessResetToken = token[:]
+			}
+		})
+		tc.handshake()
+		tc.ignoreFrame(frameTypeAck)
+		tc.writeFrames(packetType1RTT,
+			debugFrameNewConnectionID{
+				seq:           2,
+				retirePriorTo: 1,
+				connID:        testPeerConnID(2),
+				token:         testPeerStatelessResetToken(0),
+			})
+		tc.wantFrame("peer asked for conn id 0 to be retired",
+			packetType1RTT, debugFrameRetireConnectionID{
+				seq: 0,
+			})
+		tc.writeFrames(packetType1RTT, debugFrameConnectionCloseTransport{})
+		tc.conn.Abort(nil)
+		tc.wantFrame("CONN_CLOSE sent after user closes connection",
+			packetType1RTT, debugFrameConnectionCloseTransport{})
+
+		// Wait for the conn to drain.
+		// Then wait for the conn loop to exit,
+		// and force an immediate sync of the connsMap updates
+		// (normally only done by the listener read loop).
+		tc.advanceToTimer()
+		<-tc.conn.donec
+		tc.listener.l.connsMap.applyUpdates()
+
+		if got := len(tc.listener.l.connsMap.byConnID); got != 0 {
+			t.Errorf("%v conn ids in listener map after closing, want 0", got)
+		}
+		if got := len(tc.listener.l.connsMap.byResetToken); got != 0 {
+			t.Errorf("%v reset tokens in listener map after closing, want 0", got)
+		}
+	})
+}