quic: handle peer-initiated key updates
RFC 9001, Section 6.
For golang/go#58547
Change-Id: I3700043d27ab41536521b547ecf5e632a08eb1b5
Reviewed-on: https://go-review.googlesource.com/c/net/+/528835
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/internal/quic/conn.go b/internal/quic/conn.go
index 4565e1a..dc3a985 100644
--- a/internal/quic/conn.go
+++ b/internal/quic/conn.go
@@ -44,7 +44,7 @@
// Packet protection keys, CRYPTO streams, and TLS state.
keysInitial fixedKeyPair
keysHandshake fixedKeyPair
- keysAppData fixedKeyPair
+ keysAppData updatingKeyPair
crypto [numberSpaceCount]cryptoStream
tls *tls.QUICConn
diff --git a/internal/quic/conn_recv.go b/internal/quic/conn_recv.go
index d1fa52d..4fc4eec 100644
--- a/internal/quic/conn_recv.go
+++ b/internal/quic/conn_recv.go
@@ -89,7 +89,7 @@
}
pnumMax := c.acks[appDataSpace].largestSeen()
- p, n := parse1RTTPacket(buf, c.keysAppData.r, connIDLen, pnumMax)
+ p, n := parse1RTTPacket(buf, &c.keysAppData, connIDLen, pnumMax)
if n < 0 {
return -1
}
@@ -247,7 +247,7 @@
func (c *Conn) handleAckFrame(now time.Time, space numberSpace, payload []byte) int {
c.loss.receiveAckStart()
- _, ackDelay, n := consumeAckFrame(payload, func(rangeIndex int, start, end packetNumber) {
+ largest, ackDelay, n := consumeAckFrame(payload, func(rangeIndex int, start, end packetNumber) {
if end > c.loss.nextNumber(space) {
// Acknowledgement of a packet we never sent.
c.abort(now, localTransportError(errProtocolViolation))
@@ -280,6 +280,9 @@
delay = ackDelay.Duration(uint8(c.peerAckDelayExponent))
}
c.loss.receiveAckEnd(now, space, delay, c.handleAckOrLoss)
+ if space == appDataSpace {
+ c.keysAppData.handleAckFor(largest)
+ }
return n
}
diff --git a/internal/quic/conn_send.go b/internal/quic/conn_send.go
index 58a3df1..63f65b5 100644
--- a/internal/quic/conn_send.go
+++ b/internal/quic/conn_send.go
@@ -128,7 +128,7 @@
if logPackets {
logSentPacket(c, packetType1RTT, pnum, nil, dstConnID, c.w.payload())
}
- if sent := c.w.finish1RTTPacket(pnum, pnumMaxAcked, dstConnID, c.keysAppData.w); sent != nil {
+ if sent := c.w.finish1RTTPacket(pnum, pnumMaxAcked, dstConnID, &c.keysAppData); sent != nil {
c.loss.packetSent(now, appDataSpace, sent)
}
}
@@ -197,16 +197,23 @@
// All frames other than ACK and PADDING are ack-eliciting,
// so if the packet is ack-eliciting we've added additional
// frames to it.
- if shouldSendAck || c.w.sent.ackEliciting {
- // Either we are willing to send an ACK-only packet,
- // or we've added additional frames.
- c.acks[space].sentAck()
- } else {
+ if !shouldSendAck && !c.w.sent.ackEliciting {
// There's nothing in this packet but ACK frames, and
// we don't want to send an ACK-only packet at this time.
// Abandoning the packet means we wrote an ACK frame for
// nothing, but constructing the frame is cheap.
c.w.abandonPacket()
+ return
+ }
+ // Either we are willing to send an ACK-only packet,
+ // or we've added additional frames.
+ c.acks[space].sentAck()
+ if !c.w.sent.ackEliciting && c.keysAppData.needAckEliciting() {
+ // The peer has initiated a key update.
+ // We haven't sent them any packets yet in the new phase.
+ // Make this an ack-eliciting packet.
+ // Their ack of this packet will complete the key update.
+ c.w.appendPingFrame()
}
}()
}
diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go
index 3fef62d..76774cc 100644
--- a/internal/quic/conn_test.go
+++ b/internal/quic/conn_test.go
@@ -76,12 +76,14 @@
}
type testPacket struct {
- ptype packetType
- version uint32
- num packetNumber
- dstConnID []byte
- srcConnID []byte
- frames []debugFrame
+ ptype packetType
+ version uint32
+ num packetNumber
+ keyPhaseBit bool
+ keyNumber int
+ dstConnID []byte
+ srcConnID []byte
+ frames []debugFrame
}
func (p testPacket) String() string {
@@ -102,6 +104,9 @@
return b.String()
}
+// maxTestKeyPhases is the maximum number of 1-RTT keys we'll generate in a test.
+const maxTestKeyPhases = 3
+
// A testConn is a Conn whose external interactions (sending and receiving packets,
// setting timers) can be manipulated in tests.
type testConn struct {
@@ -122,9 +127,10 @@
// the Initial packet.
keysInitial fixedKeyPair
keysHandshake fixedKeyPair
- keysAppData fixedKeyPair
- rsecrets [numberSpaceCount]testKeySecret
- wsecrets [numberSpaceCount]testKeySecret
+ rkeyAppData test1RTTKeys
+ wkeyAppData test1RTTKeys
+ rsecrets [numberSpaceCount]keySecret
+ wsecrets [numberSpaceCount]keySecret
// testConn uses a test hook to snoop on the conn's TLS events.
// CRYPTO data produced by the conn's QUICConn is placed in
@@ -156,10 +162,19 @@
// Frame types to ignore in tests.
ignoreFrames map[byte]bool
+ // Values to set in packets sent to the conn.
+ sendKeyNumber int
+ sendKeyPhaseBit bool
+
asyncTestState
}
-type testKeySecret struct {
+type test1RTTKeys struct {
+ hdr headerKey
+ pkt [maxTestKeyPhases]packetKey
+}
+
+type keySecret struct {
suite uint16
secret []byte
}
@@ -333,12 +348,20 @@
}
tc.t.Logf("%v datagram%v", text, pad)
for _, p := range d.packets {
+ var s string
switch p.ptype {
case packetType1RTT:
- tc.t.Logf(" %v pnum=%v", p.ptype, p.num)
+ s = fmt.Sprintf(" %v pnum=%v", p.ptype, p.num)
default:
- tc.t.Logf(" %v pnum=%v ver=%v dst={%x} src={%x}", p.ptype, p.num, p.version, p.dstConnID, p.srcConnID)
+ s = fmt.Sprintf(" %v pnum=%v ver=%v dst={%x} src={%x}", p.ptype, p.num, p.version, p.dstConnID, p.srcConnID)
}
+ if p.keyPhaseBit {
+ s += fmt.Sprintf(" KeyPhase")
+ }
+ if p.keyNumber != 0 {
+ s += fmt.Sprintf(" keynum=%v", p.keyNumber)
+ }
+ tc.t.Log(s)
for _, f := range p.frames {
tc.t.Logf(" %v", f)
}
@@ -381,12 +404,14 @@
}
d := &testDatagram{
packets: []*testPacket{{
- ptype: ptype,
- num: tc.peerNextPacketNum[space],
- frames: frames,
- version: 1,
- dstConnID: dstConnID,
- srcConnID: tc.peerConnID,
+ ptype: ptype,
+ num: tc.peerNextPacketNum[space],
+ keyNumber: tc.sendKeyNumber,
+ keyPhaseBit: tc.sendKeyPhaseBit,
+ frames: frames,
+ version: 1,
+ dstConnID: dstConnID,
+ srcConnID: tc.peerConnID,
}},
}
if ptype == packetTypeInitial && tc.conn.side == serverSide {
@@ -580,6 +605,22 @@
}
}
+// wantFrameType indicates that we expect the Conn to send a frame,
+// although we don't care about the contents.
+func (tc *testConn) wantFrameType(expectation string, wantType packetType, want debugFrame) {
+ tc.t.Helper()
+ got, gotType := tc.readFrame()
+ if got == nil {
+ tc.t.Fatalf("%v:\nconnection is idle\nwant %v frame: %v", expectation, wantType, want)
+ }
+ if gotType != wantType {
+ tc.t.Fatalf("%v:\ngot %v packet, want %v\ngot frame: %v", expectation, gotType, wantType, got)
+ }
+ if reflect.TypeOf(got) != reflect.TypeOf(want) {
+ tc.t.Fatalf("%v:\ngot frame: %v\nwant frame of type: %v", expectation, got, want)
+ }
+}
+
// wantIdle indicates that we expect the Conn to not send any more frames.
func (tc *testConn) wantIdle(expectation string) {
tc.t.Helper()
@@ -615,17 +656,17 @@
}
w.appendPaddingTo(pad)
if p.ptype != packetType1RTT {
- var k fixedKeyPair
+ var k fixedKeys
switch p.ptype {
case packetTypeInitial:
- k = tc.keysInitial
+ k = tc.keysInitial.w
case packetTypeHandshake:
- k = tc.keysHandshake
+ k = tc.keysHandshake.w
}
- if !k.canWrite() {
+ if !k.isSet() {
tc.t.Fatalf("sending %v packet with no write key", p.ptype)
}
- w.finishProtectedLongHeaderPacket(pnumMaxAcked, k.w, longPacket{
+ w.finishProtectedLongHeaderPacket(pnumMaxAcked, k, longPacket{
ptype: p.ptype,
version: p.version,
num: p.num,
@@ -633,10 +674,24 @@
srcConnID: p.srcConnID,
})
} else {
- if !tc.keysAppData.canWrite() {
- tc.t.Fatalf("sending %v packet with no write key", p.ptype)
+ if !tc.wkeyAppData.hdr.isSet() {
+ tc.t.Fatalf("sending 1-RTT packet with no write key")
}
- w.finish1RTTPacket(p.num, pnumMaxAcked, p.dstConnID, tc.keysAppData.w)
+ // Somewhat hackish: Generate a temporary updatingKeyPair that will
+ // always use our desired key phase.
+ k := &updatingKeyPair{
+ w: updatingKeys{
+ hdr: tc.wkeyAppData.hdr,
+ pkt: [2]packetKey{
+ tc.wkeyAppData.pkt[p.keyNumber],
+ tc.wkeyAppData.pkt[p.keyNumber],
+ },
+ },
+ }
+ if p.keyPhaseBit {
+ k.phase |= keyPhaseBit
+ }
+ w.finish1RTTPacket(p.num, pnumMaxAcked, p.dstConnID, k)
}
return w.datagram()
}
@@ -682,25 +737,45 @@
})
buf = buf[n:]
} else {
- if !tc.keysAppData.canRead() {
+ if !tc.rkeyAppData.hdr.isSet() {
tc.t.Fatalf("reading 1-RTT packet with no read key")
}
var pnumMax packetNumber // TODO: Track packet numbers.
- p, n := parse1RTTPacket(buf, tc.keysAppData.r, len(tc.peerConnID), pnumMax)
- if n < 0 {
- tc.t.Fatalf("packet parse error")
+ pnumOff := 1 + len(tc.peerConnID)
+ // Try unprotecting the packet with the first maxTestKeyPhases keys.
+ var phase int
+ var pnum packetNumber
+ var hdr []byte
+ var pay []byte
+ var err error
+ for phase = 0; phase < maxTestKeyPhases; phase++ {
+ b := append([]byte{}, buf...)
+ hdr, pay, pnum, err = tc.rkeyAppData.hdr.unprotect(b, pnumOff, pnumMax)
+ if err != nil {
+ tc.t.Fatalf("1-RTT packet header parse error")
+ }
+ k := tc.rkeyAppData.pkt[phase]
+ pay, err = k.unprotect(hdr, pay, pnum)
+ if err == nil {
+ break
+ }
}
- frames, err := tc.parseTestFrames(p.payload)
+ if err != nil {
+ tc.t.Fatalf("1-RTT packet payload parse error")
+ }
+ frames, err := tc.parseTestFrames(pay)
if err != nil {
tc.t.Fatal(err)
}
d.packets = append(d.packets, &testPacket{
- ptype: packetType1RTT,
- num: p.num,
- dstConnID: buf[1:][:len(tc.peerConnID)],
- frames: frames,
+ ptype: packetType1RTT,
+ num: pnum,
+ dstConnID: hdr[1:][:len(tc.peerConnID)],
+ keyPhaseBit: hdr[0]&keyPhaseBit != 0,
+ keyNumber: phase,
+ frames: frames,
})
- buf = buf[n:]
+ buf = buf[len(buf):]
}
}
// This is rather hackish: If the last frame in the last packet
@@ -766,7 +841,7 @@
// and verify that both sides of the connection are getting
// matching keys.
func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) {
- checkKey := func(typ string, secrets *[numberSpaceCount]testKeySecret, e tls.QUICEvent) {
+ checkKey := func(typ string, secrets *[numberSpaceCount]keySecret, e tls.QUICEvent) {
var space numberSpace
switch {
case e.Level == tls.QUICEncryptionLevelHandshake:
@@ -781,25 +856,32 @@
secrets[space].suite = e.Suite
secrets[space].secret = append([]byte{}, e.Data...)
} else if secrets[space].suite != e.Suite || !bytes.Equal(secrets[space].secret, e.Data) {
- tc.t.Errorf("%v key mismatch for level %v", typ, e.Level)
+ tc.t.Errorf("%v key mismatch for level for level %v", typ, e.Level)
+ }
+ }
+ setAppDataKey := func(suite uint16, secret []byte, k *test1RTTKeys) {
+ k.hdr.init(suite, secret)
+ for i := 0; i < len(k.pkt); i++ {
+ k.pkt[i].init(suite, secret)
+ secret = updateSecret(suite, secret)
}
}
switch e.Kind {
case tls.QUICSetReadSecret:
- checkKey("read", &tc.rsecrets, e)
+ checkKey("write", &tc.wsecrets, e)
switch e.Level {
case tls.QUICEncryptionLevelHandshake:
tc.keysHandshake.w.init(e.Suite, e.Data)
case tls.QUICEncryptionLevelApplication:
- tc.keysAppData.w.init(e.Suite, e.Data)
+ setAppDataKey(e.Suite, e.Data, &tc.wkeyAppData)
}
case tls.QUICSetWriteSecret:
- checkKey("write", &tc.wsecrets, e)
+ checkKey("read", &tc.rsecrets, e)
switch e.Level {
case tls.QUICEncryptionLevelHandshake:
tc.keysHandshake.r.init(e.Suite, e.Data)
case tls.QUICEncryptionLevelApplication:
- tc.keysAppData.r.init(e.Suite, e.Data)
+ setAppDataKey(e.Suite, e.Data, &tc.rkeyAppData)
}
case tls.QUICWriteData:
tc.cryptoDataOut[e.Level] = append(tc.cryptoDataOut[e.Level], e.Data...)
@@ -811,20 +893,20 @@
case tls.QUICNoEvent:
return
case tls.QUICSetReadSecret:
- checkKey("write", &tc.wsecrets, e)
+ checkKey("write", &tc.rsecrets, e)
switch e.Level {
case tls.QUICEncryptionLevelHandshake:
tc.keysHandshake.r.init(e.Suite, e.Data)
case tls.QUICEncryptionLevelApplication:
- tc.keysAppData.r.init(e.Suite, e.Data)
+ setAppDataKey(e.Suite, e.Data, &tc.rkeyAppData)
}
case tls.QUICSetWriteSecret:
- checkKey("read", &tc.rsecrets, e)
+ checkKey("read", &tc.wsecrets, e)
switch e.Level {
case tls.QUICEncryptionLevelHandshake:
tc.keysHandshake.w.init(e.Suite, e.Data)
case tls.QUICEncryptionLevelApplication:
- tc.keysAppData.w.init(e.Suite, e.Data)
+ setAppDataKey(e.Suite, e.Data, &tc.wkeyAppData)
}
case tls.QUICWriteData:
tc.cryptoDataIn[e.Level] = append(tc.cryptoDataIn[e.Level], e.Data...)
diff --git a/internal/quic/key_update_test.go b/internal/quic/key_update_test.go
new file mode 100644
index 0000000..6b6bb79
--- /dev/null
+++ b/internal/quic/key_update_test.go
@@ -0,0 +1,163 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "testing"
+)
+
+func TestKeyUpdatePeerUpdates(t *testing.T) {
+ tc := newTestConn(t, serverSide)
+ tc.handshake()
+ tc.ignoreFrames = nil // ignore nothing
+
+ // Peer initiates a key update.
+ tc.sendKeyNumber = 1
+ tc.sendKeyPhaseBit = true
+ tc.writeFrames(packetType1RTT, debugFramePing{})
+
+ // We update to the new key.
+ tc.advanceToTimer()
+ tc.wantFrameType("conn ACKs last packet",
+ packetType1RTT, debugFrameAck{})
+ tc.wantFrame("first packet after a key update is always ack-eliciting",
+ packetType1RTT, debugFramePing{})
+ if got, want := tc.lastPacket.keyNumber, 1; got != want {
+ t.Errorf("after key rotation, conn sent packet with key %v, want %v", got, want)
+ }
+ if !tc.lastPacket.keyPhaseBit {
+ t.Errorf("after key rotation, conn failed to change Key Phase bit")
+ }
+ tc.wantIdle("conn has nothing to send")
+
+ // Peer's ACK of a packet we sent in the new phase completes the update.
+ tc.writeAckForAll()
+
+ // Peer initiates a second key update.
+ tc.sendKeyNumber = 2
+ tc.sendKeyPhaseBit = false
+ tc.writeFrames(packetType1RTT, debugFramePing{})
+
+ // We update to the new key.
+ tc.advanceToTimer()
+ tc.wantFrameType("conn ACKs last packet",
+ packetType1RTT, debugFrameAck{})
+ tc.wantFrame("first packet after a key update is always ack-eliciting",
+ packetType1RTT, debugFramePing{})
+ if got, want := tc.lastPacket.keyNumber, 2; got != want {
+ t.Errorf("after key rotation, conn sent packet with key %v, want %v", got, want)
+ }
+ if tc.lastPacket.keyPhaseBit {
+ t.Errorf("after second key rotation, conn failed to change Key Phase bit")
+ }
+ tc.wantIdle("conn has nothing to send")
+}
+
+func TestKeyUpdateAcceptPreviousPhaseKeys(t *testing.T) {
+ // "An endpoint SHOULD retain old keys for some time after
+ // unprotecting a packet sent using the new keys."
+ // https://www.rfc-editor.org/rfc/rfc9001#section-6.1-8
+ tc := newTestConn(t, serverSide)
+ tc.handshake()
+ tc.ignoreFrames = nil // ignore nothing
+
+ // Peer initiates a key update, skipping one packet number.
+ pnum0 := tc.peerNextPacketNum[appDataSpace]
+ tc.peerNextPacketNum[appDataSpace]++
+ tc.sendKeyNumber = 1
+ tc.sendKeyPhaseBit = true
+ tc.writeFrames(packetType1RTT, debugFramePing{})
+
+ // We update to the new key.
+ // This ACK is not delayed, because we've skipped a packet number.
+ tc.wantFrame("conn ACKs last packet",
+ packetType1RTT, debugFrameAck{
+ ranges: []i64range[packetNumber]{
+ {0, pnum0},
+ {pnum0 + 1, pnum0 + 2},
+ },
+ })
+ tc.wantFrame("first packet after a key update is always ack-eliciting",
+ packetType1RTT, debugFramePing{})
+ if got, want := tc.lastPacket.keyNumber, 1; got != want {
+ t.Errorf("after key rotation, conn sent packet with key %v, want %v", got, want)
+ }
+ if !tc.lastPacket.keyPhaseBit {
+ t.Errorf("after key rotation, conn failed to change Key Phase bit")
+ }
+ tc.wantIdle("conn has nothing to send")
+
+ // We receive the previously-skipped packet in the earlier key phase.
+ tc.peerNextPacketNum[appDataSpace] = pnum0
+ tc.sendKeyNumber = 0
+ tc.sendKeyPhaseBit = false
+ tc.writeFrames(packetType1RTT, debugFramePing{})
+
+ // We ack the reordered packet immediately, still in the new key phase.
+ tc.wantFrame("conn ACKs reordered packet",
+ packetType1RTT, debugFrameAck{
+ ranges: []i64range[packetNumber]{
+ {0, pnum0 + 2},
+ },
+ })
+ tc.wantIdle("packet is not ack-eliciting")
+ if got, want := tc.lastPacket.keyNumber, 1; got != want {
+ t.Errorf("after key rotation, conn sent packet with key %v, want %v", got, want)
+ }
+ if !tc.lastPacket.keyPhaseBit {
+ t.Errorf("after key rotation, conn failed to change Key Phase bit")
+ }
+}
+
+func TestKeyUpdateRejectPacketFromPriorPhase(t *testing.T) {
+ // "Packets with higher packet numbers MUST be protected with either
+ // the same or newer packet protection keys than packets with lower packet numbers."
+ // https://www.rfc-editor.org/rfc/rfc9001#section-6.4-2
+ tc := newTestConn(t, serverSide)
+ tc.handshake()
+ tc.ignoreFrames = nil // ignore nothing
+
+ // Peer initiates a key update.
+ tc.sendKeyNumber = 1
+ tc.sendKeyPhaseBit = true
+ tc.writeFrames(packetType1RTT, debugFramePing{})
+
+ // We update to the new key.
+ tc.advanceToTimer()
+ tc.wantFrameType("conn ACKs last packet",
+ packetType1RTT, debugFrameAck{})
+ tc.wantFrame("first packet after a key update is always ack-eliciting",
+ packetType1RTT, debugFramePing{})
+ if got, want := tc.lastPacket.keyNumber, 1; got != want {
+ t.Errorf("after key rotation, conn sent packet with key %v, want %v", got, want)
+ }
+ if !tc.lastPacket.keyPhaseBit {
+ t.Errorf("after key rotation, conn failed to change Key Phase bit")
+ }
+ tc.wantIdle("conn has nothing to send")
+
+ // Peer sends an ack-eliciting packet using the prior phase keys.
+ // We fail to unprotect the packet and ignore it.
+ skipped := tc.peerNextPacketNum[appDataSpace]
+ tc.sendKeyNumber = 0
+ tc.sendKeyPhaseBit = false
+ tc.writeFrames(packetType1RTT, debugFramePing{})
+
+ // Peer sends an ack-eliciting packet using the current phase keys.
+ tc.sendKeyNumber = 1
+ tc.sendKeyPhaseBit = true
+ tc.writeFrames(packetType1RTT, debugFramePing{})
+
+ // We ack the peer's packets, not including the one sent with the wrong keys.
+ tc.wantFrame("conn ACKs packets, not including packet sent with wrong keys",
+ packetType1RTT, debugFrameAck{
+ ranges: []i64range[packetNumber]{
+ {0, skipped},
+ {skipped + 1, skipped + 2},
+ },
+ })
+}
diff --git a/internal/quic/packet.go b/internal/quic/packet.go
index a1bcead..8242bd0 100644
--- a/internal/quic/packet.go
+++ b/internal/quic/packet.go
@@ -45,6 +45,7 @@
fixedBit = 0x40 // https://www.rfc-editor.org/rfc/rfc9000.html#section-17.2-3.4.1
reservedLongBits = 0x0c // https://www.rfc-editor.org/rfc/rfc9000#section-17.2-8.2.1
reserved1RTTBits = 0x18 // https://www.rfc-editor.org/rfc/rfc9000#section-17.3.1-4.8.1
+ keyPhaseBit = 0x04 // https://www.rfc-editor.org/rfc/rfc9000#section-17.3.1-4.10.1
)
// Long Packet Type bits.
diff --git a/internal/quic/packet_codec_test.go b/internal/quic/packet_codec_test.go
index 7f0846f..c8b1f9b 100644
--- a/internal/quic/packet_codec_test.go
+++ b/internal/quic/packet_codec_test.go
@@ -146,10 +146,13 @@
}
func TestRoundtripEncodeShortPacket(t *testing.T) {
- var aes128Keys, aes256Keys, chachaKeys fixedKeys
- aes128Keys.init(tls.TLS_AES_128_GCM_SHA256, []byte("secret"))
- aes256Keys.init(tls.TLS_AES_256_GCM_SHA384, []byte("secret"))
- chachaKeys.init(tls.TLS_CHACHA20_POLY1305_SHA256, []byte("secret"))
+ var aes128Keys, aes256Keys, chachaKeys updatingKeyPair
+ aes128Keys.r.init(tls.TLS_AES_128_GCM_SHA256, []byte("secret"))
+ aes256Keys.r.init(tls.TLS_AES_256_GCM_SHA384, []byte("secret"))
+ chachaKeys.r.init(tls.TLS_CHACHA20_POLY1305_SHA256, []byte("secret"))
+ aes128Keys.w = aes128Keys.r
+ aes256Keys.w = aes256Keys.r
+ chachaKeys.w = chachaKeys.r
connID := make([]byte, connIDLen)
for i := range connID {
connID[i] = byte(i)
@@ -158,7 +161,7 @@
desc string
num packetNumber
payload []byte
- k fixedKeys
+ k updatingKeyPair
}{{
desc: "1-byte number, AES128",
num: 0, // 1-byte encoding,
@@ -185,9 +188,9 @@
w.reset(1200)
w.start1RTTPacket(test.num, 0, connID)
w.b = append(w.b, test.payload...)
- w.finish1RTTPacket(test.num, 0, connID, test.k)
+ w.finish1RTTPacket(test.num, 0, connID, &test.k)
pkt := w.datagram()
- p, n := parse1RTTPacket(pkt, test.k, connIDLen, 0)
+ p, n := parse1RTTPacket(pkt, &test.k, connIDLen, 0)
if n != len(pkt) {
t.Errorf("parse1RTTPacket: n=%v, want %v", n, len(pkt))
}
diff --git a/internal/quic/packet_parser.go b/internal/quic/packet_parser.go
index 458cd3a..8bb3cae 100644
--- a/internal/quic/packet_parser.go
+++ b/internal/quic/packet_parser.go
@@ -143,12 +143,13 @@
//
// On input, pkt contains a short header packet, k the decryption keys for the packet,
// and pnumMax the largest packet number seen in the number space of this packet.
-func parse1RTTPacket(pkt []byte, k fixedKeys, dstConnIDLen int, pnumMax packetNumber) (p shortPacket, n int) {
- var err error
- p.payload, p.num, err = k.unprotect(pkt, 1+dstConnIDLen, pnumMax)
+func parse1RTTPacket(pkt []byte, k *updatingKeyPair, dstConnIDLen int, pnumMax packetNumber) (p shortPacket, n int) {
+ pay, pnum, err := k.unprotect(pkt, 1+dstConnIDLen, pnumMax)
if err != nil {
return shortPacket{}, -1
}
+ p.num = pnum
+ p.payload = pay
return p, len(pkt)
}
diff --git a/internal/quic/packet_protection.go b/internal/quic/packet_protection.go
index 2f9b9ce..aab1eaf 100644
--- a/internal/quic/packet_protection.go
+++ b/internal/quic/packet_protection.go
@@ -37,6 +37,10 @@
hp headerProtection
}
+func (k headerKey) isSet() bool {
+ return k.hp != nil
+}
+
func (k *headerKey) init(suite uint16, secret []byte) {
h, keySize := hashForSuite(suite)
hpKey := hkdfExpandLabel(h.New, secret, "quic hp", nil, keySize)
@@ -275,6 +279,148 @@
return k.w.isSet()
}
+// An updatingKeys is a header protection key and updatable packet protection key.
+// updatingKeys are used for 1-RTT keys, where the packet protection key changes
+// over the lifetime of a connection.
+// https://www.rfc-editor.org/rfc/rfc9001#section-6
+type updatingKeys struct {
+ suite uint16
+ hdr headerKey
+ pkt [2]packetKey // current, next
+ nextSecret []byte // secret used to generate pkt[1]
+}
+
+func (k *updatingKeys) init(suite uint16, secret []byte) {
+ k.suite = suite
+ k.hdr.init(suite, secret)
+ // Initialize pkt[1] with secret_0, and then call update to generate secret_1.
+ k.pkt[1].init(suite, secret)
+ k.nextSecret = secret
+ k.update()
+}
+
+// update performs a key update.
+// The current key in pkt[0] is discarded.
+// The next key in pkt[1] becomes the current key.
+// A new next key is generated in pkt[1].
+func (k *updatingKeys) update() {
+ k.nextSecret = updateSecret(k.suite, k.nextSecret)
+ k.pkt[0] = k.pkt[1]
+ k.pkt[1].init(k.suite, k.nextSecret)
+}
+
+func updateSecret(suite uint16, secret []byte) (nextSecret []byte) {
+ h, _ := hashForSuite(suite)
+ return hkdfExpandLabel(h.New, secret, "quic ku", nil, len(secret))
+}
+
+// An updatingKeyPair is a read/write pair of updating keys.
+//
+// We keep two keys (current and next) in both read and write directions.
+// When an incoming packet's phase matches the current phase bit,
+// we unprotect it using the current keys; otherwise we use the next keys.
+//
+// When updating=false, outgoing packets are protected using the current phase.
+//
+// An update is initiated and updating is set to true when:
+// - we decide to initiate a key update; or
+// - we successfully unprotect a packet using the next keys,
+// indicating the peer has initiated a key update.
+//
+// When updating=true, outgoing packets are protected using the next phase.
+// We do not change the current phase bit or generate new keys yet.
+//
+// The update concludes when we receive an ACK frame for a packet sent
+// with the next keys. At this time, we set updating to false, flip the
+// phase bit, and update the keys. This permits us to handle up to 1-RTT
+// of reordered packets before discarding the previous phase's keys after
+// an update.
+type updatingKeyPair struct {
+ phase uint8 // current key phase (r.pkt[0], w.pkt[0])
+ updating bool
+ minSent packetNumber // min packet number sent since entering the updating state
+ minReceived packetNumber // min packet number received in the next phase
+ r, w updatingKeys
+}
+
+func (k *updatingKeyPair) canRead() bool {
+ return k.r.hdr.hp != nil
+}
+
+func (k *updatingKeyPair) canWrite() bool {
+ return k.w.hdr.hp != nil
+}
+
+// handleAckFor finishes a key update after receiving an ACK for a packet in the next phase.
+func (k *updatingKeyPair) handleAckFor(pnum packetNumber) {
+ if k.updating && pnum >= k.minSent {
+ k.updating = false
+ k.phase ^= keyPhaseBit
+ k.r.update()
+ k.w.update()
+ }
+}
+
+// needAckEliciting reports whether we should send an ack-eliciting packet in the next phase.
+// The first packet sent in a phase is ack-eliciting, since the peer must acknowledge a
+// packet in the new phase for us to finish the update.
+func (k *updatingKeyPair) needAckEliciting() bool {
+ return k.updating && k.minSent == maxPacketNumber
+}
+
+// protect applies packet protection to a packet.
+// Parameters and returns are as for fixedKeyPair.protect.
+func (k *updatingKeyPair) protect(hdr, pay []byte, pnumOff int, pnum packetNumber) []byte {
+ // TODO: Initiate key updates as required to avoid the AEAD usage limit.
+ // https://www.rfc-editor.org/rfc/rfc9001#section-6.6
+ var pkt []byte
+ if k.updating {
+ hdr[0] |= k.phase ^ keyPhaseBit
+ pkt = k.w.pkt[1].protect(hdr, pay, pnum)
+ k.minSent = min(pnum, k.minSent)
+ } else {
+ hdr[0] |= k.phase
+ pkt = k.w.pkt[0].protect(hdr, pay, pnum)
+ }
+ k.w.hdr.protect(pkt, pnumOff)
+ return pkt
+}
+
+// unprotect removes packet protection from a packet.
+// Parameters and returns are as for fixedKeyPair.unprotect.
+func (k *updatingKeyPair) unprotect(pkt []byte, pnumOff int, pnumMax packetNumber) (pay []byte, pnum packetNumber, err error) {
+ hdr, pay, pnum, err := k.r.hdr.unprotect(pkt, pnumOff, pnumMax)
+ if err != nil {
+ return nil, 0, err
+ }
+ // To avoid timing signals that might indicate the key phase bit is invalid,
+ // we always attempt to unprotect the packet with one key.
+ //
+ // If the key phase bit matches and the packet number doesn't come after
+ // the start of an in-progress update, use the current phase.
+ // Otherwise, use the next phase.
+ if hdr[0]&keyPhaseBit == k.phase && (!k.updating || pnum < k.minReceived) {
+ pay, err = k.r.pkt[0].unprotect(hdr, pay, pnum)
+ if err != nil {
+ return nil, 0, err
+ }
+ } else {
+ pay, err = k.r.pkt[1].unprotect(hdr, pay, pnum)
+ if err != nil {
+ return nil, 0, err
+ }
+ if !k.updating {
+ // The peer has initiated a key update.
+ k.updating = true
+ k.minSent = maxPacketNumber
+ k.minReceived = pnum
+ } else {
+ k.minReceived = min(pnum, k.minReceived)
+ }
+ }
+ return pay, pnum, nil
+}
+
// https://www.rfc-editor.org/rfc/rfc9001#section-5.2-2
var initialSalt = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a}
diff --git a/internal/quic/packet_writer.go b/internal/quic/packet_writer.go
index 2009895..0c2b2ee 100644
--- a/internal/quic/packet_writer.go
+++ b/internal/quic/packet_writer.go
@@ -163,14 +163,13 @@
// finish1RTTPacket finishes writing a 1-RTT packet,
// canceling the packet if it contains no payload.
// It returns a sentPacket describing the packet, or nil if no packet was written.
-func (w *packetWriter) finish1RTTPacket(pnum, pnumMaxAcked packetNumber, dstConnID []byte, k fixedKeys) *sentPacket {
+func (w *packetWriter) finish1RTTPacket(pnum, pnumMaxAcked packetNumber, dstConnID []byte, k *updatingKeyPair) *sentPacket {
if len(w.b) == w.payOff {
// The payload is empty, so just abandon the packet.
w.b = w.b[:w.pktOff]
return nil
}
// TODO: Spin
- // TODO: Key phase
pnumLen := packetNumberLength(pnum, pnumMaxAcked)
hdr := w.b[:w.pktOff]
hdr = append(hdr, 0x40|byte(pnumLen-1))