quic: support stateless reset

Add a StatelessResetKey config field to permit
generating consistent stateless reset tokens
across server restarts.

Set the stateless_reset_token transport parameter
and populate the Token field in NEW_CONNECTION_ID
frames.

Detect reset tokens in datagrams which cannot
be associated with a connection or cannot be parsed.

RFC 9000, Section 10.3.

For golang/go#58547

Change-Id: Idb52ba07092ab5c08b323d6b531964a7e7e5ecea
Reviewed-on: https://go-review.googlesource.com/c/net/+/536315
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
Auto-Submit: Damien Neil <dneil@google.com>
diff --git a/internal/quic/config.go b/internal/quic/config.go
index 99ef68f..6278bf8 100644
--- a/internal/quic/config.go
+++ b/internal/quic/config.go
@@ -55,6 +55,23 @@
 	// source address information can cause a server to perform,
 	// at the cost of increased handshake latency.
 	RequireAddressValidation bool
+
+	// StatelessResetKey is used to provide stateless reset of connections.
+	// A restart may leave an endpoint without access to the state of
+	// existing connections. Stateless reset permits an endpoint to respond
+	// to a packet for a connection it does not recognize.
+	//
+	// This field should be filled with random bytes.
+	// The contents should remain stable across restarts,
+	// to permit an endpoint to send a reset for
+	// connections created before a restart.
+	//
+	// The contents of the StatelessResetKey should not be exposed.
+	// An attacker can use knowledge of this field's value to
+	// reset existing connections.
+	//
+	// If this field is left as zero, stateless reset is disabled.
+	StatelessResetKey [32]byte
 }
 
 func configDefault(v, def, limit int64) int64 {
diff --git a/internal/quic/conn.go b/internal/quic/conn.go
index 4acf5dd..b3d6fea 100644
--- a/internal/quic/conn.go
+++ b/internal/quic/conn.go
@@ -203,7 +203,7 @@
 // receiveTransportParameters applies transport parameters sent by the peer.
 func (c *Conn) receiveTransportParameters(p transportParameters) error {
 	isRetry := c.retryToken != nil
-	if err := c.connIDState.validateTransportParameters(c.side, isRetry, p); err != nil {
+	if err := c.connIDState.validateTransportParameters(c, isRetry, p); err != nil {
 		return err
 	}
 	c.streams.outflow.setMaxData(p.initialMaxData)
@@ -224,7 +224,7 @@
 			resetToken    [16]byte
 		)
 		copy(resetToken[:], p.preferredAddrResetToken)
-		if err := c.connIDState.handleNewConnID(seq, retirePriorTo, p.preferredAddrConnID, resetToken); err != nil {
+		if err := c.connIDState.handleNewConnID(c, seq, retirePriorTo, p.preferredAddrConnID, resetToken); err != nil {
 			return err
 		}
 	}
diff --git a/internal/quic/conn_close_test.go b/internal/quic/conn_close_test.go
index 20c00e7..d5c3499 100644
--- a/internal/quic/conn_close_test.go
+++ b/internal/quic/conn_close_test.go
@@ -15,7 +15,9 @@
 )
 
 func TestConnCloseResponseBackoff(t *testing.T) {
-	tc := newTestConn(t, clientSide)
+	tc := newTestConn(t, clientSide, func(c *Config) {
+		clear(c.StatelessResetKey[:])
+	})
 	tc.handshake()
 
 	tc.conn.Abort(nil)
diff --git a/internal/quic/conn_id.go b/internal/quic/conn_id.go
index ff7e2d1..c236137 100644
--- a/internal/quic/conn_id.go
+++ b/internal/quic/conn_id.go
@@ -22,7 +22,7 @@
 	//
 	// These are []connID rather than []*connID to minimize allocations.
 	local  []connID
-	remote []connID
+	remote []remoteConnID
 
 	nextLocalSeq          int64
 	retireRemotePriorTo   int64 // largest Retire Prior To value sent by the peer
@@ -58,6 +58,12 @@
 	send sentVal
 }
 
+// A remoteConnID is a connection ID and stateless reset token.
+type remoteConnID struct {
+	connID
+	resetToken statelessResetToken
+}
+
 func (s *connIDState) initClient(c *Conn) error {
 	// Client chooses its initial connection ID, and sends it
 	// in the Source Connection ID field of the first Initial packet.
@@ -70,6 +76,9 @@
 		cid: locid,
 	})
 	s.nextLocalSeq = 1
+	c.listener.connsMap.updateConnIDs(func(conns *connsMap) {
+		conns.addConnID(c, locid)
+	})
 
 	// Client chooses an initial, transient connection ID for the server,
 	// and sends it in the Destination Connection ID field of the first Initial packet.
@@ -77,13 +86,13 @@
 	if err != nil {
 		return err
 	}
-	s.remote = append(s.remote, connID{
-		seq: -1,
-		cid: remid,
+	s.remote = append(s.remote, remoteConnID{
+		connID: connID{
+			seq: -1,
+			cid: remid,
+		},
 	})
 	s.originalDstConnID = remid
-	const retired = false
-	c.listener.connIDsChanged(c, retired, s.local[:])
 	return nil
 }
 
@@ -107,8 +116,10 @@
 		cid: locid,
 	})
 	s.nextLocalSeq = 1
-	const retired = false
-	c.listener.connIDsChanged(c, retired, s.local[:])
+	c.listener.connsMap.updateConnIDs(func(conns *connsMap) {
+		conns.addConnID(c, dstConnID)
+		conns.addConnID(c, locid)
+	})
 	return nil
 }
 
@@ -131,6 +142,19 @@
 	return nil, false
 }
 
+// isValidStatelessResetToken reports whether the given reset token is
+// associated with a non-retired connection ID which we have used.
+func (s *connIDState) isValidStatelessResetToken(resetToken statelessResetToken) bool {
+	for i := range s.remote {
+		// We currently only use the first available remote connection ID,
+		// so any other reset token is not valid.
+		if !s.remote[i].retired {
+			return s.remote[i].resetToken == resetToken
+		}
+	}
+	return false
+}
+
 // setPeerActiveConnIDLimit sets the active_connection_id_limit
 // transport parameter received from the peer.
 func (s *connIDState) setPeerActiveConnIDLimit(c *Conn, lim int64) error {
@@ -145,12 +169,13 @@
 			toIssue--
 		}
 	}
