quic: support stateless reset
Add a StatelessResetKey config field to permit
generating consistent stateless reset tokens
across server restarts.
Set the stateless_reset_token transport parameter
and populate the Token field in NEW_CONNECTION_ID
frames.
Detect reset tokens in datagrams which cannot
be associated with a connection or cannot be parsed.
RFC 9000, Section 10.3.
For golang/go#58547
Change-Id: Idb52ba07092ab5c08b323d6b531964a7e7e5ecea
Reviewed-on: https://go-review.googlesource.com/c/net/+/536315
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
Auto-Submit: Damien Neil <dneil@google.com>
diff --git a/internal/quic/conn_id_test.go b/internal/quic/conn_id_test.go
index 784c5e2..63feec9 100644
--- a/internal/quic/conn_id_test.go
+++ b/internal/quic/conn_id_test.go
@@ -47,12 +47,14 @@
if got := tc.conn.connIDState.local; !connIDListEqual(got, wantLocal) {
t.Errorf("local ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantLocal))
}
- wantRemote := []connID{{
- cid: testPeerConnID(0),
- seq: 0,
+ wantRemote := []remoteConnID{{
+ connID: connID{
+ cid: testPeerConnID(0),
+ seq: 0,
+ },
}}
- if got := tc.conn.connIDState.remote; !connIDListEqual(got, wantRemote) {
- t.Errorf("remote ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantRemote))
+ if got := tc.conn.connIDState.remote; !remoteConnIDListEqual(got, wantRemote) {
+ t.Errorf("remote ids: %v, want %v", fmtRemoteConnIDList(got), fmtRemoteConnIDList(wantRemote))
}
}
@@ -93,12 +95,14 @@
if got := tc.conn.connIDState.local; !connIDListEqual(got, wantLocal) {
t.Errorf("local ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantLocal))
}
- wantRemote := []connID{{
- cid: testPeerConnID(0),
- seq: 0,
+ wantRemote := []remoteConnID{{
+ connID: connID{
+ cid: testPeerConnID(0),
+ seq: 0,
+ },
}}
- if got := tc.conn.connIDState.remote; !connIDListEqual(got, wantRemote) {
- t.Errorf("remote ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantRemote))
+ if got := tc.conn.connIDState.remote; !remoteConnIDListEqual(got, wantRemote) {
+ t.Errorf("remote ids: %v, want %v", fmtRemoteConnIDList(got), fmtRemoteConnIDList(wantRemote))
}
// The client's first Handshake packet permits the server to discard the
@@ -134,6 +138,24 @@
return true
}
+func remoteConnIDListEqual(a, b []remoteConnID) bool {
+ if len(a) != len(b) {
+ return false
+ }
+ for i := range a {
+ if a[i].seq != b[i].seq {
+ return false
+ }
+ if !bytes.Equal(a[i].cid, b[i].cid) {
+ return false
+ }
+ if a[i].resetToken != b[i].resetToken {
+ return false
+ }
+ }
+ return true
+}
+
func fmtConnIDList(s []connID) string {
var strs []string
for _, cid := range s {
@@ -142,6 +164,14 @@
return "{" + strings.Join(strs, " ") + "}"
}
+func fmtRemoteConnIDList(s []remoteConnID) string {
+ var strs []string
+ for _, cid := range s {
+ strs = append(strs, fmt.Sprintf("[seq:%v cid:{%x} token:{%x}]", cid.seq, cid.cid, cid.resetToken))
+ }
+ return "{" + strings.Join(strs, " ") + "}"
+}
+
func TestNewRandomConnID(t *testing.T) {
cid, err := newRandomConnID(0)
if len(cid) != connIDLen || err != nil {
@@ -174,16 +204,19 @@
packetType1RTT, debugFrameNewConnectionID{
seq: 1,
connID: testLocalConnID(1),
+ token: testLocalStatelessResetToken(1),
})
tc.wantFrame("provide additional connection ID 2",
packetType1RTT, debugFrameNewConnectionID{
seq: 2,
connID: testLocalConnID(2),
+ token: testLocalStatelessResetToken(2),
})
tc.wantFrame("provide additional connection ID 3",
packetType1RTT, debugFrameNewConnectionID{
seq: 3,
connID: testLocalConnID(3),
+ token: testLocalStatelessResetToken(3),
})
tc.wantIdle("connection ID limit reached, no more to provide")
}
@@ -255,6 +288,7 @@
seq: 2,
retirePriorTo: 1,
connID: testLocalConnID(2),
+ token: testLocalStatelessResetToken(2),
})
})
}
@@ -455,6 +489,7 @@
retirePriorTo: 1,
seq: 2,
connID: testLocalConnID(2),
+ token: testLocalStatelessResetToken(2),
})
tc.wantIdle("repeated RETIRE_CONNECTION_ID frames are not an error")
}
@@ -583,3 +618,46 @@
})
})
}
+
+func TestConnIDsCleanedUpAfterClose(t *testing.T) {
+ testSides(t, "", func(t *testing.T, side connSide) {
+ tc := newTestConn(t, side, func(p *transportParameters) {
+ if side == clientSide {
+ token := testPeerStatelessResetToken(0)
+ p.statelessResetToken = token[:]
+ }
+ })
+ tc.handshake()
+ tc.ignoreFrame(frameTypeAck)
+ tc.writeFrames(packetType1RTT,
+ debugFrameNewConnectionID{
+ seq: 2,
+ retirePriorTo: 1,
+ connID: testPeerConnID(2),
+ token: testPeerStatelessResetToken(0),
+ })
+ tc.wantFrame("peer asked for conn id 0 to be retired",
+ packetType1RTT, debugFrameRetireConnectionID{
+ seq: 0,
+ })
+ tc.writeFrames(packetType1RTT, debugFrameConnectionCloseTransport{})
+ tc.conn.Abort(nil)
+ tc.wantFrame("CONN_CLOSE sent after user closes connection",
+ packetType1RTT, debugFrameConnectionCloseTransport{})
+
+ // Wait for the conn to drain.
+ // Then wait for the conn loop to exit,
+ // and force an immediate sync of the connsMap updates
+ // (normally only done by the listener read loop).
+ tc.advanceToTimer()
+ <-tc.conn.donec
+ tc.listener.l.connsMap.applyUpdates()
+
+ if got := len(tc.listener.l.connsMap.byConnID); got != 0 {
+ t.Errorf("%v conn ids in listener map after closing, want 0", got)
+ }
+ if got := len(tc.listener.l.connsMap.byResetToken); got != 0 {
+ t.Errorf("%v reset tokens in listener map after closing, want 0", got)
+ }
+ })
+}