quic: refactor keys for key updates
Refactor how we store encryption keys in preparation for adding
support for key updates.
Previously, we had a single "keys" type containing header and packet
protection key material. With key update, the 1-RTT header protection
keys are consistent across the lifetime of a connection, while
packet protection keys vary. Separate out the header and packet
protection keys into distinct types.
Add "fixed" key types for keys which remain fixed across a
connection's lifetime and do not update. For the moment,
1-RTT keys are still fixed.
Remove a number of can-never-happen error returns from
key handling paths. We were previously inconsistent about
where to panic and where to return an error on these paths;
we now consistently panic in paths where errors can only
occur due to a bug. (For example, attempting to create an
AEAD with an incorrect secret size.)
No functional changes, this is purely refactoring.
For golang/go#58547
Change-Id: I49f83091517186e452845b65a1597add60e5fc92
Reviewed-on: https://go-review.googlesource.com/c/net/+/529155
Reviewed-by: Jonathan Amsterdam <jba@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/internal/quic/conn.go b/internal/quic/conn.go
index 26c25f8..4565e1a 100644
--- a/internal/quic/conn.go
+++ b/internal/quic/conn.go
@@ -42,10 +42,11 @@
idleTimeout time.Time
// Packet protection keys, CRYPTO streams, and TLS state.
- rkeys [numberSpaceCount]keys
- wkeys [numberSpaceCount]keys
- crypto [numberSpaceCount]cryptoStream
- tls *tls.QUICConn
+ keysInitial fixedKeyPair
+ keysHandshake fixedKeyPair
+ keysAppData fixedKeyPair
+ crypto [numberSpaceCount]cryptoStream
+ tls *tls.QUICConn
// handshakeConfirmed is set when the handshake is confirmed.
// For server connections, it tracks sending HANDSHAKE_DONE.
@@ -156,8 +157,12 @@
// discardKeys discards unused packet protection keys.
// https://www.rfc-editor.org/rfc/rfc9001#section-4.9
func (c *Conn) discardKeys(now time.Time, space numberSpace) {
- c.rkeys[space].discard()
- c.wkeys[space].discard()
+ switch space {
+ case initialSpace:
+ c.keysInitial.discard()
+ case handshakeSpace:
+ c.keysHandshake.discard()
+ }
c.loss.discardKeys(now, space)
}
diff --git a/internal/quic/conn_recv.go b/internal/quic/conn_recv.go
index 64e5f98..d1fa52d 100644
--- a/internal/quic/conn_recv.go
+++ b/internal/quic/conn_recv.go
@@ -26,9 +26,9 @@
// https://www.rfc-editor.org/rfc/rfc9000#section-14.1-4
return
}
- n = c.handleLongHeader(now, ptype, initialSpace, buf)
+ n = c.handleLongHeader(now, ptype, initialSpace, c.keysInitial.r, buf)
case packetTypeHandshake:
- n = c.handleLongHeader(now, ptype, handshakeSpace, buf)
+ n = c.handleLongHeader(now, ptype, handshakeSpace, c.keysHandshake.r, buf)
case packetType1RTT:
n = c.handle1RTT(now, buf)
default:
@@ -43,13 +43,13 @@
}
}
-func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpace, buf []byte) int {
- if !c.rkeys[space].isSet() {
+func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpace, k fixedKeys, buf []byte) int {
+ if !k.isSet() {
return skipLongHeaderPacket(buf)
}
pnumMax := c.acks[space].largestSeen()
- p, n := parseLongHeaderPacket(buf, c.rkeys[space], pnumMax)
+ p, n := parseLongHeaderPacket(buf, k, pnumMax)
if n < 0 {
return -1
}
@@ -82,14 +82,14 @@
}
func (c *Conn) handle1RTT(now time.Time, buf []byte) int {
- if !c.rkeys[appDataSpace].isSet() {
+ if !c.keysAppData.canRead() {
// 1-RTT packets extend to the end of the datagram,
// so skip the remainder of the datagram if we can't parse this.
return len(buf)
}
pnumMax := c.acks[appDataSpace].largestSeen()
- p, n := parse1RTTPacket(buf, c.rkeys[appDataSpace], connIDLen, pnumMax)
+ p, n := parse1RTTPacket(buf, c.keysAppData.r, connIDLen, pnumMax)
if n < 0 {
return -1
}
diff --git a/internal/quic/conn_send.go b/internal/quic/conn_send.go
index 853c845..58a3df1 100644
--- a/internal/quic/conn_send.go
+++ b/internal/quic/conn_send.go
@@ -59,7 +59,7 @@
// Initial packet.
pad := false
var sentInitial *sentPacket
- if k := c.wkeys[initialSpace]; k.isSet() {
+ if c.keysInitial.canWrite() {
pnumMaxAcked := c.acks[initialSpace].largestSeen()
pnum := c.loss.nextNumber(initialSpace)
p := longPacket{
@@ -74,7 +74,7 @@
if logPackets {
logSentPacket(c, packetTypeInitial, pnum, p.srcConnID, p.dstConnID, c.w.payload())
}
- sentInitial = c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, k, p)
+ sentInitial = c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, c.keysInitial.w, p)
if sentInitial != nil {
// Client initial packets need to be sent in a datagram padded to
// at least 1200 bytes. We can't add the padding yet, however,
@@ -86,7 +86,7 @@
}
// Handshake packet.
- if k := c.wkeys[handshakeSpace]; k.isSet() {
+ if c.keysHandshake.canWrite() {
pnumMaxAcked := c.acks[handshakeSpace].largestSeen()
pnum := c.loss.nextNumber(handshakeSpace)
p := longPacket{
@@ -101,7 +101,7 @@
if logPackets {
logSentPacket(c, packetTypeHandshake, pnum, p.srcConnID, p.dstConnID, c.w.payload())
}
- if sent := c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, k, p); sent != nil {
+ if sent := c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, c.keysHandshake.w, p); sent != nil {
c.loss.packetSent(now, handshakeSpace, sent)
if c.side == clientSide {
// "[...] a client MUST discard Initial keys when it first
@@ -113,7 +113,7 @@
}
// 1-RTT packet.
- if k := c.wkeys[appDataSpace]; k.isSet() {
+ if c.keysAppData.canWrite() {
pnumMaxAcked := c.acks[appDataSpace].largestSeen()
pnum := c.loss.nextNumber(appDataSpace)
c.w.start1RTTPacket(pnum, pnumMaxAcked, dstConnID)
@@ -128,7 +128,7 @@
if logPackets {
logSentPacket(c, packetType1RTT, pnum, nil, dstConnID, c.w.payload())
}
- if sent := c.w.finish1RTTPacket(pnum, pnumMaxAcked, dstConnID, k); sent != nil {
+ if sent := c.w.finish1RTTPacket(pnum, pnumMaxAcked, dstConnID, c.keysAppData.w); sent != nil {
c.loss.packetSent(now, appDataSpace, sent)
}
}
@@ -157,7 +157,10 @@
sentInitial.inFlight = true
}
}
- if k := c.wkeys[initialSpace]; k.isSet() {
+ // If we're a client and this Initial packet is coalesced
+ // with a Handshake packet, then we've discarded Initial keys
+ // since constructing the packet and shouldn't record it as in-flight.
+ if c.keysInitial.canWrite() {
c.loss.packetSent(now, initialSpace, sentInitial)
}
}
diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go
index 4228ce7..3fef62d 100644
--- a/internal/quic/conn_test.go
+++ b/internal/quic/conn_test.go
@@ -113,15 +113,18 @@
timerLastFired time.Time
idlec chan struct{} // only accessed on the conn's loop
- // Read and write keys are distinct from the conn's keys,
+ // Keys are distinct from the conn's keys,
// because the test may know about keys before the conn does.
// For example, when sending a datagram with coalesced
// Initial and Handshake packets to a client conn,
// we use Handshake keys to encrypt the packet.
// The client only acquires those keys when it processes
// the Initial packet.
- rkeys [numberSpaceCount]keyData // for packets sent to the conn
- wkeys [numberSpaceCount]keyData // for packets sent by the conn
+ keysInitial fixedKeyPair
+ keysHandshake fixedKeyPair
+ keysAppData fixedKeyPair
+ rsecrets [numberSpaceCount]testKeySecret
+ wsecrets [numberSpaceCount]testKeySecret
// testConn uses a test hook to snoop on the conn's TLS events.
// CRYPTO data produced by the conn's QUICConn is placed in
@@ -156,10 +159,9 @@
asyncTestState
}
-type keyData struct {
+type testKeySecret struct {
suite uint16
secret []byte
- k keys
}
// newTestConn creates a Conn for testing.
@@ -225,8 +227,8 @@
}
tc.conn = conn
- tc.wkeys[initialSpace].k = conn.wkeys[initialSpace]
- tc.rkeys[initialSpace].k = conn.rkeys[initialSpace]
+ tc.keysInitial.r = conn.keysInitial.w
+ tc.keysInitial.w = conn.keysInitial.r
tc.wait()
return tc
@@ -611,14 +613,19 @@
for _, f := range p.frames {
f.write(&w)
}
- space := spaceForPacketType(p.ptype)
- if !tc.rkeys[space].k.isSet() {
- tc.t.Fatalf("sending packet with no %v keys available", space)
- return nil
- }
w.appendPaddingTo(pad)
if p.ptype != packetType1RTT {
- w.finishProtectedLongHeaderPacket(pnumMaxAcked, tc.rkeys[space].k, longPacket{
+ var k fixedKeyPair
+ switch p.ptype {
+ case packetTypeInitial:
+ k = tc.keysInitial
+ case packetTypeHandshake:
+ k = tc.keysHandshake
+ }
+ if !k.canWrite() {
+ tc.t.Fatalf("sending %v packet with no write key", p.ptype)
+ }
+ w.finishProtectedLongHeaderPacket(pnumMaxAcked, k.w, longPacket{
ptype: p.ptype,
version: p.version,
num: p.num,
@@ -626,7 +633,10 @@
srcConnID: p.srcConnID,
})
} else {
- w.finish1RTTPacket(p.num, pnumMaxAcked, p.dstConnID, tc.rkeys[space].k)
+ if !tc.keysAppData.canWrite() {
+ tc.t.Fatalf("sending %v packet with no write key", p.ptype)
+ }
+ w.finish1RTTPacket(p.num, pnumMaxAcked, p.dstConnID, tc.keysAppData.w)
}
return w.datagram()
}
@@ -642,13 +652,19 @@
break
}
ptype := getPacketType(buf)
- space := spaceForPacketType(ptype)
- if !tc.wkeys[space].k.isSet() {
- tc.t.Fatalf("no keys for space %v, packet type %v", space, ptype)
- }
if isLongHeader(buf[0]) {
+ var k fixedKeyPair
+ switch ptype {
+ case packetTypeInitial:
+ k = tc.keysInitial
+ case packetTypeHandshake:
+ k = tc.keysHandshake
+ }
+ if !k.canRead() {
+ tc.t.Fatalf("reading %v packet with no read key", ptype)
+ }
var pnumMax packetNumber // TODO: Track packet numbers.
- p, n := parseLongHeaderPacket(buf, tc.wkeys[space].k, pnumMax)
+ p, n := parseLongHeaderPacket(buf, k.r, pnumMax)
if n < 0 {
tc.t.Fatalf("packet parse error")
}
@@ -666,8 +682,11 @@
})
buf = buf[n:]
} else {
+ if !tc.keysAppData.canRead() {
+ tc.t.Fatalf("reading 1-RTT packet with no read key")
+ }
var pnumMax packetNumber // TODO: Track packet numbers.
- p, n := parse1RTTPacket(buf, tc.wkeys[space].k, len(tc.peerConnID), pnumMax)
+ p, n := parse1RTTPacket(buf, tc.keysAppData.r, len(tc.peerConnID), pnumMax)
if n < 0 {
tc.t.Fatalf("packet parse error")
}
@@ -747,12 +766,7 @@
// and verify that both sides of the connection are getting
// matching keys.
func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) {
- setKey := func(keys *[numberSpaceCount]keyData, e tls.QUICEvent) {
- k, err := newKeys(e.Suite, e.Data)
- if err != nil {
- tc.t.Errorf("newKeys: %v", err)
- return
- }
+ checkKey := func(typ string, secrets *[numberSpaceCount]testKeySecret, e tls.QUICEvent) {
var space numberSpace
switch {
case e.Level == tls.QUICEncryptionLevelHandshake:
@@ -763,25 +777,30 @@
tc.t.Errorf("unexpected encryption level %v", e.Level)
return
}
- s := "read"
- if keys == &tc.wkeys {
- s = "write"
+ if secrets[space].secret == nil {
+ secrets[space].suite = e.Suite
+ secrets[space].secret = append([]byte{}, e.Data...)
+ } else if secrets[space].suite != e.Suite || !bytes.Equal(secrets[space].secret, e.Data) {
+ tc.t.Errorf("%v key mismatch for level %v", typ, e.Level)
}
- if keys[space].k.isSet() {
- if keys[space].suite != e.Suite || !bytes.Equal(keys[space].secret, e.Data) {
- tc.t.Errorf("%v key mismatch for level for level %v", s, e.Level)
- }
- return
- }
- keys[space].suite = e.Suite
- keys[space].secret = append([]byte{}, e.Data...)
- keys[space].k = k
}
switch e.Kind {
case tls.QUICSetReadSecret:
- setKey(&tc.rkeys, e)
+ checkKey("read", &tc.rsecrets, e)
+ switch e.Level {
+ case tls.QUICEncryptionLevelHandshake:
+ tc.keysHandshake.w.init(e.Suite, e.Data)
+ case tls.QUICEncryptionLevelApplication:
+ tc.keysAppData.w.init(e.Suite, e.Data)
+ }
case tls.QUICSetWriteSecret:
- setKey(&tc.wkeys, e)
+ checkKey("write", &tc.wsecrets, e)
+ switch e.Level {
+ case tls.QUICEncryptionLevelHandshake:
+ tc.keysHandshake.r.init(e.Suite, e.Data)
+ case tls.QUICEncryptionLevelApplication:
+ tc.keysAppData.r.init(e.Suite, e.Data)
+ }
case tls.QUICWriteData:
tc.cryptoDataOut[e.Level] = append(tc.cryptoDataOut[e.Level], e.Data...)
tc.peerTLSConn.HandleData(e.Level, e.Data)
@@ -792,9 +811,21 @@
case tls.QUICNoEvent:
return
case tls.QUICSetReadSecret:
- setKey(&tc.wkeys, e)
+ checkKey("write", &tc.wsecrets, e)
+ switch e.Level {
+ case tls.QUICEncryptionLevelHandshake:
+ tc.keysHandshake.r.init(e.Suite, e.Data)
+ case tls.QUICEncryptionLevelApplication:
+ tc.keysAppData.r.init(e.Suite, e.Data)
+ }
case tls.QUICSetWriteSecret:
- setKey(&tc.rkeys, e)
+ checkKey("read", &tc.rsecrets, e)
+ switch e.Level {
+ case tls.QUICEncryptionLevelHandshake:
+ tc.keysHandshake.w.init(e.Suite, e.Data)
+ case tls.QUICEncryptionLevelApplication:
+ tc.keysAppData.w.init(e.Suite, e.Data)
+ }
case tls.QUICWriteData:
tc.cryptoDataIn[e.Level] = append(tc.cryptoDataIn[e.Level], e.Data...)
case tls.QUICTransportParameters:
diff --git a/internal/quic/packet_codec_test.go b/internal/quic/packet_codec_test.go
index 3503d24..7f0846f 100644
--- a/internal/quic/packet_codec_test.go
+++ b/internal/quic/packet_codec_test.go
@@ -17,7 +17,7 @@
// Example Initial packet from:
// https://www.rfc-editor.org/rfc/rfc9001.html#section-a.3
cid := unhex(`8394c8f03e515708`)
- _, initialServerKeys := initialKeys(cid)
+ initialServerKeys := initialKeys(cid, clientSide).r
pkt := unhex(`
cf000000010008f067a5502a4262b500 4075c0d95a482cd0991cd25b0aac406a
5816b6394100f37a1c69797554780bb3 8cc5a99f5ede4cf73c3ec2493a1839b3
@@ -65,20 +65,21 @@
}
// Parse with the wrong keys.
- _, invalidKeys := initialKeys([]byte{})
+ invalidKeys := initialKeys([]byte{}, clientSide).w
if _, n := parseLongHeaderPacket(pkt, invalidKeys, 0); n != -1 {
t.Fatalf("parse long header packet with wrong keys: n=%v, want -1", n)
}
}
func TestRoundtripEncodeLongPacket(t *testing.T) {
- aes128Keys, _ := newKeys(tls.TLS_AES_128_GCM_SHA256, []byte("secret"))
- aes256Keys, _ := newKeys(tls.TLS_AES_256_GCM_SHA384, []byte("secret"))
- chachaKeys, _ := newKeys(tls.TLS_CHACHA20_POLY1305_SHA256, []byte("secret"))
+ var aes128Keys, aes256Keys, chachaKeys fixedKeys
+ aes128Keys.init(tls.TLS_AES_128_GCM_SHA256, []byte("secret"))
+ aes256Keys.init(tls.TLS_AES_256_GCM_SHA384, []byte("secret"))
+ chachaKeys.init(tls.TLS_CHACHA20_POLY1305_SHA256, []byte("secret"))
for _, test := range []struct {
desc string
p longPacket
- k keys
+ k fixedKeys
}{{
desc: "Initial, 1-byte number, AES128",
p: longPacket{
@@ -145,9 +146,10 @@
}
func TestRoundtripEncodeShortPacket(t *testing.T) {
- aes128Keys, _ := newKeys(tls.TLS_AES_128_GCM_SHA256, []byte("secret"))
- aes256Keys, _ := newKeys(tls.TLS_AES_256_GCM_SHA384, []byte("secret"))
- chachaKeys, _ := newKeys(tls.TLS_CHACHA20_POLY1305_SHA256, []byte("secret"))
+ var aes128Keys, aes256Keys, chachaKeys fixedKeys
+ aes128Keys.init(tls.TLS_AES_128_GCM_SHA256, []byte("secret"))
+ aes256Keys.init(tls.TLS_AES_256_GCM_SHA384, []byte("secret"))
+ chachaKeys.init(tls.TLS_CHACHA20_POLY1305_SHA256, []byte("secret"))
connID := make([]byte, connIDLen)
for i := range connID {
connID[i] = byte(i)
@@ -156,7 +158,7 @@
desc string
num packetNumber
payload []byte
- k keys
+ k fixedKeys
}{{
desc: "1-byte number, AES128",
num: 0, // 1-byte encoding,
@@ -700,7 +702,7 @@
func FuzzParseLongHeaderPacket(f *testing.F) {
cid := unhex(`0000000000000000`)
- _, initialServerKeys := initialKeys(cid)
+ initialServerKeys := initialKeys(cid, clientSide).r
f.Fuzz(func(t *testing.T, in []byte) {
parseLongHeaderPacket(in, initialServerKeys, 0)
})
diff --git a/internal/quic/packet_parser.go b/internal/quic/packet_parser.go
index 4323882..458cd3a 100644
--- a/internal/quic/packet_parser.go
+++ b/internal/quic/packet_parser.go
@@ -18,7 +18,7 @@
// and its length in bytes.
//
// It returns an empty packet and -1 if the packet could not be parsed.
-func parseLongHeaderPacket(pkt []byte, k keys, pnumMax packetNumber) (p longPacket, n int) {
+func parseLongHeaderPacket(pkt []byte, k fixedKeys, pnumMax packetNumber) (p longPacket, n int) {
if len(pkt) < 5 || !isLongHeader(pkt[0]) {
return longPacket{}, -1
}
@@ -143,7 +143,7 @@
//
// On input, pkt contains a short header packet, k the decryption keys for the packet,
// and pnumMax the largest packet number seen in the number space of this packet.
-func parse1RTTPacket(pkt []byte, k keys, dstConnIDLen int, pnumMax packetNumber) (p shortPacket, n int) {
+func parse1RTTPacket(pkt []byte, k fixedKeys, dstConnIDLen int, pnumMax packetNumber) (p shortPacket, n int) {
var err error
p.payload, p.num, err = k.unprotect(pkt, 1+dstConnIDLen, pnumMax)
if err != nil {
diff --git a/internal/quic/packet_protection.go b/internal/quic/packet_protection.go
index 1847053..2f9b9ce 100644
--- a/internal/quic/packet_protection.go
+++ b/internal/quic/packet_protection.go
@@ -13,7 +13,6 @@
"crypto/sha256"
"crypto/tls"
"errors"
- "fmt"
"hash"
"golang.org/x/crypto/chacha20"
@@ -24,44 +23,145 @@
var errInvalidPacket = errors.New("quic: invalid packet")
-// keys holds the cryptographic material used to protect packets
-// at an encryption level and direction. (e.g., Initial client keys.)
-//
-// keys are not safe for concurrent use.
-type keys struct {
- // AEAD function used for packet protection.
- aead cipher.AEAD
+// headerProtectionSampleSize is the size of the ciphertext sample used for header protection.
+// https://www.rfc-editor.org/rfc/rfc9001#section-5.4.2
+const headerProtectionSampleSize = 16
- // The header_protection function as defined in:
- // https://www.rfc-editor.org/rfc/rfc9001#section-5.4.1
- //
- // This function takes a sample of the packet ciphertext
- // and returns a 5-byte mask which will be applied to the
- // protected portions of the packet header.
- headerProtection func(sample []byte) (mask [5]byte)
+// aeadOverhead is the difference in size between the AEAD output and input.
+// All cipher suites defined for use with QUIC have 16 bytes of overhead.
+const aeadOverhead = 16
- // IV used to construct the AEAD nonce.
- iv []byte
+// A headerKey applies or removes header protection.
+// https://www.rfc-editor.org/rfc/rfc9001#section-5.4
+type headerKey struct {
+ hp headerProtection
}
-// newKeys creates keys for a given cipher suite and secret.
-//
-// It returns an error if the suite is unknown.
-func newKeys(suite uint16, secret []byte) (keys, error) {
+func (k *headerKey) init(suite uint16, secret []byte) {
+ h, keySize := hashForSuite(suite)
+ hpKey := hkdfExpandLabel(h.New, secret, "quic hp", nil, keySize)
switch suite {
- case tls.TLS_AES_128_GCM_SHA256:
- return newAESKeys(secret, crypto.SHA256, 128/8), nil
- case tls.TLS_AES_256_GCM_SHA384:
- return newAESKeys(secret, crypto.SHA384, 256/8), nil
+ case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
+ c, err := aes.NewCipher(hpKey)
+ if err != nil {
+ panic(err)
+ }
+ k.hp = &aesHeaderProtection{cipher: c}
case tls.TLS_CHACHA20_POLY1305_SHA256:
- return newChaCha20Keys(secret), nil
+ k.hp = chaCha20HeaderProtection{hpKey}
+ default:
+ panic("BUG: unknown cipher suite")
}
- return keys{}, fmt.Errorf("unknown cipher suite %x", suite)
}
-func newAESKeys(secret []byte, h crypto.Hash, keyBytes int) keys {
+// protect applies header protection.
+// pnumOff is the offset of the packet number in the packet.
+func (k headerKey) protect(hdr []byte, pnumOff int) {
+ // Apply header protection.
+ pnumSize := int(hdr[0]&0x03) + 1
+ sample := hdr[pnumOff+4:][:headerProtectionSampleSize]
+ mask := k.hp.headerProtection(sample)
+ if isLongHeader(hdr[0]) {
+ hdr[0] ^= mask[0] & 0x0f
+ } else {
+ hdr[0] ^= mask[0] & 0x1f
+ }
+ for i := 0; i < pnumSize; i++ {
+ hdr[pnumOff+i] ^= mask[1+i]
+ }
+}
+
+// unprotect removes header protection.
+// pnumOff is the offset of the packet number in the packet.
+// pnumMax is the largest packet number seen in the number space of this packet.
+func (k headerKey) unprotect(pkt []byte, pnumOff int, pnumMax packetNumber) (hdr, pay []byte, pnum packetNumber, _ error) {
+ if len(pkt) < pnumOff+4+headerProtectionSampleSize {
+ return nil, nil, 0, errInvalidPacket
+ }
+ numpay := pkt[pnumOff:]
+ sample := numpay[4:][:headerProtectionSampleSize]
+ mask := k.hp.headerProtection(sample)
+ if isLongHeader(pkt[0]) {
+ pkt[0] ^= mask[0] & 0x0f
+ } else {
+ pkt[0] ^= mask[0] & 0x1f
+ }
+ pnumLen := int(pkt[0]&0x03) + 1
+ pnum = packetNumber(0)
+ for i := 0; i < pnumLen; i++ {
+ numpay[i] ^= mask[1+i]
+ pnum = (pnum << 8) | packetNumber(numpay[i])
+ }
+ pnum = decodePacketNumber(pnumMax, pnum, pnumLen)
+ hdr = pkt[:pnumOff+pnumLen]
+ pay = numpay[pnumLen:]
+ return hdr, pay, pnum, nil
+}
+
+// headerProtection is the header_protection function as defined in:
+// https://www.rfc-editor.org/rfc/rfc9001#section-5.4.1
+//
+// This function takes a sample of the packet ciphertext
+// and returns a 5-byte mask which will be applied to the
+// protected portions of the packet header.
+type headerProtection interface {
+ headerProtection(sample []byte) (mask [5]byte)
+}
+
+// AES-based header protection.
+// https://www.rfc-editor.org/rfc/rfc9001#section-5.4.3
+type aesHeaderProtection struct {
+ cipher cipher.Block
+ scratch [aes.BlockSize]byte
+}
+
+func (hp *aesHeaderProtection) headerProtection(sample []byte) (mask [5]byte) {
+ hp.cipher.Encrypt(hp.scratch[:], sample)
+ copy(mask[:], hp.scratch[:])
+ return mask
+}
+
+// ChaCha20-based header protection.
+// https://www.rfc-editor.org/rfc/rfc9001#section-5.4.4
+type chaCha20HeaderProtection struct {
+ key []byte
+}
+
+func (hp chaCha20HeaderProtection) headerProtection(sample []byte) (mask [5]byte) {
+ counter := uint32(sample[3])<<24 | uint32(sample[2])<<16 | uint32(sample[1])<<8 | uint32(sample[0])
+ nonce := sample[4:16]
+ c, err := chacha20.NewUnauthenticatedCipher(hp.key, nonce)
+ if err != nil {
+ panic(err)
+ }
+ c.SetCounter(counter)
+ c.XORKeyStream(mask[:], mask[:])
+ return mask
+}
+
+// A packetKey applies or removes packet protection.
+// https://www.rfc-editor.org/rfc/rfc9001#section-5.1
+type packetKey struct {
+ aead cipher.AEAD // AEAD function used for packet protection.
+ iv []byte // IV used to construct the AEAD nonce.
+}
+
+func (k *packetKey) init(suite uint16, secret []byte) {
// https://www.rfc-editor.org/rfc/rfc9001#section-5.1
- key := hkdfExpandLabel(h.New, secret, "quic key", nil, keyBytes)
+ h, keySize := hashForSuite(suite)
+ key := hkdfExpandLabel(h.New, secret, "quic key", nil, keySize)
+ switch suite {
+ case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
+ k.aead = newAESAEAD(key)
+ case tls.TLS_CHACHA20_POLY1305_SHA256:
+ k.aead = newChaCha20AEAD(key)
+ default:
+ panic("BUG: unknown cipher suite")
+ }
+ k.iv = hkdfExpandLabel(h.New, secret, "quic iv", nil, k.aead.NonceSize())
+}
+
+func newAESAEAD(key []byte) cipher.AEAD {
c, err := aes.NewCipher(key)
if err != nil {
panic(err)
@@ -70,53 +170,109 @@
if err != nil {
panic(err)
}
- iv := hkdfExpandLabel(h.New, secret, "quic iv", nil, aead.NonceSize())
- // https://www.rfc-editor.org/rfc/rfc9001#section-5.4.3
- hpKey := hkdfExpandLabel(h.New, secret, "quic hp", nil, keyBytes)
- hp, err := aes.NewCipher(hpKey)
- if err != nil {
- panic(err)
- }
- var scratch [aes.BlockSize]byte
- headerProtection := func(sample []byte) (mask [5]byte) {
- hp.Encrypt(scratch[:], sample)
- copy(mask[:], scratch[:])
- return mask
- }
- return keys{
- aead: aead,
- iv: iv,
- headerProtection: headerProtection,
- }
+ return aead
}
-func newChaCha20Keys(secret []byte) keys {
- // https://www.rfc-editor.org/rfc/rfc9001#section-5.1
- key := hkdfExpandLabel(sha256.New, secret, "quic key", nil, chacha20poly1305.KeySize)
+func newChaCha20AEAD(key []byte) cipher.AEAD {
+ var err error
aead, err := chacha20poly1305.New(key)
if err != nil {
panic(err)
}
- iv := hkdfExpandLabel(sha256.New, secret, "quic iv", nil, aead.NonceSize())
- // https://www.rfc-editor.org/rfc/rfc9001#section-5.4.4
- hpKey := hkdfExpandLabel(sha256.New, secret, "quic hp", nil, chacha20.KeySize)
- headerProtection := func(sample []byte) [5]byte {
- counter := uint32(sample[3])<<24 | uint32(sample[2])<<16 | uint32(sample[1])<<8 | uint32(sample[0])
- nonce := sample[4:16]
- c, err := chacha20.NewUnauthenticatedCipher(hpKey, nonce)
- if err != nil {
- panic(err)
- }
- c.SetCounter(counter)
- var mask [5]byte
- c.XORKeyStream(mask[:], mask[:])
- return mask
+ return aead
+}
+
+func (k packetKey) protect(hdr, pay []byte, pnum packetNumber) []byte {
+ k.xorIV(pnum)
+ defer k.xorIV(pnum)
+ return k.aead.Seal(hdr, k.iv, pay, hdr)
+}
+
+func (k packetKey) unprotect(hdr, pay []byte, pnum packetNumber) (dec []byte, err error) {
+ k.xorIV(pnum)
+ defer k.xorIV(pnum)
+ return k.aead.Open(pay[:0], k.iv, pay, hdr)
+}
+
+// xorIV xors the packet protection IV with the packet number.
+func (k packetKey) xorIV(pnum packetNumber) {
+ k.iv[len(k.iv)-8] ^= uint8(pnum >> 56)
+ k.iv[len(k.iv)-7] ^= uint8(pnum >> 48)
+ k.iv[len(k.iv)-6] ^= uint8(pnum >> 40)
+ k.iv[len(k.iv)-5] ^= uint8(pnum >> 32)
+ k.iv[len(k.iv)-4] ^= uint8(pnum >> 24)
+ k.iv[len(k.iv)-3] ^= uint8(pnum >> 16)
+ k.iv[len(k.iv)-2] ^= uint8(pnum >> 8)
+ k.iv[len(k.iv)-1] ^= uint8(pnum)
+}
+
+// A fixedKeys is a header protection key and fixed packet protection key.
+// The packet protection key is fixed (it does not update).
+//
+// Fixed keys are used for Initial and Handshake keys, which do not update.
+type fixedKeys struct {
+ hdr headerKey
+ pkt packetKey
+}
+
+func (k *fixedKeys) init(suite uint16, secret []byte) {
+ k.hdr.init(suite, secret)
+ k.pkt.init(suite, secret)
+}
+
+func (k fixedKeys) isSet() bool {
+ return k.hdr.hp != nil
+}
+
+// protect applies packet protection to a packet.
+//
+// On input, hdr contains the packet header, pay the unencrypted payload,
+// pnumOff the offset of the packet number in the header, and pnum the untruncated
+// packet number.
+//
+// protect returns the result of appending the encrypted payload to hdr and
+// applying header protection.
+func (k fixedKeys) protect(hdr, pay []byte, pnumOff int, pnum packetNumber) []byte {
+ pkt := k.pkt.protect(hdr, pay, pnum)
+ k.hdr.protect(pkt, pnumOff)
+ return pkt
+}
+
+// unprotect removes packet protection from a packet.
+//
+// On input, pkt contains the full protected packet, pnumOff the offset of
+// the packet number in the header, and pnumMax the largest packet number
+// seen in the number space of this packet.
+//
+// unprotect removes header protection from the header in pkt, and returns
+// the unprotected payload and packet number.
+func (k fixedKeys) unprotect(pkt []byte, pnumOff int, pnumMax packetNumber) (pay []byte, num packetNumber, err error) {
+ hdr, pay, pnum, err := k.hdr.unprotect(pkt, pnumOff, pnumMax)
+ if err != nil {
+ return nil, 0, err
}
- return keys{
- aead: aead,
- iv: iv,
- headerProtection: headerProtection,
+ pay, err = k.pkt.unprotect(hdr, pay, pnum)
+ if err != nil {
+ return nil, 0, err
}
+ return pay, pnum, nil
+}
+
+// A fixedKeyPair is a read/write pair of fixed keys.
+type fixedKeyPair struct {
+ r, w fixedKeys
+}
+
+func (k *fixedKeyPair) discard() {
+ *k = fixedKeyPair{}
+}
+
+func (k *fixedKeyPair) canRead() bool {
+ return k.r.isSet()
+}
+
+func (k *fixedKeyPair) canWrite() bool {
+ return k.w.isSet()
}
// https://www.rfc-editor.org/rfc/rfc9001#section-5.2-2
@@ -128,121 +284,44 @@
// field in the client's first Initial packet.
//
// https://www.rfc-editor.org/rfc/rfc9001#section-5.2
-func initialKeys(cid []byte) (clientKeys, serverKeys keys) {
+func initialKeys(cid []byte, side connSide) fixedKeyPair {
initialSecret := hkdf.Extract(sha256.New, cid, initialSalt)
- clientInitialSecret := hkdfExpandLabel(sha256.New, initialSecret, "client in", nil, sha256.Size)
- clientKeys, err := newKeys(tls.TLS_AES_128_GCM_SHA256, clientInitialSecret)
- if err != nil {
- panic(err)
- }
-
- serverInitialSecret := hkdfExpandLabel(sha256.New, initialSecret, "server in", nil, sha256.Size)
- serverKeys, err = newKeys(tls.TLS_AES_128_GCM_SHA256, serverInitialSecret)
- if err != nil {
- panic(err)
- }
-
- return clientKeys, serverKeys
-}
-
-const headerProtectionSampleSize = 16
-
-// aeadOverhead is the difference in size between the AEAD output and input.
-// All cipher suites defined for use with QUIC have 16 bytes of overhead.
-const aeadOverhead = 16
-
-// xorIV xors the packet protection IV with the packet number.
-func (k keys) xorIV(pnum packetNumber) {
- k.iv[len(k.iv)-8] ^= uint8(pnum >> 56)
- k.iv[len(k.iv)-7] ^= uint8(pnum >> 48)
- k.iv[len(k.iv)-6] ^= uint8(pnum >> 40)
- k.iv[len(k.iv)-5] ^= uint8(pnum >> 32)
- k.iv[len(k.iv)-4] ^= uint8(pnum >> 24)
- k.iv[len(k.iv)-3] ^= uint8(pnum >> 16)
- k.iv[len(k.iv)-2] ^= uint8(pnum >> 8)
- k.iv[len(k.iv)-1] ^= uint8(pnum)
-}
-
-// isSet returns true if valid keys are available.
-func (k keys) isSet() bool {
- return k.aead != nil
-}
-
-// discard discards the keys (in the sense that we won't use them any more,
-// not that the keys are securely erased).
-//
-// https://www.rfc-editor.org/rfc/rfc9001.html#section-4.9
-func (k *keys) discard() {
- *k = keys{}
-}
-
-// protect applies packet protection to a packet.
-//
-// On input, hdr contains the packet header, pay the unencrypted payload,
-// pnumOff the offset of the packet number in the header, and pnum the untruncated
-// packet number.
-//
-// protect returns the result of appending the encrypted payload to hdr and
-// applying header protection.
-func (k keys) protect(hdr, pay []byte, pnumOff int, pnum packetNumber) []byte {
- k.xorIV(pnum)
- hdr = k.aead.Seal(hdr, k.iv, pay, hdr)
- k.xorIV(pnum)
-
- // Apply header protection.
- pnumSize := int(hdr[0]&0x03) + 1
- sample := hdr[pnumOff+4:][:headerProtectionSampleSize]
- mask := k.headerProtection(sample)
- if isLongHeader(hdr[0]) {
- hdr[0] ^= mask[0] & 0x0f
+ var clientKeys fixedKeys
+ clientSecret := hkdfExpandLabel(sha256.New, initialSecret, "client in", nil, sha256.Size)
+ clientKeys.init(tls.TLS_AES_128_GCM_SHA256, clientSecret)
+ var serverKeys fixedKeys
+ serverSecret := hkdfExpandLabel(sha256.New, initialSecret, "server in", nil, sha256.Size)
+ serverKeys.init(tls.TLS_AES_128_GCM_SHA256, serverSecret)
+ if side == clientSide {
+ return fixedKeyPair{r: serverKeys, w: clientKeys}
} else {
- hdr[0] ^= mask[0] & 0x1f
+ return fixedKeyPair{w: serverKeys, r: clientKeys}
}
- for i := 0; i < pnumSize; i++ {
- hdr[pnumOff+i] ^= mask[1+i]
- }
-
- return hdr
}
-// unprotect removes packet protection from a packet.
-//
-// On input, pkt contains the full protected packet, pnumOff the offset of
-// the packet number in the header, and pnumMax the largest packet number
-// seen in the number space of this packet.
-//
-// unprotect removes header protection from the header in pkt, and returns
-// the unprotected payload and packet number.
-func (k keys) unprotect(pkt []byte, pnumOff int, pnumMax packetNumber) (pay []byte, num packetNumber, err error) {
- if len(pkt) < pnumOff+4+headerProtectionSampleSize {
- return nil, 0, errInvalidPacket
+// checkCipherSuite returns an error if suite is not a supported cipher suite.
+func checkCipherSuite(suite uint16) error {
+ switch suite {
+ case tls.TLS_AES_128_GCM_SHA256:
+ case tls.TLS_AES_256_GCM_SHA384:
+ case tls.TLS_CHACHA20_POLY1305_SHA256:
+ default:
+ return errors.New("invalid cipher suite")
}
- numpay := pkt[pnumOff:]
- sample := numpay[4:][:headerProtectionSampleSize]
- mask := k.headerProtection(sample)
- if isLongHeader(pkt[0]) {
- pkt[0] ^= mask[0] & 0x0f
- } else {
- pkt[0] ^= mask[0] & 0x1f
- }
- pnumLen := int(pkt[0]&0x03) + 1
- pnum := packetNumber(0)
- for i := 0; i < pnumLen; i++ {
- numpay[i] ^= mask[1+i]
- pnum = (pnum << 8) | packetNumber(numpay[i])
- }
- pnum = decodePacketNumber(pnumMax, pnum, pnumLen)
+ return nil
+}
- hdr := pkt[:pnumOff+pnumLen]
- pay = numpay[pnumLen:]
- k.xorIV(pnum)
- pay, err = k.aead.Open(pay[:0], k.iv, pay, hdr)
- k.xorIV(pnum)
- if err != nil {
- return nil, 0, err
+func hashForSuite(suite uint16) (h crypto.Hash, keySize int) {
+ switch suite {
+ case tls.TLS_AES_128_GCM_SHA256:
+ return crypto.SHA256, 128 / 8
+ case tls.TLS_AES_256_GCM_SHA384:
+ return crypto.SHA384, 256 / 8
+ case tls.TLS_CHACHA20_POLY1305_SHA256:
+ return crypto.SHA256, chacha20.KeySize
+ default:
+ panic("BUG: unknown cipher suite")
}
-
- return pay, pnum, nil
}
// hdkfExpandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1.
diff --git a/internal/quic/packet_protection_test.go b/internal/quic/packet_protection_test.go
index 6495360..1fe1307 100644
--- a/internal/quic/packet_protection_test.go
+++ b/internal/quic/packet_protection_test.go
@@ -16,10 +16,11 @@
// Test cases from:
// https://www.rfc-editor.org/rfc/rfc9001#section-appendix.a
cid := unhex(`8394c8f03e515708`)
- initialClientKeys, initialServerKeys := initialKeys(cid)
+ k := initialKeys(cid, clientSide)
+ initialClientKeys, initialServerKeys := k.w, k.r
for _, test := range []struct {
name string
- k keys
+ k fixedKeys
pnum packetNumber
hdr []byte
pay []byte
@@ -103,15 +104,13 @@
`),
}, {
name: "ChaCha20_Poly1305 Short Header",
- k: func() keys {
+ k: func() fixedKeys {
secret := unhex(`
9ac312a7f877468ebe69422748ad00a1
5443f18203a07d6060f688f30f21632b
`)
- k, err := newKeys(tls.TLS_CHACHA20_POLY1305_SHA256, secret)
- if err != nil {
- t.Fatal(err)
- }
+ var k fixedKeys
+ k.init(tls.TLS_CHACHA20_POLY1305_SHA256, secret)
return k
}(),
pnum: 654360564,
diff --git a/internal/quic/packet_writer.go b/internal/quic/packet_writer.go
index a80b471..2009895 100644
--- a/internal/quic/packet_writer.go
+++ b/internal/quic/packet_writer.go
@@ -100,7 +100,7 @@
// finishProtectedLongHeaderPacket finishes writing an Initial, 0-RTT, or Handshake packet,
// canceling the packet if it contains no payload.
// It returns a sentPacket describing the packet, or nil if no packet was written.
-func (w *packetWriter) finishProtectedLongHeaderPacket(pnumMaxAcked packetNumber, k keys, p longPacket) *sentPacket {
+func (w *packetWriter) finishProtectedLongHeaderPacket(pnumMaxAcked packetNumber, k fixedKeys, p longPacket) *sentPacket {
if len(w.b) == w.payOff {
// The payload is empty, so just abandon the packet.
w.b = w.b[:w.pktOff]
@@ -135,7 +135,8 @@
pnumOff := len(hdr)
hdr = appendPacketNumber(hdr, p.num, pnumMaxAcked)
- return w.protect(hdr[w.pktOff:], p.num, pnumOff, k)
+ k.protect(hdr[w.pktOff:], w.b[len(hdr):], pnumOff-w.pktOff, p.num)
+ return w.finish(p.num)
}
// start1RTTPacket starts writing a 1-RTT (short header) packet.
@@ -162,7 +163,7 @@
// finish1RTTPacket finishes writing a 1-RTT packet,
// canceling the packet if it contains no payload.
// It returns a sentPacket describing the packet, or nil if no packet was written.
-func (w *packetWriter) finish1RTTPacket(pnum, pnumMaxAcked packetNumber, dstConnID []byte, k keys) *sentPacket {
+func (w *packetWriter) finish1RTTPacket(pnum, pnumMaxAcked packetNumber, dstConnID []byte, k fixedKeys) *sentPacket {
if len(w.b) == w.payOff {
// The payload is empty, so just abandon the packet.
w.b = w.b[:w.pktOff]
@@ -177,7 +178,8 @@
pnumOff := len(hdr)
hdr = appendPacketNumber(hdr, pnum, pnumMaxAcked)
w.padPacketLength(pnumLen)
- return w.protect(hdr[w.pktOff:], pnum, pnumOff, k)
+ k.protect(hdr[w.pktOff:], w.b[len(hdr):], pnumOff-w.pktOff, pnum)
+ return w.finish(pnum)
}
// padPacketLength pads out the payload of the current packet to the minimum size,
@@ -197,9 +199,8 @@
return plen
}
-// protect applies packet protection and finishes the current packet.
-func (w *packetWriter) protect(hdr []byte, pnum packetNumber, pnumOff int, k keys) *sentPacket {
- k.protect(hdr, w.b[w.pktOff+len(hdr):], pnumOff-w.pktOff, pnum)
+// finish finishes the current packet after protection is applied.
+func (w *packetWriter) finish(pnum packetNumber) *sentPacket {
w.b = w.b[:len(w.b)+aeadOverhead]
w.sent.size = len(w.b) - w.pktOff
w.sent.num = pnum
diff --git a/internal/quic/tls.go b/internal/quic/tls.go
index e3a430e..a37e26f 100644
--- a/internal/quic/tls.go
+++ b/internal/quic/tls.go
@@ -16,12 +16,7 @@
// startTLS starts the TLS handshake.
func (c *Conn) startTLS(now time.Time, initialConnID []byte, params transportParameters) error {
- clientKeys, serverKeys := initialKeys(initialConnID)
- if c.side == clientSide {
- c.wkeys[initialSpace], c.rkeys[initialSpace] = clientKeys, serverKeys
- } else {
- c.wkeys[initialSpace], c.rkeys[initialSpace] = serverKeys, clientKeys
- }
+ c.keysInitial = initialKeys(initialConnID, c.side)
qconfig := &tls.QUICConfig{TLSConfig: c.config.TLSConfig}
if c.side == clientSide {
@@ -49,21 +44,36 @@
case tls.QUICNoEvent:
return nil
case tls.QUICSetReadSecret:
- space, k, err := tlsKey(e)
- if err != nil {
+ if err := checkCipherSuite(e.Suite); err != nil {
return err
}
- c.rkeys[space] = k
+ switch e.Level {
+ case tls.QUICEncryptionLevelHandshake:
+ c.keysHandshake.r.init(e.Suite, e.Data)
+ case tls.QUICEncryptionLevelApplication:
+ c.keysAppData.r.init(e.Suite, e.Data)
+ }
case tls.QUICSetWriteSecret:
- space, k, err := tlsKey(e)
- if err != nil {
+ if err := checkCipherSuite(e.Suite); err != nil {
return err
}
- c.wkeys[space] = k
+ switch e.Level {
+ case tls.QUICEncryptionLevelHandshake:
+ c.keysHandshake.w.init(e.Suite, e.Data)
+ case tls.QUICEncryptionLevelApplication:
+ c.keysAppData.w.init(e.Suite, e.Data)
+ }
case tls.QUICWriteData:
- space, err := spaceForLevel(e.Level)
- if err != nil {
- return err
+ var space numberSpace
+ switch e.Level {
+ case tls.QUICEncryptionLevelInitial:
+ space = initialSpace
+ case tls.QUICEncryptionLevelHandshake:
+ space = handshakeSpace
+ case tls.QUICEncryptionLevelApplication:
+ space = appDataSpace
+ default:
+ return fmt.Errorf("quic: internal error: write handshake data at level %v", e.Level)
}
c.crypto[space].write(e.Data)
case tls.QUICHandshakeDone:
@@ -86,32 +96,6 @@
}
}
-// tlsKey returns the keys in a QUICSetReadSecret or QUICSetWriteSecret event.
-func tlsKey(e tls.QUICEvent) (numberSpace, keys, error) {
- space, err := spaceForLevel(e.Level)
- if err != nil {
- return 0, keys{}, err
- }
- k, err := newKeys(e.Suite, e.Data)
- if err != nil {
- return 0, keys{}, err
- }
- return space, k, nil
-}
-
-func spaceForLevel(level tls.QUICEncryptionLevel) (numberSpace, error) {
- switch level {
- case tls.QUICEncryptionLevelInitial:
- return initialSpace, nil
- case tls.QUICEncryptionLevelHandshake:
- return handshakeSpace, nil
- case tls.QUICEncryptionLevelApplication:
- return appDataSpace, nil
- default:
- return 0, fmt.Errorf("quic: internal error: write handshake data at level %v", level)
- }
-}
-
// handleCrypto processes data received in a CRYPTO frame.
func (c *Conn) handleCrypto(now time.Time, space numberSpace, off int64, data []byte) error {
var level tls.QUICEncryptionLevel