-	prev := len(s.local)
+	var newIDs [][]byte
 	for toIssue > 0 {
 		cid, err := c.newConnID(s.nextLocalSeq)
 		if err != nil {
 			return err
 		}
+		newIDs = append(newIDs, cid)
 		s.local = append(s.local, connID{
 			seq: s.nextLocalSeq,
 			cid: cid,
@@ -160,14 +185,17 @@
 		s.needSend = true
 		toIssue--
 	}
-	const retired = false
-	c.listener.connIDsChanged(c, retired, s.local[prev:])
+	c.listener.connsMap.updateConnIDs(func(conns *connsMap) {
+		for _, cid := range newIDs {
+			conns.addConnID(c, cid)
+		}
+	})
 	return nil
 }
 
 // validateTransportParameters verifies the original_destination_connection_id and
 // initial_source_connection_id transport parameters match the expected values.
-func (s *connIDState) validateTransportParameters(side connSide, isRetry bool, p transportParameters) error {
+func (s *connIDState) validateTransportParameters(c *Conn, isRetry bool, p transportParameters) error {
 	// TODO: Consider returning more detailed errors, for debugging.
 	// Verify original_destination_connection_id matches
 	// the transient remote connection ID we chose (client)
@@ -189,6 +217,16 @@
 	if !bytes.Equal(p.initialSrcConnID, s.remote[0].cid) {
 		return localTransportError(errTransportParameter)
 	}
+	if len(p.statelessResetToken) > 0 {
+		if c.side == serverSide {
+			return localTransportError(errTransportParameter)
+		}
+		token := statelessResetToken(p.statelessResetToken)
+		s.remote[0].resetToken = token
+		c.listener.connsMap.updateConnIDs(func(conns *connsMap) {
+			conns.addResetToken(c, token)
+		})
+	}
 	return nil
 }
 
@@ -201,18 +239,22 @@
 			// We're a client connection processing the first Initial packet
 			// from the server. Replace the transient remote connection ID
 			// with the Source Connection ID from the packet.
-			s.remote[0] = connID{
-				seq: 0,
-				cid: cloneBytes(srcConnID),
+			s.remote[0] = remoteConnID{
+				connID: connID{
+					seq: 0,
+					cid: cloneBytes(srcConnID),
+				},
 			}
 		}
 	case ptype == packetTypeInitial && c.side == serverSide:
 		if len(s.remote) == 0 {
 			// We're a server connection processing the first Initial packet
 			// from the client. Set the client's connection ID.
-			s.remote = append(s.remote, connID{
-				seq: 0,
-				cid: cloneBytes(srcConnID),
+			s.remote = append(s.remote, remoteConnID{
+				connID: connID{
+					seq: 0,
+					cid: cloneBytes(srcConnID),
+				},
 			})
 		}
 	case ptype == packetTypeHandshake && c.side == serverSide:
@@ -220,8 +262,10 @@
 			// We're a server connection processing the first Handshake packet from
 			// the client. Discard the transient, client-chosen connection ID used
 			// for Initial packets; the client will never send it again.
-			const retired = true
-			c.listener.connIDsChanged(c, retired, s.local[0:1])
+			cid := s.local[0].cid
+			c.listener.connsMap.updateConnIDs(func(conns *connsMap) {
+				conns.retireConnID(c, cid)
+			})
 			s.local = append(s.local[:0], s.local[1:]...)
 		}
 	}
@@ -235,7 +279,7 @@
 	s.remote[0].cid = s.retrySrcConnID
 }
 
-func (s *connIDState) handleNewConnID(seq, retire int64, cid []byte, resetToken [16]byte) error {
+func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, resetToken statelessResetToken) error {
 	if len(s.remote[0].cid) == 0 {
 		// "An endpoint that is sending packets with a zero-length
 		// Destination Connection ID MUST treat receipt of a NEW_CONNECTION_ID
@@ -254,6 +298,9 @@
 		rcid := &s.remote[i]
 		if !rcid.retired && rcid.seq >= 0 && rcid.seq < s.retireRemotePriorTo {
 			s.retireRemote(rcid)
+			c.listener.connsMap.updateConnIDs(func(conns *connsMap) {
+				conns.retireResetToken(c, rcid.resetToken)
+			})
 		}
 		if !rcid.retired {
 			active++
@@ -272,15 +319,21 @@
 		// We could take steps to keep the list of remote connection IDs
 		// sorted by sequence number, but there's no particular need
 		// so we don't bother.
-		s.remote = append(s.remote, connID{
-			seq: seq,
-			cid: cloneBytes(cid),
+		s.remote = append(s.remote, remoteConnID{
+			connID: connID{
+				seq: seq,
+				cid: cloneBytes(cid),
+			},
+			resetToken: resetToken,
 		})
 		if seq < s.retireRemotePriorTo {
 			// This ID was already retired by a previous NEW_CONNECTION_ID frame.
 			s.retireRemote(&s.remote[len(s.remote)-1])
 		} else {
 			active++
+			c.listener.connsMap.updateConnIDs(func(conns *connsMap) {
+				conns.addResetToken(c, resetToken)
+			})
 		}
 	}
 
@@ -305,7 +358,7 @@
 }
 
 // retireRemote marks a remote connection ID as retired.
-func (s *connIDState) retireRemote(rcid *connID) {
+func (s *connIDState) retireRemote(rcid *remoteConnID) {
 	rcid.retired = true
 	rcid.send.setUnsent()
 	s.needSend = true
@@ -317,8 +370,10 @@
 	}
 	for i := range s.local {
 		if s.local[i].seq == seq {
-			const retired = true
-			c.listener.connIDsChanged(c, retired, s.local[i:i+1])
+			cid := s.local[i].cid
+			c.listener.connsMap.updateConnIDs(func(conns *connsMap) {
+				conns.retireConnID(c, cid)
+			})
 			s.local = append(s.local[:i], s.local[i+1:]...)
 			break
 		}
@@ -363,7 +418,7 @@
 //
 // It returns true if no more frames need appending,
 // false if not everything fit in the current packet.
-func (s *connIDState) appendFrames(w *packetWriter, pnum packetNumber, pto bool) bool {
+func (s *connIDState) appendFrames(c *Conn, pnum packetNumber, pto bool) bool {
 	if !s.needSend && !pto {
 		// Fast path: We don't need to send anything.
 		return true
@@ -376,11 +431,11 @@
 		if !s.local[i].send.shouldSendPTO(pto) {
 			continue
 		}
-		if !w.appendNewConnectionIDFrame(
+		if !c.w.appendNewConnectionIDFrame(
 			s.local[i].seq,
 			retireBefore,
 			s.local[i].cid,
-			[16]byte{}, // TODO: stateless reset token
+			c.listener.resetGen.tokenForConnID(s.local[i].cid),
 		) {
 			return false
 		}
@@ -390,7 +445,7 @@
 		if !s.remote[i].send.shouldSendPTO(pto) {
 			continue
 		}
-		if !w.appendRetireConnectionIDFrame(s.remote[i].seq) {
+		if !c.w.appendRetireConnectionIDFrame(s.remote[i].seq) {
 			return false
 		}
 		s.remote[i].send.setSent(pnum)
diff --git a/internal/quic/conn_id_test.go b/internal/quic/conn_id_test.go
index 784c5e2..63feec9 100644
--- a/internal/quic/conn_id_test.go
+++ b/internal/quic/conn_id_test.go
@@ -47,12 +47,14 @@
 	if got := tc.conn.connIDState.local; !connIDListEqual(got, wantLocal) {
 		t.Errorf("local ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantLocal))
 	}
-	wantRemote := []connID{{
-		cid: testPeerConnID(0),
-		seq: 0,
+	wantRemote := []remoteConnID{{
+		connID: connID{
+			cid: testPeerConnID(0),
+			seq: 0,
+		},
 	}}
-	if got := tc.conn.connIDState.remote; !connIDListEqual(got, wantRemote) {
-		t.Errorf("remote ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantRemote))
+	if got := tc.conn.connIDState.remote; !remoteConnIDListEqual(got, wantRemote) {
+		t.Errorf("remote ids: %v, want %v", fmtRemoteConnIDList(got), fmtRemoteConnIDList(wantRemote))
 	}
 }
 
@@ -93,12 +95,14 @@
 	if got := tc.conn.connIDState.local; !connIDListEqual(got, wantLocal) {
 		t.Errorf("local ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantLocal))
 	}
-	wantRemote := []connID{{
-		cid: testPeerConnID(0),
-		seq: 0,
+	wantRemote := []remoteConnID{{
+		connID: connID{
+			cid: testPeerConnID(0),
+			seq: 0,
+		},
 	}}
-	if got := tc.conn.connIDState.remote; !connIDListEqual(got, wantRemote) {
-		t.Errorf("remote ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantRemote))
+	if got := tc.conn.connIDState.remote; !remoteConnIDListEqual(got, wantRemote) {
+		t.Errorf("remote ids: %v, want %v", fmtRemoteConnIDList(got), fmtRemoteConnIDList(wantRemote))
 	}
 
 	// The client's first Handshake packet permits the server to discard the
@@ -134,6 +138,24 @@
 	return true
 }
 
+func remoteConnIDListEqual(a, b []remoteConnID) bool {
+	if len(a) != len(b) {
+		return false
+	}
+	for i := range a {
+		if a[i].seq != b[i].seq {
+			return false
+		}
+		if !bytes.Equal(a[i].cid, b[i].cid) {
+			return false
+		}
+		if a[i].resetToken != b[i].resetToken {
+			return false
+		}
+	}
+	return true
+}
+
 func fmtConnIDList(s []connID) string {
 	var strs []string
 	for _, cid := range s {
@@ -142,6 +164,14 @@
 	return "{" + strings.Join(strs, " ") + "}"
 }
 
+func fmtRemoteConnIDList(s []remoteConnID) string {
+	var strs []string
+	for _, cid := range s {
+		strs = append(strs, fmt.Sprintf("[seq:%v cid:{%x} token:{%x}]", cid.seq, cid.cid, cid.resetToken))
+	}
+	return "{" + strings.Join(strs, " ") + "}"
+}
+
 func TestNewRandomConnID(t *testing.T) {
 	cid, err := newRandomConnID(0)
 	if len(cid) != connIDLen || err != nil {
@@ -174,16 +204,19 @@
 		packetType1RTT, debugFrameNewConnectionID{
 			seq:    1,
 			connID: testLocalConnID(1),
+			token:  testLocalStatelessResetToken(1),
 		})
 	tc.wantFrame("provide additional connection ID 2",
 		packetType1RTT, debugFrameNewConnectionID{
 			seq:    2,
 			connID: testLocalConnID(2),
+			token:  testLocalStatelessResetToken(2),
 		})
 	tc.wantFrame("provide additional connection ID 3",
 		packetType1RTT, debugFrameNewConnectionID{
 			seq:    3,
 			connID: testLocalConnID(3),
+			token:  testLocalStatelessResetToken(3),
 		})
 	tc.wantIdle("connection ID limit reached, no more to provide")
 }
@@ -255,6 +288,7 @@
 					seq:           2,
 					retirePriorTo: 1,
 					connID:        testLocalConnID(2),
+					token:         testLocalStatelessResetToken(2),
 				})
 		})
 	}
@@ -455,6 +489,7 @@
 			retirePriorTo: 1,
 			seq:           2,
 			connID:        testLocalConnID(2),
+			token:         testLocalStatelessResetToken(2),
 		})
 	tc.wantIdle("repeated RETIRE_CONNECTION_ID frames are not an error")
 }
@@ -583,3 +618,46 @@
 			})
 	})
 }
+
+func TestConnIDsCleanedUpAfterClose(t *testing.T) {
+	testSides(t, "", func(t *testing.T, side connSide) {
+		tc := newTestConn(t, side, func(p *transportParameters) {
+			if side == clientSide {
+				token := testPeerStatelessResetToken(0)
+				p.statelessResetToken = token[:]
+			}
+		})
+		tc.handshake()
+		tc.ignoreFrame(frameTypeAck)
+		tc.writeFrames(packetType1RTT,
+			debugFrameNewConnectionID{
+				seq:           2,
+				retirePriorTo: 1,
+				connID:        testPeerConnID(2),
+				token:         testPeerStatelessResetToken(0),
+			})
+		tc.wantFrame("peer asked for conn id 0 to be retired",
+			packetType1RTT, debugFrameRetireConnectionID{
+				seq: 0,
+			})
+		tc.writeFrames(packetType1RTT, debugFrameConnectionCloseTransport{})
+		tc.conn.Abort(nil)
+		tc.wantFrame("CONN_CLOSE sent after user closes connection",
+			packetType1RTT, debugFrameConnectionCloseTransport{})
+
+		// Wait for the conn to drain.
+		// Then wait for the conn loop to exit,
+		// and force an immediate sync of the connsMap updates
+		// (normally only done by the listener read loop).
+		tc.advanceToTimer()
+		<-tc.conn.donec
+		tc.listener.l.connsMap.applyUpdates()
+
+		if got := len(tc.listener.l.connsMap.byConnID); got != 0 {
+			t.Errorf("%v conn ids in listener map after closing, want 0", got)
+		}
+		if got := len(tc.listener.l.connsMap.byResetToken); got != 0 {
+			t.Errorf("%v reset tokens in listener map after closing, want 0", got)
+		}
+	})
+}
diff --git a/internal/quic/conn_loss_test.go b/internal/quic/conn_loss_test.go
index 9b88462..5144be6 100644
--- a/internal/quic/conn_loss_test.go
+++ b/internal/quic/conn_loss_test.go
@@ -160,6 +160,7 @@
 			packetType1RTT, debugFrameNewConnectionID{
 				seq:    1,
 				connID: testLocalConnID(1),
+				token:  testLocalStatelessResetToken(1),
 			})
 		tc.triggerLossOrPTO(packetTypeHandshake, pto)
 		tc.wantFrame("client resends Handshake CRYPTO frame",
@@ -607,6 +608,7 @@
 			packetType1RTT, debugFrameNewConnectionID{
 				seq:    2,
 				connID: testLocalConnID(2),
+				token:  testLocalStatelessResetToken(2),
 			})
 
 		tc.triggerLossOrPTO(packetType1RTT, pto)
@@ -614,6 +616,7 @@
 			packetType1RTT, debugFrameNewConnectionID{
 				seq:    2,
 				connID: testLocalConnID(2),
+				token:  testLocalStatelessResetToken(2),
 			})
 	})
 }
@@ -669,6 +672,7 @@
 			packetType1RTT, debugFrameNewConnectionID{
 				seq:    1,
 				connID: testLocalConnID(1),
+				token:  testLocalStatelessResetToken(1),
 			})
 		tc.writeFrames(packetTypeHandshake,
 			debugFrameCrypto{
diff --git a/internal/quic/conn_recv.go b/internal/quic/conn_recv.go
index e789ae0..1833167 100644
--- a/internal/quic/conn_recv.go
+++ b/internal/quic/conn_recv.go
@@ -41,9 +41,23 @@
 			c.handleVersionNegotiation(now, buf)
 			return
 		default:
-			return
+			n = -1
 		}
 		if n <= 0 {
+			// We don't expect to get a stateless reset with a valid
+			// destination connection ID, since the sender of a stateless
+			// reset doesn't know what the connection ID is.
+			//
+			// We're required to perform this check anyway.
+			//
+			// "[...] the comparison MUST be performed when the first packet
+			// in an incoming datagram [...] cannot be decrypted."
+			// https://www.rfc-editor.org/rfc/rfc9000#section-10.3.1-2
+			if len(buf) == len(dgram.b) && len(buf) > statelessResetTokenLen {
+				var token statelessResetToken
+				copy(token[:], buf[len(buf)-len(token):])
+				c.handleStatelessReset(token)
+			}
 			// Invalid data at the end of a datagram is ignored.
 			break
 		}
@@ -468,7 +482,7 @@
 	if n < 0 {
 		return -1
 	}
-	if err := c.connIDState.handleNewConnID(seq, retire, connID, resetToken); err != nil {
+	if err := c.connIDState.handleNewConnID(c, seq, retire, connID, resetToken); err != nil {
 		c.abort(now, err)
 	}
 	return n
@@ -515,3 +529,12 @@
 	}
 	return 1
 }
+
+var errStatelessReset = errors.New("received stateless reset")
+
+func (c *Conn) handleStatelessReset(resetToken statelessResetToken) {
+	if !c.connIDState.isValidStatelessResetToken(resetToken) {
+		return
+	}
+	c.enterDraining(errStatelessReset)
+}
diff --git a/internal/quic/conn_send.go b/internal/quic/conn_send.go
index efeb04f..f512518 100644
--- a/internal/quic/conn_send.go
+++ b/internal/quic/conn_send.go
@@ -250,7 +250,7 @@
 		}
 
 		// NEW_CONNECTION_ID, RETIRE_CONNECTION_ID
-		if !c.connIDState.appendFrames(&c.w, pnum, pto) {
+		if !c.connIDState.appendFrames(c, pnum, pto) {
 			return
 		}
 
diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go
index cfb0d06..df28907 100644
--- a/internal/quic/conn_test.go
+++ b/internal/quic/conn_test.go
@@ -190,7 +190,8 @@
 func newTestConn(t *testing.T, side connSide, opts ...any) *testConn {
 	t.Helper()
 	config := &Config{
-		TLSConfig: newTestTLSConfig(side),
+		TLSConfig:         newTestTLSConfig(side),
+		StatelessResetKey: testStatelessResetKey,
 	}
 	var configTransportParams []func(*transportParameters)
 	for _, o := range opts {
@@ -1041,6 +1042,13 @@
 	return []byte{0xbe, 0xee, 0xff, byte(seq)}
 }
 
+func testPeerStatelessResetToken(seq int64) statelessResetToken {
+	return statelessResetToken{
+		0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee,
+		0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, byte(seq),
+	}
+}
+
 // canceledContext returns a canceled Context.
 //
 // Functions which take a context preference progress over cancelation.
diff --git a/internal/quic/frame_debug.go b/internal/quic/frame_debug.go
index 7a5aee5..dc80090 100644
--- a/internal/quic/frame_debug.go
+++ b/internal/quic/frame_debug.go
@@ -368,7 +368,7 @@
 	seq           int64
 	retirePriorTo int64
 	connID        []byte
-	token         [16]byte
+	token         statelessResetToken
 }
 
 func parseDebugFrameNewConnectionID(b []byte) (f debugFrameNewConnectionID, n int) {
diff --git a/internal/quic/listener.go b/internal/quic/listener.go
index aa25839..668d270 100644
--- a/internal/quic/listener.go
+++ b/internal/quic/listener.go
@@ -8,6 +8,7 @@
 
 import (
 	"context"
+	"crypto/rand"
 	"errors"
 	"net"
 	"net/netip"
@@ -24,21 +25,16 @@
 	config    *Config
 	udpConn   udpConn
 	testHooks listenerTestHooks
+	resetGen  statelessResetTokenGenerator
 	retry     retryState
 
 	acceptQueue queue[*Conn] // new inbound connections
+	connsMap    connsMap     // only accessed by the listen loop
 
 	connsMu sync.Mutex
 	conns   map[*Conn]struct{}
 	closing bool          // set when Close is called
 	closec  chan struct{} // closed when the listen loop exits
-
-	// The datagram receive loop keeps a mapping of connection IDs to conns.
-	// When a conn's connection IDs change, we add it to connIDUpdates and set
-	// connIDUpdateNeeded, and the receive loop updates its map.
-	connIDUpdateMu     sync.Mutex
-	connIDUpdateNeeded atomic.Bool
-	connIDUpdates      []connIDUpdate
 }
 
 type listenerTestHooks interface {
@@ -55,12 +51,6 @@
 	WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error)
 }
 
-type connIDUpdate struct {
-	conn    *Conn
-	retired bool
-	cid     []byte
-}
-
 // Listen listens on a local network address.
 // The configuration config must be non-nil.
 func Listen(network, address string, config *Config) (*Listener, error) {
@@ -87,6 +77,8 @@
 		acceptQueue: newQueue[*Conn](),
 		closec:      make(chan struct{}),
 	}
+	l.resetGen.init(config.StatelessResetKey)
+	l.connsMap.init()
 	if config.RequireAddressValidation {
 		if err := l.retry.init(); err != nil {
 			return nil, err
@@ -181,6 +173,22 @@
 // connDrained is called by a conn when it leaves the draining state,
 // either when the peer acknowledges connection closure or the drain timeout expires.
 func (l *Listener) connDrained(c *Conn) {
+	var cids [][]byte
+	for i := range c.connIDState.local {
+		cids = append(cids, c.connIDState.local[i].cid)
+	}
+	var tokens []statelessResetToken
+	for i := range c.connIDState.remote {
+		tokens = append(tokens, c.connIDState.remote[i].resetToken)
+	}
+	l.connsMap.updateConnIDs(func(conns *connsMap) {
+		for _, cid := range cids {
+			conns.retireConnID(c, cid)
+		}
+		for _, token := range tokens {
+			conns.retireResetToken(c, token)
+		}
+	})
 	l.connsMu.Lock()
 	defer l.connsMu.Unlock()
 	delete(l.conns, c)
@@ -189,39 +197,8 @@
 	}
 }
 
-// connIDsChanged is called by a conn when its connection IDs change.
-func (l *Listener) connIDsChanged(c *Conn, retired bool, cids []connID) {
-	l.connIDUpdateMu.Lock()
-	defer l.connIDUpdateMu.Unlock()
-	for _, cid := range cids {
-		l.connIDUpdates = append(l.connIDUpdates, connIDUpdate{
-			conn:    c,
-			retired: retired,
-			cid:     cid.cid,
-		})
-	}
-	l.connIDUpdateNeeded.Store(true)
-}
-
-// updateConnIDs is called by the datagram receive loop to update its connection ID map.
-func (l *Listener) updateConnIDs(conns map[string]*Conn) {
-	l.connIDUpdateMu.Lock()
-	defer l.connIDUpdateMu.Unlock()
-	for i, u := range l.connIDUpdates {
-		if u.retired {
-			delete(conns, string(u.cid))
-		} else {
-			conns[string(u.cid)] = u.conn
-		}
-		l.connIDUpdates[i] = connIDUpdate{} // drop refs
-	}
-	l.connIDUpdates = l.connIDUpdates[:0]
-	l.connIDUpdateNeeded.Store(false)
-}
-
 func (l *Listener) listen() {
 	defer close(l.closec)
-	conns := map[string]*Conn{}
 	for {
 		m := newDatagram()
 		// TODO: Read and process the ECN (explicit congestion notification) field.
@@ -237,22 +214,22 @@
 		if n == 0 {
 			continue
 		}
-		if l.connIDUpdateNeeded.Load() {
-			l.updateConnIDs(conns)
+		if l.connsMap.updateNeeded.Load() {
+			l.connsMap.applyUpdates()
 		}
 		m.addr = addr
 		m.b = m.b[:n]
-		l.handleDatagram(m, conns)
+		l.handleDatagram(m)
 	}
 }
 
-func (l *Listener) handleDatagram(m *datagram, conns map[string]*Conn) {
+func (l *Listener) handleDatagram(m *datagram) {
 	dstConnID, ok := dstConnIDForDatagram(m.b)
 	if !ok {
 		m.recycle()
 		return
 	}
-	c := conns[string(dstConnID)]
+	c := l.connsMap.byConnID[string(dstConnID)]
 	if c == nil {
 		// TODO: Move this branch into a separate goroutine to avoid blocking
 		// the listener while processing packets.
@@ -271,18 +248,29 @@
 			m.recycle()
 		}
 	}()
-	if len(m.b) < minimumClientInitialDatagramSize {
+	const minimumValidPacketSize = 21
+	if len(m.b) < minimumValidPacketSize {
+		return
+	}
+	// Check to see if this is a stateless reset.
+	var token statelessResetToken
+	copy(token[:], m.b[len(m.b)-len(token):])
+	if c := l.connsMap.byResetToken[token]; c != nil {
+		c.sendMsg(func(now time.Time, c *Conn) {
+			c.handleStatelessReset(token)
+		})
+		return
+	}
+	// If this is a 1-RTT packet, there's nothing productive we can do with it.
+	// Send a stateless reset if possible.
+	if !isLongHeader(m.b[0]) {
+		l.maybeSendStatelessReset(m.b, m.addr)
 		return
 	}
 	p, ok := parseGenericLongHeaderPacket(m.b)
-	if !ok {
-		// Not a long header packet, or not parseable.
-		// Short header (1-RTT) packets don't contain enough information
-		// to do anything useful with if we don't recognize the
-		// connection ID.
+	if !ok || len(m.b) < minimumClientInitialDatagramSize {
 		return
 	}
-
 	switch p.version {
 	case quicVersion1:
 	case 0:
@@ -296,8 +284,9 @@
 	if getPacketType(m.b) != packetTypeInitial {
 		// This packet isn't trying to create a new connection.
 		// It might be associated with some connection we've lost state for.
-		// TODO: Send a stateless reset when appropriate.
-		// https://www.rfc-editor.org/rfc/rfc9000.html#section-10.3
+		// We are technically permitted to send a stateless reset for
+		// a long-header packet, but this isn't generally useful. See:
+		// https://www.rfc-editor.org/rfc/rfc9000#section-10.3-16
 		return
 	}
 	var now time.Time
@@ -330,6 +319,50 @@
 	m = nil // don't recycle, sendMsg takes ownership
 }
 
+func (l *Listener) maybeSendStatelessReset(b []byte, addr netip.AddrPort) {
+	if !l.resetGen.canReset {
+		// Config.StatelessResetKey isn't set, so we don't send stateless resets.
+		return
+	}
+	// The smallest possible valid packet a peer can send us is:
+	//   1 byte of header
+	//   connIDLen bytes of destination connection ID
+	//   1 byte of packet number
+	//   1 byte of payload
+	//   16 bytes AEAD expansion
+	if len(b) < 1+connIDLen+1+1+16 {
+		return
+	}
+	// TODO: Rate limit stateless resets.
+	cid := b[1:][:connIDLen]
+	token := l.resetGen.tokenForConnID(cid)
+	// We want to generate a stateless reset that is as short as possible,
+	// but long enough to be difficult to distinguish from a 1-RTT packet.
+	//
+	// The minimal 1-RTT packet is:
+	//   1 byte of header
+	//   0-20 bytes of destination connection ID
+	//   1-4 bytes of packet number
+	//   1 byte of payload
+	//   16 bytes AEAD expansion
+	//
+	// Assuming the maximum possible connection ID and packet number size,
+	// this gives 1 + 20 + 4 + 1 + 16 = 42 bytes.
+	//
+	// We also must generate a stateless reset that is shorter than the datagram
+	// we are responding to, in order to ensure that reset loops terminate.
+	//
+	// See: https://www.rfc-editor.org/rfc/rfc9000#section-10.3
+	size := min(len(b)-1, 42)
+	// Reuse the input buffer for generating the stateless reset.
+	b = b[:size]
+	rand.Read(b[:len(b)-statelessResetTokenLen])
+	b[0] &^= headerFormLong // clear long header bit
+	b[0] |= fixedBit        // set fixed bit
+	copy(b[len(b)-statelessResetTokenLen:], token[:])
+	l.sendDatagram(b, addr)
+}
+
 func (l *Listener) sendVersionNegotiation(p genericLongPacket, addr netip.AddrPort) {
 	m := newDatagram()
 	m.b = appendVersionNegotiation(m.b[:0], p.srcConnID, p.dstConnID, quicVersion1)
@@ -363,3 +396,53 @@
 	_, err := l.udpConn.WriteToUDPAddrPort(p, addr)
 	return err
 }
+
+// A connsMap is a listener's mapping of conn ids and reset tokens to conns.
+type connsMap struct {
+	byConnID     map[string]*Conn
+	byResetToken map[statelessResetToken]*Conn
+
+	updateMu     sync.Mutex
+	updateNeeded atomic.Bool
+	updates      []func(*connsMap)
+}
+
+func (m *connsMap) init() {
+	m.byConnID = map[string]*Conn{}
+	m.byResetToken = map[statelessResetToken]*Conn{}
+}
+
+func (m *connsMap) addConnID(c *Conn, cid []byte) {
+	m.byConnID[string(cid)] = c
+}
+
+func (m *connsMap) retireConnID(c *Conn, cid []byte) {
+	delete(m.byConnID, string(cid))
+}
+
+func (m *connsMap) addResetToken(c *Conn, token statelessResetToken) {
+	m.byResetToken[token] = c
+}
+
+func (m *connsMap) retireResetToken(c *Conn, token statelessResetToken) {
+	delete(m.byResetToken, token)
+}
+
+func (m *connsMap) updateConnIDs(f func(*connsMap)) {
+	m.updateMu.Lock()
+	defer m.updateMu.Unlock()
+	m.updates = append(m.updates, f)
+	m.updateNeeded.Store(true)
+}
+
+// applyConnIDUpdates is called by the datagram receive loop to update its connection ID map.
+func (m *connsMap) applyUpdates() {
+	m.updateMu.Lock()
+	defer m.updateMu.Unlock()
+	for _, f := range m.updates {
+		f(m)
+	}
+	clear(m.updates)
+	m.updates = m.updates[:0]
+	m.updateNeeded.Store(false)
+}
diff --git a/internal/quic/packet_parser.go b/internal/quic/packet_parser.go
index 8bcd866..02ef9fb 100644
--- a/internal/quic/packet_parser.go
+++ b/internal/quic/packet_parser.go
@@ -420,32 +420,32 @@
 	return typ, max, n
 }
 
-func consumeNewConnectionIDFrame(b []byte) (seq, retire int64, connID []byte, resetToken [16]byte, n int) {
+func consumeNewConnectionIDFrame(b []byte) (seq, retire int64, connID []byte, resetToken statelessResetToken, n int) {
 	n = 1
 	var nn int
 	seq, nn = consumeVarintInt64(b[n:])
 	if nn < 0 {
-		return 0, 0, nil, [16]byte{}, -1
+		return 0, 0, nil, statelessResetToken{}, -1
 	}
 	n += nn
 	retire, nn = consumeVarintInt64(b[n:])
 	if nn < 0 {
-		return 0, 0, nil, [16]byte{}, -1
+		return 0, 0, nil, statelessResetToken{}, -1
 	}
 	n += nn
 	if seq < retire {
-		return 0, 0, nil, [16]byte{}, -1
+		return 0, 0, nil, statelessResetToken{}, -1
 	}
 	connID, nn = consumeVarintBytes(b[n:])
 	if nn < 0 {
-		return 0, 0, nil, [16]byte{}, -1
+		return 0, 0, nil, statelessResetToken{}, -1
 	}
 	if len(connID) < 1 || len(connID) > 20 {
-		return 0, 0, nil, [16]byte{}, -1
+		return 0, 0, nil, statelessResetToken{}, -1
 	}
 	n += nn
 	if len(b[n:]) < len(resetToken) {
-		return 0, 0, nil, [16]byte{}, -1
+		return 0, 0, nil, statelessResetToken{}, -1
 	}
 	copy(resetToken[:], b[n:])
 	n += len(resetToken)
diff --git a/internal/quic/stateless_reset.go b/internal/quic/stateless_reset.go
new file mode 100644
index 0000000..53c3ba5
--- /dev/null
+++ b/internal/quic/stateless_reset.go
@@ -0,0 +1,61 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+	"crypto/hmac"
+	"crypto/rand"
+	"crypto/sha256"
+	"hash"
+	"sync"
+)
+
+const statelessResetTokenLen = 128 / 8
+
+// A statelessResetToken is a stateless reset token.
+// https://www.rfc-editor.org/rfc/rfc9000#section-10.3
+type statelessResetToken [statelessResetTokenLen]byte
+
+type statelessResetTokenGenerator struct {
+	canReset bool
+
+	// The hash.Hash interface is not concurrency safe,
+	// so we need a mutex here.
+	//
+	// There shouldn't be much contention on stateless reset token generation.
+	// If this proves to be a problem, we could avoid the mutex by using a separate
+	// generator per Conn, or by using a concurrency-safe generator.
+	mu  sync.Mutex
+	mac hash.Hash
+}
+
+func (g *statelessResetTokenGenerator) init(secret [32]byte) {
+	zero := true
+	for _, b := range secret {
+		if b != 0 {
+			zero = false
+			break
+		}
+	}
+	if zero {
+		// Generate tokens using a random secret, but don't send stateless resets.
+		rand.Read(secret[:])
+		g.canReset = false
+	} else {
+		g.canReset = true
+	}
+	g.mac = hmac.New(sha256.New, secret[:])
+}
+
+func (g *statelessResetTokenGenerator) tokenForConnID(cid []byte) (token statelessResetToken) {
+	g.mu.Lock()
+	defer g.mu.Unlock()
+	defer g.mac.Reset()
+	g.mac.Write(cid)
+	copy(token[:], g.mac.Sum(nil))
+	return token
+}
diff --git a/internal/quic/stateless_reset_test.go b/internal/quic/stateless_reset_test.go
new file mode 100644
index 0000000..b12e975
--- /dev/null
+++ b/internal/quic/stateless_reset_test.go
@@ -0,0 +1,277 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+	"bytes"
+	"context"
+	"crypto/rand"
+	"crypto/tls"
+	"errors"
+	"net/netip"
+	"testing"
+)
+
+func TestStatelessResetClientSendsStatelessResetTokenTransportParameter(t *testing.T) {
+	// "[The stateless_reset_token] transport parameter MUST NOT be sent by a client [...]"
+	// https://www.rfc-editor.org/rfc/rfc9000#section-18.2-4.6.1
+	resetToken := testPeerStatelessResetToken(0)
+	tc := newTestConn(t, serverSide, func(p *transportParameters) {
+		p.statelessResetToken = resetToken[:]
+	})
+	tc.writeFrames(packetTypeInitial,
+		debugFrameCrypto{
+			data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial],
+		})
+	tc.wantFrame("client provided stateless_reset_token transport parameter",
+		packetTypeInitial, debugFrameConnectionCloseTransport{
+			code: errTransportParameter,
+		})
+}
+
+var testStatelessResetKey = func() (key [32]byte) {
+	if _, err := rand.Read(key[:]); err != nil {
+		panic(err)
+	}
+	return key
+}()
+
+func testStatelessResetToken(cid []byte) statelessResetToken {
+	var gen statelessResetTokenGenerator
+	gen.init(testStatelessResetKey)
+	return gen.tokenForConnID(cid)
+}
+
+func testLocalStatelessResetToken(seq int64) statelessResetToken {
+	return testStatelessResetToken(testLocalConnID(seq))
+}
+
+func newDatagramForReset(cid []byte, size int, addr netip.AddrPort) *datagram {
+	dgram := append([]byte{headerFormShort | fixedBit}, cid...)
+	for len(dgram) < size {
+		dgram = append(dgram, byte(len(dgram))) // semi-random junk
+	}
+	return &datagram{
+		b:    dgram,
+		addr: addr,
+	}
+}
+
+func TestStatelessResetSentSizes(t *testing.T) {
+	config := &Config{
+		TLSConfig:         newTestTLSConfig(serverSide),
+		StatelessResetKey: testStatelessResetKey,
+	}
+	addr := netip.MustParseAddr("127.0.0.1")
+	tl := newTestListener(t, config)
+	for i, test := range []struct {
+		reqSize  int
+		wantSize int
+	}{{
+		// Datagrams larger than 42 bytes result in a 42-byte stateless reset.
+		// This isn't specifically mandated by RFC 9000, but is implied.
+		// https://www.rfc-editor.org/rfc/rfc9000#section-10.3-11
+		reqSize:  1200,
+		wantSize: 42,
+	}, {
+		// "An endpoint that sends a Stateless Reset in response to a packet
+		// that is 43 bytes or shorter SHOULD send a Stateless Reset that is
+		// one byte shorter than the packet it responds to."
+		// https://www.rfc-editor.org/rfc/rfc9000#section-10.3-11
+		reqSize:  43,
+		wantSize: 42,
+	}, {
+		reqSize:  42,
+		wantSize: 41,
+	}, {
+		// We should send a stateless reset in response to the smallest possible
+		// valid datagram the peer can send us.
+		// The smallest packet is 1-RTT:
+		// header byte, conn id, packet num, payload, AEAD.
+		reqSize:  1 + connIDLen + 1 + 1 + 16,
+		wantSize: 1 + connIDLen + 1 + 1 + 16 - 1,
+	}, {
+		// The smallest possible stateless reset datagram is 21 bytes.
+		// Since our response must be smaller than the incoming datagram,
+		// we must not respond to a 21 byte or smaller packet.
+		reqSize:  21,
+		wantSize: 0,
+	}} {
+		cid := testLocalConnID(int64(i))
+		token := testStatelessResetToken(cid)
+		addrport := netip.AddrPortFrom(addr, uint16(8000+i))
+		tl.write(newDatagramForReset(cid, test.reqSize, addrport))
+
+		got := tl.read()
+		if len(got) != test.wantSize {
+			t.Errorf("got %v-byte response to %v-byte req, want %v",
+				len(got), test.reqSize, test.wantSize)
+		}
+		if len(got) == 0 {
+			continue
+		}
+		// "Endpoints MUST send Stateless Resets formatted as
+		// a packet with a short header."
+		// https://www.rfc-editor.org/rfc/rfc9000#section-10.3-15
+		if isLongHeader(got[0]) {
+			t.Errorf("response to %v-byte request is not a short-header packet\ngot: %x", test.reqSize, got)
+		}
+		if !bytes.HasSuffix(got, token[:]) {
+			t.Errorf("response to %v-byte request does not end in stateless reset token\ngot: %x\nwant suffix: %x", test.reqSize, got, token)
+		}
+	}
+}
+
+func TestStatelessResetSuccessfulNewConnectionID(t *testing.T) {
+	// "[...] Stateless Reset Token field values from [...] NEW_CONNECTION_ID frames [...]"
+	// https://www.rfc-editor.org/rfc/rfc9000#section-10.3.1-1
+	tc := newTestConn(t, clientSide)
+	tc.handshake()
+	tc.ignoreFrame(frameTypeAck)
+
+	// Retire connection ID 0.
+	tc.writeFrames(packetType1RTT,
+		debugFrameNewConnectionID{
+			retirePriorTo: 1,
+			seq:           2,
+			connID:        testPeerConnID(2),
+		})
+	tc.wantFrame("peer requested we retire conn id 0",
+		packetType1RTT, debugFrameRetireConnectionID{
+			seq: 0,
+		})
+
+	resetToken := testPeerStatelessResetToken(1) // provided during handshake
+	dgram := append(make([]byte, 100), resetToken[:]...)
+	tc.listener.write(&datagram{
+		b: dgram,
+	})
+
+	if err := tc.conn.Wait(canceledContext()); !errors.Is(err, errStatelessReset) {
+		t.Errorf("conn.Wait() = %v, want errStatelessReset", err)
+	}
+	tc.wantIdle("closed connection is idle")
+}
+
+func TestStatelessResetSuccessfulTransportParameter(t *testing.T) {
+	// "[...] Stateless Reset Token field values from [...]
+	// the server's transport parameters [...]"
+	// https://www.rfc-editor.org/rfc/rfc9000#section-10.3.1-1
+	resetToken := testPeerStatelessResetToken(0)
+	tc := newTestConn(t, clientSide, func(p *transportParameters) {
+		p.statelessResetToken = resetToken[:]
+	})
+	tc.handshake()
+
+	dgram := append(make([]byte, 100), resetToken[:]...)
+	tc.listener.write(&datagram{
+		b: dgram,
+	})
+
+	if err := tc.conn.Wait(canceledContext()); !errors.Is(err, errStatelessReset) {
+		t.Errorf("conn.Wait() = %v, want errStatelessReset", err)
+	}
+	tc.wantIdle("closed connection is idle")
+}
+
+func TestStatelessResetSuccessfulPrefix(t *testing.T) {
+	for _, test := range []struct {
+		name   string
+		prefix []byte
+		size   int
+	}{{
+		name: "short header and fixed bit",
+		prefix: []byte{
+			headerFormShort | fixedBit,
+		},
+		size: 100,
+	}, {
+		// "[...] endpoints MUST treat [long header packets] ending in a
+		// valid stateless reset token as a Stateless Reset [...]"
+		// https://www.rfc-editor.org/rfc/rfc9000#section-10.3-15
+		name: "long header no fixed bit",
+		prefix: []byte{
+			headerFormLong,
+		},
+		size: 100,
+	}, {
+		// "[...] the comparison MUST be performed when the first packet
+		// in an incoming datagram [...] cannot be decrypted."
+		// https://www.rfc-editor.org/rfc/rfc9000#section-10.3.1-2
+		name: "short header valid DCID",
+		prefix: append([]byte{
+			headerFormShort | fixedBit,
+		}, testLocalConnID(0)...),
+		size: 100,
+	}, {
+		name: "handshake valid DCID",
+		prefix: append([]byte{
+			headerFormLong | fixedBit | longPacketTypeHandshake,
+		}, testLocalConnID(0)...),
+		size: 100,
+	}, {
+		name: "no fixed bit valid DCID",
+		prefix: append([]byte{
+			0,
+		}, testLocalConnID(0)...),
+		size: 100,
+	}} {
+		t.Run(test.name, func(t *testing.T) {
+			resetToken := testPeerStatelessResetToken(0)
+			tc := newTestConn(t, clientSide, func(p *transportParameters) {
+				p.statelessResetToken = resetToken[:]
+			})
+			tc.handshake()
+
+			dgram := test.prefix
+			for len(dgram) < test.size-len(resetToken) {
+				dgram = append(dgram, byte(len(dgram))) // semi-random junk
+			}
+			dgram = append(dgram, resetToken[:]...)
+			tc.listener.write(&datagram{
+				b: dgram,
+			})
+			if err := tc.conn.Wait(canceledContext()); !errors.Is(err, errStatelessReset) {
+				t.Errorf("conn.Wait() = %v, want errStatelessReset", err)
+			}
+		})
+	}
+}
+
+func TestStatelessResetRetiredConnID(t *testing.T) {
+	// "An endpoint MUST NOT check for any stateless reset tokens [...]
+	// for connection IDs that have been retired."
+	// https://www.rfc-editor.org/rfc/rfc9000#section-10.3.1-3
+	resetToken := testPeerStatelessResetToken(0)
+	tc := newTestConn(t, clientSide, func(p *transportParameters) {
+		p.statelessResetToken = resetToken[:]
+	})
+	tc.handshake()
+	tc.ignoreFrame(frameTypeAck)
+
+	// We retire connection ID 0.
+	tc.writeFrames(packetType1RTT,
+		debugFrameNewConnectionID{
+			seq:           2,
+			retirePriorTo: 1,
+			connID:        testPeerConnID(2),
+		})
+	tc.wantFrame("peer asked for conn id 0 to be retired",
+		packetType1RTT, debugFrameRetireConnectionID{
+			seq: 0,
+		})
+
+	// Receive a stateless reset for connection ID 0.
+	dgram := append(make([]byte, 100), resetToken[:]...)
+	tc.listener.write(&datagram{
+		b: dgram,
+	})
+
+	if err := tc.conn.Wait(canceledContext()); !errors.Is(err, context.Canceled) {
+		t.Errorf("conn.Wait() = %v, want connection to be alive", err)
+	}
+}
diff --git a/internal/quic/tls_test.go b/internal/quic/tls_test.go
index 337657e..6f4e065 100644
--- a/internal/quic/tls_test.go
+++ b/internal/quic/tls_test.go
@@ -71,9 +71,11 @@
 
 func handshakeDatagrams(tc *testConn) (dgrams []*testDatagram) {
 	var (
-		clientConnIDs   [][]byte
-		serverConnIDs   [][]byte
-		transientConnID []byte
+		clientConnIDs    [][]byte
+		serverConnIDs    [][]byte
+		clientResetToken statelessResetToken
+		serverResetToken statelessResetToken
+		transientConnID  []byte
 	)
 	localConnIDs := [][]byte{
 		testLocalConnID(0),
@@ -83,13 +85,19 @@
 		testPeerConnID(0),
 		testPeerConnID(1),
 	}
+	localResetToken := tc.listener.l.resetGen.tokenForConnID(localConnIDs[1])
+	peerResetToken := testPeerStatelessResetToken(1)
 	if tc.conn.side == clientSide {
 		clientConnIDs = localConnIDs
 		serverConnIDs = peerConnIDs
+		clientResetToken = localResetToken
+		serverResetToken = peerResetToken
 		transientConnID = testLocalConnID(-1)
 	} else {
 		clientConnIDs = peerConnIDs
 		serverConnIDs = localConnIDs
+		clientResetToken = peerResetToken
+		serverResetToken = localResetToken
 		transientConnID = testPeerConnID(-1)
 	}
 	return []*testDatagram{{
@@ -136,6 +144,7 @@
 				debugFrameNewConnectionID{
 					seq:    1,
 					connID: serverConnIDs[1],
+					token:  serverResetToken,
 				},
 			},
 		}},
@@ -175,6 +184,7 @@
 				debugFrameNewConnectionID{
 					seq:    1,
 					connID: clientConnIDs[1],
+					token:  clientResetToken,
 				},
 			},
 		}},
@@ -337,6 +347,7 @@
 		packetType1RTT, debugFrameNewConnectionID{
 			seq:    1,
 			connID: testLocalConnID(1),
+			token:  testLocalStatelessResetToken(1),
 		})
 
 	// The client discards Initial keys after sending a Handshake packet.
@@ -390,6 +401,7 @@
 		packetType1RTT, debugFrameNewConnectionID{
 			seq:    1,
 			connID: testLocalConnID(1),
+			token:  testLocalStatelessResetToken(1),
 		})
 	tc.wantIdle("server has discarded Initial keys, cannot read CONNECTION_CLOSE")
 
@@ -546,7 +558,9 @@
 	// exceeds the integrity limit for the selected AEAD,
 	// the endpoint MUST immediately close the connection [...]"
 	// https://www.rfc-editor.org/rfc/rfc9001#section-6.6-6
-	tc := newTestConn(t, clientSide)
+	tc := newTestConn(t, clientSide, func(c *Config) {
+		clear(c.StatelessResetKey[:])
+	})
 	tc.handshake()
 
 	var limit int64