quic: refactor keys for key updates

Refactor how we store encryption keys in preparation for adding
support for key updates.

Previously, we had a single "keys" type containing header and packet
protection key material. With key update, the 1-RTT header protection
keys are consistent across the lifetime of a connection, while
packet protection keys vary. Separate out the header and packet
protection keys into distinct types.

Add "fixed" key types for keys which remain fixed across a
connection's lifetime and do not update. For the moment,
1-RTT keys are still fixed.

Remove a number of can-never-happen error returns from
key handling paths. We were previously inconsistent about
where to panic and where to return an error on these paths;
we now consistently panic in paths where errors can only
occur due to a bug. (For example, attempting to create an
AEAD with an incorrect secret size.)

No functional changes, this is purely refactoring.

For golang/go#58547

Change-Id: I49f83091517186e452845b65a1597add60e5fc92
Reviewed-on: https://go-review.googlesource.com/c/net/+/529155
Reviewed-by: Jonathan Amsterdam <jba@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/internal/quic/conn.go b/internal/quic/conn.go
index 26c25f8..4565e1a 100644
--- a/internal/quic/conn.go
+++ b/internal/quic/conn.go
@@ -42,10 +42,11 @@
 	idleTimeout    time.Time
 
 	// Packet protection keys, CRYPTO streams, and TLS state.
-	rkeys  [numberSpaceCount]keys
-	wkeys  [numberSpaceCount]keys
-	crypto [numberSpaceCount]cryptoStream
-	tls    *tls.QUICConn
+	keysInitial   fixedKeyPair
+	keysHandshake fixedKeyPair
+	keysAppData   fixedKeyPair
+	crypto        [numberSpaceCount]cryptoStream
+	tls           *tls.QUICConn
 
 	// handshakeConfirmed is set when the handshake is confirmed.
 	// For server connections, it tracks sending HANDSHAKE_DONE.
@@ -156,8 +157,12 @@
 // discardKeys discards unused packet protection keys.
 // https://www.rfc-editor.org/rfc/rfc9001#section-4.9
 func (c *Conn) discardKeys(now time.Time, space numberSpace) {
-	c.rkeys[space].discard()
-	c.wkeys[space].discard()
+	switch space {
+	case initialSpace:
+		c.keysInitial.discard()
+	case handshakeSpace:
+		c.keysHandshake.discard()
+	}
 	c.loss.discardKeys(now, space)
 }
 
diff --git a/internal/quic/conn_recv.go b/internal/quic/conn_recv.go
index 64e5f98..d1fa52d 100644
--- a/internal/quic/conn_recv.go
+++ b/internal/quic/conn_recv.go
@@ -26,9 +26,9 @@
 				// https://www.rfc-editor.org/rfc/rfc9000#section-14.1-4
 				return
 			}
-			n = c.handleLongHeader(now, ptype, initialSpace, buf)
+			n = c.handleLongHeader(now, ptype, initialSpace, c.keysInitial.r, buf)
 		case packetTypeHandshake:
-			n = c.handleLongHeader(now, ptype, handshakeSpace, buf)
+			n = c.handleLongHeader(now, ptype, handshakeSpace, c.keysHandshake.r, buf)
 		case packetType1RTT:
 			n = c.handle1RTT(now, buf)
 		default:
@@ -43,13 +43,13 @@
 	}
 }
 
-func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpace, buf []byte) int {
-	if !c.rkeys[space].isSet() {
+func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpace, k fixedKeys, buf []byte) int {
+	if !k.isSet() {
 		return skipLongHeaderPacket(buf)
 	}
 
 	pnumMax := c.acks[space].largestSeen()
-	p, n := parseLongHeaderPacket(buf, c.rkeys[space], pnumMax)
+	p, n := parseLongHeaderPacket(buf, k, pnumMax)
 	if n < 0 {
 		return -1
 	}
@@ -82,14 +82,14 @@
 }
 
 func (c *Conn) handle1RTT(now time.Time, buf []byte) int {
-	if !c.rkeys[appDataSpace].isSet() {
+	if !c.keysAppData.canRead() {
 		// 1-RTT packets extend to the end of the datagram,
 		// so skip the remainder of the datagram if we can't parse this.
 		return len(buf)
 	}
 
 	pnumMax := c.acks[appDataSpace].largestSeen()
-	p, n := parse1RTTPacket(buf, c.rkeys[appDataSpace], connIDLen, pnumMax)
+	p, n := parse1RTTPacket(buf, c.keysAppData.r, connIDLen, pnumMax)
 	if n < 0 {
 		return -1
 	}
diff --git a/internal/quic/conn_send.go b/internal/quic/conn_send.go
index 853c845..58a3df1 100644
--- a/internal/quic/conn_send.go
+++ b/internal/quic/conn_send.go
@@ -59,7 +59,7 @@
 		// Initial packet.
 		pad := false
 		var sentInitial *sentPacket
-		if k := c.wkeys[initialSpace]; k.isSet() {
+		if c.keysInitial.canWrite() {
 			pnumMaxAcked := c.acks[initialSpace].largestSeen()
 			pnum := c.loss.nextNumber(initialSpace)
 			p := longPacket{
@@ -74,7 +74,7 @@
 			if logPackets {
 				logSentPacket(c, packetTypeInitial, pnum, p.srcConnID, p.dstConnID, c.w.payload())
 			}
-			sentInitial = c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, k, p)
+			sentInitial = c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, c.keysInitial.w, p)
 			if sentInitial != nil {
 				// Client initial packets need to be sent in a datagram padded to
 				// at least 1200 bytes. We can't add the padding yet, however,
@@ -86,7 +86,7 @@
 		}
 
 		// Handshake packet.
-		if k := c.wkeys[handshakeSpace]; k.isSet() {
+		if c.keysHandshake.canWrite() {
 			pnumMaxAcked := c.acks[handshakeSpace].largestSeen()
 			pnum := c.loss.nextNumber(handshakeSpace)
 			p := longPacket{
@@ -101,7 +101,7 @@
 			if logPackets {
 				logSentPacket(c, packetTypeHandshake, pnum, p.srcConnID, p.dstConnID, c.w.payload())
 			}
-			if sent := c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, k, p); sent != nil {
+			if sent := c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, c.keysHandshake.w, p); sent != nil {
 				c.loss.packetSent(now, handshakeSpace, sent)
 				if c.side == clientSide {
 					// "[...] a client MUST discard Initial keys when it first
@@ -113,7 +113,7 @@
 		}
 
 		// 1-RTT packet.
-		if k := c.wkeys[appDataSpace]; k.isSet() {
+		if c.keysAppData.canWrite() {
 			pnumMaxAcked := c.acks[appDataSpace].largestSeen()
 			pnum := c.loss.nextNumber(appDataSpace)
 			c.w.start1RTTPacket(pnum, pnumMaxAcked, dstConnID)
@@ -128,7 +128,7 @@
 			if logPackets {
 				logSentPacket(c, packetType1RTT, pnum, nil, dstConnID, c.w.payload())
 			}
-			if sent := c.w.finish1RTTPacket(pnum, pnumMaxAcked, dstConnID, k); sent != nil {
+			if sent := c.w.finish1RTTPacket(pnum, pnumMaxAcked, dstConnID, c.keysAppData.w); sent != nil {
 				c.loss.packetSent(now, appDataSpace, sent)
 			}
 		}
@@ -157,7 +157,10 @@
 					sentInitial.inFlight = true
 				}
 			}
-			if k := c.wkeys[initialSpace]; k.isSet() {
+			// If we're a client and this Initial packet is coalesced
+			// with a Handshake packet, then we've discarded Initial keys
+			// since constructing the packet and shouldn't record it as in-flight.
+			if c.keysInitial.canWrite() {
 				c.loss.packetSent(now, initialSpace, sentInitial)
 			}
 		}
diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go
index 4228ce7..3fef62d 100644
--- a/internal/quic/conn_test.go
+++ b/internal/quic/conn_test.go
@@ -113,15 +113,18 @@
 	timerLastFired time.Time
 	idlec          chan struct{} // only accessed on the conn's loop
 
-	// Read and write keys are distinct from the conn's keys,
+	// Keys are distinct from the conn's keys,
 	// because the test may know about keys before the conn does.
 	// For example, when sending a datagram with coalesced
 	// Initial and Handshake packets to a client conn,
 	// we use Handshake keys to encrypt the packet.
 	// The client only acquires those keys when it processes
 	// the Initial packet.
-	rkeys [numberSpaceCount]keyData // for packets sent to the conn
-	wkeys [numberSpaceCount]keyData // for packets sent by the conn
+	keysInitial   fixedKeyPair
+	keysHandshake fixedKeyPair
+	keysAppData   fixedKeyPair
+	rsecrets      [numberSpaceCount]testKeySecret
+	wsecrets      [numberSpaceCount]testKeySecret
 
 	// testConn uses a test hook to snoop on the conn's TLS events.
 	// CRYPTO data produced by the conn's QUICConn is placed in
@@ -156,10 +159,9 @@
 	asyncTestState
 }
 
-type keyData struct {
+type testKeySecret struct {
 	suite  uint16
 	secret []byte
-	k      keys
 }
 
 // newTestConn creates a Conn for testing.
@@ -225,8 +227,8 @@
 	}
 	tc.conn = conn
 
-	tc.wkeys[initialSpace].k = conn.wkeys[initialSpace]
-	tc.rkeys[initialSpace].k = conn.rkeys[initialSpace]
+	tc.keysInitial.r = conn.keysInitial.w
+	tc.keysInitial.w = conn.keysInitial.r
 
 	tc.wait()
 	return tc
@@ -611,14 +613,19 @@
 	for _, f := range p.frames {
 		f.write(&w)
 	}
-	space := spaceForPacketType(p.ptype)
-	if !tc.rkeys[space].k.isSet() {
-		tc.t.Fatalf("sending packet with no %v keys available", space)
-		return nil
-	}
 	w.appendPaddingTo(pad)
 	if p.ptype != packetType1RTT {
-		w.finishProtectedLongHeaderPacket(pnumMaxAcked, tc.rkeys[space].k, longPacket{
+		var k fixedKeyPair
+		switch p.ptype {
+		case packetTypeInitial:
+			k = tc.keysInitial
+		case packetTypeHandshake:
+			k = tc.keysHandshake
+		}
+		if !k.canWrite() {
+			tc.t.Fatalf("sending %v packet with no write key", p.ptype)
+		}
+		w.finishProtectedLongHeaderPacket(pnumMaxAcked, k.w, longPacket{
 			ptype:     p.ptype,
 			version:   p.version,
 			num:       p.num,
@@ -626,7 +633,10 @@
 			srcConnID: p.srcConnID,
 		})
 	} else {
-		w.finish1RTTPacket(p.num, pnumMaxAcked, p.dstConnID, tc.rkeys[space].k)
+		if !tc.keysAppData.canWrite() {
+			tc.t.Fatalf("sending %v packet with no write key", p.ptype)
+		}
+		w.finish1RTTPacket(p.num, pnumMaxAcked, p.dstConnID, tc.keysAppData.w)
 	}
 	return w.datagram()
 }
@@ -642,13 +652,19 @@
 			break
 		}
 		ptype := getPacketType(buf)
-		space := spaceForPacketType(ptype)
-		if !tc.wkeys[space].k.isSet() {
-			tc.t.Fatalf("no keys for space %v, packet type %v", space, ptype)
-		}
 		if isLongHeader(buf[0]) {
+			var k fixedKeyPair
+			switch ptype {
+			case packetTypeInitial:
+				k = tc.keysInitial
+			case packetTypeHandshake:
+				k = tc.keysHandshake
+			}
+			if !k.canRead() {
+				tc.t.Fatalf("reading %v packet with no read key", ptype)
+			}
 			var pnumMax packetNumber // TODO: Track packet numbers.
-			p, n := parseLongHeaderPacket(buf, tc.wkeys[space].k, pnumMax)
+			p, n := parseLongHeaderPacket(buf, k.r, pnumMax)
 			if n < 0 {
 				tc.t.Fatalf("packet parse error")
 			}
@@ -666,8 +682,11 @@
 			})
 			buf = buf[n:]
 		} else {
+			if !tc.keysAppData.canRead() {
+				tc.t.Fatalf("reading 1-RTT packet with no read key")
+			}
 			var pnumMax packetNumber // TODO: Track packet numbers.
-			p, n := parse1RTTPacket(buf, tc.wkeys[space].k, len(tc.peerConnID), pnumMax)
+			p, n := parse1RTTPacket(buf, tc.keysAppData.r, len(tc.peerConnID), pnumMax)
 			if n < 0 {
 				tc.t.Fatalf("packet parse error")
 			}
@@ -747,12 +766,7 @@
 // and verify that both sides of the connection are getting
 // matching keys.
 func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) {
-	setKey := func(keys *[numberSpaceCount]keyData, e tls.QUICEvent) {
-		k, err := newKeys(e.Suite, e.Data)
-		if err != nil {
-			tc.t.Errorf("newKeys: %v", err)
-			return
-		}
+	checkKey := func(typ string, secrets *[numberSpaceCount]testKeySecret, e tls.QUICEvent) {
 		var space numberSpace
 		switch {
 		case e.Level == tls.QUICEncryptionLevelHandshake:
@@ -763,25 +777,30 @@
 			tc.t.Errorf("unexpected encryption level %v", e.Level)
 			return
 		}
-		s := "read"
-		if keys == &tc.wkeys {
-			s = "write"
+		if secrets[space].secret == nil {
+			secrets[space].suite = e.Suite
+			secrets[space].secret = append([]byte{}, e.Data...)
+		} else if secrets[space].suite != e.Suite || !bytes.Equal(secrets[space].secret, e.Data) {
+			tc.t.Errorf("%v key mismatch for level %v", typ, e.Level)
 		}
-		if keys[space].k.isSet() {
-			if keys[space].suite != e.Suite || !bytes.Equal(keys[space].secret, e.Data) {
-				tc.t.Errorf("%v key mismatch for level for level %v", s, e.Level)
-			}
-			return
-		}
-		keys[space].suite = e.Suite
-		keys[space].secret = append([]byte{}, e.Data...)
-		keys[space].k = k
 	}
 	switch e.Kind {
 	case tls.QUICSetReadSecret:
-		setKey(&tc.rkeys, e)
+		checkKey("read", &tc.rsecrets, e)
+		switch e.Level {
+		case tls.QUICEncryptionLevelHandshake:
+			tc.keysHandshake.w.init(e.Suite, e.Data)
+		case tls.QUICEncryptionLevelApplication:
+			tc.keysAppData.w.init(e.Suite, e.Data)
+		}
 	case tls.QUICSetWriteSecret:
-		setKey(&tc.wkeys, e)
+		checkKey("write", &tc.wsecrets, e)
+		switch e.Level {
+		case tls.QUICEncryptionLevelHandshake:
+			tc.keysHandshake.r.init(e.Suite, e.Data)
+		case tls.QUICEncryptionLevelApplication:
+			tc.keysAppData.r.init(e.Suite, e.Data)
+		}
 	case tls.QUICWriteData:
 		tc.cryptoDataOut[e.Level] = append(tc.cryptoDataOut[e.Level], e.Data...)
 		tc.peerTLSConn.HandleData(e.Level, e.Data)
@@ -792,9 +811,21 @@
 		case tls.QUICNoEvent:
 			return
 		case tls.QUICSetReadSecret:
-			setKey(&tc.wkeys, e)
+			checkKey("write", &tc.wsecrets, e)
+			switch e.Level {
+			case tls.QUICEncryptionLevelHandshake:
+				tc.keysHandshake.r.init(e.Suite, e.Data)
+			case tls.QUICEncryptionLevelApplication:
+				tc.keysAppData.r.init(e.Suite, e.Data)
+			}
 		case tls.QUICSetWriteSecret:
-			setKey(&tc.rkeys, e)
+			checkKey("read", &tc.rsecrets, e)
+			switch e.Level {
+			case tls.QUICEncryptionLevelHandshake:
+				tc.keysHandshake.w.init(e.Suite, e.Data)
+			case tls.QUICEncryptionLevelApplication:
+				tc.keysAppData.w.init(e.Suite, e.Data)
+			}
 		case tls.QUICWriteData:
 			tc.cryptoDataIn[e.Level] = append(tc.cryptoDataIn[e.Level], e.Data...)
 		case tls.QUICTransportParameters:
diff --git a/internal/quic/packet_codec_test.go b/internal/quic/packet_codec_test.go
index 3503d24..7f0846f 100644
--- a/internal/quic/packet_codec_test.go
+++ b/internal/quic/packet_codec_test.go
@@ -17,7 +17,7 @@
 	// Example Initial packet from:
 	// https://www.rfc-editor.org/rfc/rfc9001.html#section-a.3
 	cid := unhex(`8394c8f03e515708`)
-	_, initialServerKeys := initialKeys(cid)
+	initialServerKeys := initialKeys(cid, clientSide).r
 	pkt := unhex(`
 		cf000000010008f067a5502a4262b500 4075c0d95a482cd0991cd25b0aac406a
 		5816b6394100f37a1c69797554780bb3 8cc5a99f5ede4cf73c3ec2493a1839b3
@@ -65,20 +65,21 @@
 	}
 
 	// Parse with the wrong keys.
-	_, invalidKeys := initialKeys([]byte{})
+	invalidKeys := initialKeys([]byte{}, clientSide).w
 	if _, n := parseLongHeaderPacket(pkt, invalidKeys, 0); n != -1 {
 		t.Fatalf("parse long header packet with wrong keys: n=%v, want -1", n)
 	}
 }
 
 func TestRoundtripEncodeLongPacket(t *testing.T) {
-	aes128Keys, _ := newKeys(tls.TLS_AES_128_GCM_SHA256, []byte("secret"))
-	aes256Keys, _ := newKeys(tls.TLS_AES_256_GCM_SHA384, []byte("secret"))
-	chachaKeys, _ := newKeys(tls.TLS_CHACHA20_POLY1305_SHA256, []byte("secret"))
+	var aes128Keys, aes256Keys, chachaKeys fixedKeys
+	aes128Keys.init(tls.TLS_AES_128_GCM_SHA256, []byte("secret"))
+	aes256Keys.init(tls.TLS_AES_256_GCM_SHA384, []byte("secret"))
+	chachaKeys.init(tls.TLS_CHACHA20_POLY1305_SHA256, []byte("secret"))
 	for _, test := range []struct {
 		desc string
 		p    longPacket
-		k    keys
+		k    fixedKeys
 	}{{
 		desc: "Initial, 1-byte number, AES128",
 		p: longPacket{
@@ -145,9 +146,10 @@
 }
 
 func TestRoundtripEncodeShortPacket(t *testing.T) {
-	aes128Keys, _ := newKeys(tls.TLS_AES_128_GCM_SHA256, []byte("secret"))
-	aes256Keys, _ := newKeys(tls.TLS_AES_256_GCM_SHA384, []byte("secret"))
-	chachaKeys, _ := newKeys(tls.TLS_CHACHA20_POLY1305_SHA256, []byte("secret"))
+	var aes128Keys, aes256Keys, chachaKeys fixedKeys
+	aes128Keys.init(tls.TLS_AES_128_GCM_SHA256, []byte("secret"))
+	aes256Keys.init(tls.TLS_AES_256_GCM_SHA384, []byte("secret"))
+	chachaKeys.init(tls.TLS_CHACHA20_POLY1305_SHA256, []byte("secret"))
 	connID := make([]byte, connIDLen)
 	for i := range connID {
 		connID[i] = byte(i)
@@ -156,7 +158,7 @@
 		desc    string
 		num     packetNumber
 		payload []byte
-		k       keys
+		k       fixedKeys
 	}{{
 		desc:    "1-byte number, AES128",
 		num:     0, // 1-byte encoding,
@@ -700,7 +702,7 @@
 
 func FuzzParseLongHeaderPacket(f *testing.F) {
 	cid := unhex(`0000000000000000`)
-	_, initialServerKeys := initialKeys(cid)
+	initialServerKeys := initialKeys(cid, clientSide).r
 	f.Fuzz(func(t *testing.T, in []byte) {
 		parseLongHeaderPacket(in, initialServerKeys, 0)
 	})
diff --git a/internal/quic/packet_parser.go b/internal/quic/packet_parser.go
index 4323882..458cd3a 100644
--- a/internal/quic/packet_parser.go
+++ b/internal/quic/packet_parser.go
@@ -18,7 +18,7 @@
 // and its length in bytes.
 //
 // It returns an empty packet and -1 if the packet could not be parsed.
-func parseLongHeaderPacket(pkt []byte, k keys, pnumMax packetNumber) (p longPacket, n int) {
+func parseLongHeaderPacket(pkt []byte, k fixedKeys, pnumMax packetNumber) (p longPacket, n int) {
 	if len(pkt) < 5 || !isLongHeader(pkt[0]) {
 		return longPacket{}, -1
 	}
@@ -143,7 +143,7 @@
 //
 // On input, pkt contains a short header packet, k the decryption keys for the packet,
 // and pnumMax the largest packet number seen in the number space of this packet.
-func parse1RTTPacket(pkt []byte, k keys, dstConnIDLen int, pnumMax packetNumber) (p shortPacket, n int) {
+func parse1RTTPacket(pkt []byte, k fixedKeys, dstConnIDLen int, pnumMax packetNumber) (p shortPacket, n int) {
 	var err error
 	p.payload, p.num, err = k.unprotect(pkt, 1+dstConnIDLen, pnumMax)
 	if err != nil {
diff --git a/internal/quic/packet_protection.go b/internal/quic/packet_protection.go
index 1847053..2f9b9ce 100644
--- a/internal/quic/packet_protection.go
+++ b/internal/quic/packet_protection.go
@@ -13,7 +13,6 @@
 	"crypto/sha256"
 	"crypto/tls"
 	"errors"
-	"fmt"
 	"hash"
 
 	"golang.org/x/crypto/chacha20"
@@ -24,44 +23,145 @@
 
 var errInvalidPacket = errors.New("quic: invalid packet")
 
-// keys holds the cryptographic material used to protect packets
-// at an encryption level and direction. (e.g., Initial client keys.)
-//
-// keys are not safe for concurrent use.
-type keys struct {
-	// AEAD function used for packet protection.
-	aead cipher.AEAD
+// headerProtectionSampleSize is the size of the ciphertext sample used for header protection.
+// https://www.rfc-editor.org/rfc/rfc9001#section-5.4.2
+const headerProtectionSampleSize = 16
 
-	// The header_protection function as defined in:
-	// https://www.rfc-editor.org/rfc/rfc9001#section-5.4.1
-	//
-	// This function takes a sample of the packet ciphertext
-	// and returns a 5-byte mask which will be applied to the
-	// protected portions of the packet header.
-	headerProtection func(sample []byte) (mask [5]byte)
+// aeadOverhead is the difference in size between the AEAD output and input.
+// All cipher suites defined for use with QUIC have 16 bytes of overhead.
+const aeadOverhead = 16
 
-	// IV used to construct the AEAD nonce.
-	iv []byte
+// A headerKey applies or removes header protection.
+// https://www.rfc-editor.org/rfc/rfc9001#section-5.4
+type headerKey struct {
+	hp headerProtection
 }
 
-// newKeys creates keys for a given cipher suite and secret.
-//
-// It returns an error if the suite is unknown.
-func newKeys(suite uint16, secret []byte) (keys, error) {
+func (k *headerKey) init(suite uint16, secret []byte) {
+	h, keySize := hashForSuite(suite)
+	hpKey := hkdfExpandLabel(h.New, secret, "quic hp", nil, keySize)
 	switch suite {
-	case tls.TLS_AES_128_GCM_SHA256:
-		return newAESKeys(secret, crypto.SHA256, 128/8), nil
-	case tls.TLS_AES_256_GCM_SHA384:
-		return newAESKeys(secret, crypto.SHA384, 256/8), nil
+	case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
+		c, err := aes.NewCipher(hpKey)
+		if err != nil {
+			panic(err)
+		}
+		k.hp = &aesHeaderProtection{cipher: c}
 	case tls.TLS_CHACHA20_POLY1305_SHA256:
-		return newChaCha20Keys(secret), nil
+		k.hp = chaCha20HeaderProtection{hpKey}
+	default:
+		panic("BUG: unknown cipher suite")
 	}
-	return keys{}, fmt.Errorf("unknown cipher suite %x", suite)
 }
 
-func newAESKeys(secret []byte, h crypto.Hash, keyBytes int) keys {
+// protect applies header protection.
+// pnumOff is the offset of the packet number in the packet.
+func (k headerKey) protect(hdr []byte, pnumOff int) {
+	// Apply header protection.
+	pnumSize := int(hdr[0]&0x03) + 1
+	sample := hdr[pnumOff+4:][:headerProtectionSampleSize]
+	mask := k.hp.headerProtection(sample)
+	if isLongHeader(hdr[0]) {
+		hdr[0] ^= mask[0] & 0x0f
+	} else {
+		hdr[0] ^= mask[0] & 0x1f
+	}
+	for i := 0; i < pnumSize; i++ {
+		hdr[pnumOff+i] ^= mask[1+i]
+	}
+}
+
+// unprotect removes header protection.
+// pnumOff is the offset of the packet number in the packet.
+// pnumMax is the largest packet number seen in the number space of this packet.
+func (k headerKey) unprotect(pkt []byte, pnumOff int, pnumMax packetNumber) (hdr, pay []byte, pnum packetNumber, _ error) {
+	if len(pkt) < pnumOff+4+headerProtectionSampleSize {
+		return nil, nil, 0, errInvalidPacket
+	}
+	numpay := pkt[pnumOff:]
+	sample := numpay[4:][:headerProtectionSampleSize]
+	mask := k.hp.headerProtection(sample)
+	if isLongHeader(pkt[0]) {
+		pkt[0] ^= mask[0] & 0x0f
+	} else {
+		pkt[0] ^= mask[0] & 0x1f
+	}
+	pnumLen := int(pkt[0]&0x03) + 1
+	pnum = packetNumber(0)
+	for i := 0; i < pnumLen; i++ {
+		numpay[i] ^= mask[1+i]
+		pnum = (pnum << 8) | packetNumber(numpay[i])
+	}
+	pnum = decodePacketNumber(pnumMax, pnum, pnumLen)
+	hdr = pkt[:pnumOff+pnumLen]
+	pay = numpay[pnumLen:]
+	return hdr, pay, pnum, nil
+}
+
+// headerProtection is the  header_protection function as defined in:
+// https://www.rfc-editor.org/rfc/rfc9001#section-5.4.1
+//
+// This function takes a sample of the packet ciphertext
+// and returns a 5-byte mask which will be applied to the
+// protected portions of the packet header.
+type headerProtection interface {
+	headerProtection(sample []byte) (mask [5]byte)
+}
+
+// AES-based header protection.
+// https://www.rfc-editor.org/rfc/rfc9001#section-5.4.3
+type aesHeaderProtection struct {
+	cipher  cipher.Block
+	scratch [aes.BlockSize]byte
+}
+
+func (hp *aesHeaderProtection) headerProtection(sample []byte) (mask [5]byte) {
+	hp.cipher.Encrypt(hp.scratch[:], sample)
+	copy(mask[:], hp.scratch[:])
+	return mask
+}
+
+// ChaCha20-based header protection.
+// https://www.rfc-editor.org/rfc/rfc9001#section-5.4.4
+type chaCha20HeaderProtection struct {
+	key []byte
+}
+
+func (hp chaCha20HeaderProtection) headerProtection(sample []byte) (mask [5]byte) {
+	counter := uint32(sample[3])<<24 | uint32(sample[2])<<16 | uint32(sample[1])<<8 | uint32(sample[0])
+	nonce := sample[4:16]
+	c, err := chacha20.NewUnauthenticatedCipher(hp.key, nonce)
+	if err != nil {
+		panic(err)
+	}
+	c.SetCounter(counter)
+	c.XORKeyStream(mask[:], mask[:])
+	return mask
+}
+
+// A packetKey applies or removes packet protection.
+// https://www.rfc-editor.org/rfc/rfc9001#section-5.1
+type packetKey struct {
+	aead cipher.AEAD // AEAD function used for packet protection.
+	iv   []byte      // IV used to construct the AEAD nonce.
+}
+
+func (k *packetKey) init(suite uint16, secret []byte) {
 	// https://www.rfc-editor.org/rfc/rfc9001#section-5.1
-	key := hkdfExpandLabel(h.New, secret, "quic key", nil, keyBytes)
+	h, keySize := hashForSuite(suite)
+	key := hkdfExpandLabel(h.New, secret, "quic key", nil, keySize)
+	switch suite {
+	case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
+		k.aead = newAESAEAD(key)
+	case tls.TLS_CHACHA20_POLY1305_SHA256:
+		k.aead = newChaCha20AEAD(key)
+	default:
+		panic("BUG: unknown cipher suite")
+	}
+	k.iv = hkdfExpandLabel(h.New, secret, "quic iv", nil, k.aead.NonceSize())
+}
+
+func newAESAEAD(key []byte) cipher.AEAD {
 	c, err := aes.NewCipher(key)
 	if err != nil {
 		panic(err)
@@ -70,53 +170,109 @@
 	if err != nil {
 		panic(err)
 	}
-	iv := hkdfExpandLabel(h.New, secret, "quic iv", nil, aead.NonceSize())
-	// https://www.rfc-editor.org/rfc/rfc9001#section-5.4.3
-	hpKey := hkdfExpandLabel(h.New, secret, "quic hp", nil, keyBytes)
-	hp, err := aes.NewCipher(hpKey)
-	if err != nil {
-		panic(err)
-	}
-	var scratch [aes.BlockSize]byte
-	headerProtection := func(sample []byte) (mask [5]byte) {
-		hp.Encrypt(scratch[:], sample)
-		copy(mask[:], scratch[:])
-		return mask
-	}
-	return keys{
-		aead:             aead,
-		iv:               iv,
-		headerProtection: headerProtection,
-	}
+	return aead
 }
 
-func newChaCha20Keys(secret []byte) keys {
-	// https://www.rfc-editor.org/rfc/rfc9001#section-5.1
-	key := hkdfExpandLabel(sha256.New, secret, "quic key", nil, chacha20poly1305.KeySize)
+func newChaCha20AEAD(key []byte) cipher.AEAD {
+	var err error
 	aead, err := chacha20poly1305.New(key)
 	if err != nil {
 		panic(err)
 	}
-	iv := hkdfExpandLabel(sha256.New, secret, "quic iv", nil, aead.NonceSize())
-	// https://www.rfc-editor.org/rfc/rfc9001#section-5.4.4
-	hpKey := hkdfExpandLabel(sha256.New, secret, "quic hp", nil, chacha20.KeySize)
-	headerProtection := func(sample []byte) [5]byte {
-		counter := uint32(sample[3])<<24 | uint32(sample[2])<<16 | uint32(sample[1])<<8 | uint32(sample[0])
-		nonce := sample[4:16]
-		c, err := chacha20.NewUnauthenticatedCipher(hpKey, nonce)
-		if err != nil {
-			panic(err)
-		}
-		c.SetCounter(counter)
-		var mask [5]byte
-		c.XORKeyStream(mask[:], mask[:])
-		return mask
+	return aead
+}
+
+func (k packetKey) protect(hdr, pay []byte, pnum packetNumber) []byte {
+	k.xorIV(pnum)
+	defer k.xorIV(pnum)
+	return k.aead.Seal(hdr, k.iv, pay, hdr)
+}
+
+func (k packetKey) unprotect(hdr, pay []byte, pnum packetNumber) (dec []byte, err error) {
+	k.xorIV(pnum)
+	defer k.xorIV(pnum)
+	return k.aead.Open(pay[:0], k.iv, pay, hdr)
+}
+
+// xorIV xors the packet protection IV with the packet number.
+func (k packetKey) xorIV(pnum packetNumber) {
+	k.iv[len(k.iv)-8] ^= uint8(pnum >> 56)
+	k.iv[len(k.iv)-7] ^= uint8(pnum >> 48)
+	k.iv[len(k.iv)-6] ^= uint8(pnum >> 40)
+	k.iv[len(k.iv)-5] ^= uint8(pnum >> 32)
+	k.iv[len(k.iv)-4] ^= uint8(pnum >> 24)
+	k.iv[len(k.iv)-3] ^= uint8(pnum >> 16)
+	k.iv[len(k.iv)-2] ^= uint8(pnum >> 8)
+	k.iv[len(k.iv)-1] ^= uint8(pnum)
+}
+
+// A fixedKeys is a header protection key and fixed packet protection key.
+// The packet protection key is fixed (it does not update).
+//
+// Fixed keys are used for Initial and Handshake keys, which do not update.
+type fixedKeys struct {
+	hdr headerKey
+	pkt packetKey
+}
+
+func (k *fixedKeys) init(suite uint16, secret []byte) {
+	k.hdr.init(suite, secret)
+	k.pkt.init(suite, secret)
+}
+
+func (k fixedKeys) isSet() bool {
+	return k.hdr.hp != nil
+}
+
+// protect applies packet protection to a packet.
+//
+// On input, hdr contains the packet header, pay the unencrypted payload,
+// pnumOff the offset of the packet number in the header, and pnum the untruncated
+// packet number.
+//
+// protect returns the result of appending the encrypted payload to hdr and
+// applying header protection.
+func (k fixedKeys) protect(hdr, pay []byte, pnumOff int, pnum packetNumber) []byte {
+	pkt := k.pkt.protect(hdr, pay, pnum)
+	k.hdr.protect(pkt, pnumOff)
+	return pkt
+}
+
+// unprotect removes packet protection from a packet.
+//
+// On input, pkt contains the full protected packet, pnumOff the offset of
+// the packet number in the header, and pnumMax the largest packet number
+// seen in the number space of this packet.
+//
+// unprotect removes header protection from the header in pkt, and returns
+// the unprotected payload and packet number.
+func (k fixedKeys) unprotect(pkt []byte, pnumOff int, pnumMax packetNumber) (pay []byte, num packetNumber, err error) {
+	hdr, pay, pnum, err := k.hdr.unprotect(pkt, pnumOff, pnumMax)
+	if err != nil {
+		return nil, 0, err
 	}
-	return keys{
-		aead:             aead,
-		iv:               iv,
-		headerProtection: headerProtection,
+	pay, err = k.pkt.unprotect(hdr, pay, pnum)
+	if err != nil {
+		return nil, 0, err
 	}
+	return pay, pnum, nil
+}
+
+// A fixedKeyPair is a read/write pair of fixed keys.
+type fixedKeyPair struct {
+	r, w fixedKeys
+}
+
+func (k *fixedKeyPair) discard() {
+	*k = fixedKeyPair{}
+}
+
+func (k *fixedKeyPair) canRead() bool {
+	return k.r.isSet()
+}
+
+func (k *fixedKeyPair) canWrite() bool {
+	return k.w.isSet()
 }
 
 // https://www.rfc-editor.org/rfc/rfc9001#section-5.2-2
@@ -128,121 +284,44 @@
 // field in the client's first Initial packet.
 //
 // https://www.rfc-editor.org/rfc/rfc9001#section-5.2
-func initialKeys(cid []byte) (clientKeys, serverKeys keys) {
+func initialKeys(cid []byte, side connSide) fixedKeyPair {
 	initialSecret := hkdf.Extract(sha256.New, cid, initialSalt)
-	clientInitialSecret := hkdfExpandLabel(sha256.New, initialSecret, "client in", nil, sha256.Size)
-	clientKeys, err := newKeys(tls.TLS_AES_128_GCM_SHA256, clientInitialSecret)
-	if err != nil {
-		panic(err)
-	}
-
-	serverInitialSecret := hkdfExpandLabel(sha256.New, initialSecret, "server in", nil, sha256.Size)
-	serverKeys, err = newKeys(tls.TLS_AES_128_GCM_SHA256, serverInitialSecret)
-	if err != nil {
-		panic(err)
-	}
-
-	return clientKeys, serverKeys
-}
-
-const headerProtectionSampleSize = 16
-
-// aeadOverhead is the difference in size between the AEAD output and input.
-// All cipher suites defined for use with QUIC have 16 bytes of overhead.
-const aeadOverhead = 16
-
-// xorIV xors the packet protection IV with the packet number.
-func (k keys) xorIV(pnum packetNumber) {
-	k.iv[len(k.iv)-8] ^= uint8(pnum >> 56)
-	k.iv[len(k.iv)-7] ^= uint8(pnum >> 48)
-	k.iv[len(k.iv)-6] ^= uint8(pnum >> 40)
-	k.iv[len(k.iv)-5] ^= uint8(pnum >> 32)
-	k.iv[len(k.iv)-4] ^= uint8(pnum >> 24)
-	k.iv[len(k.iv)-3] ^= uint8(pnum >> 16)
-	k.iv[len(k.iv)-2] ^= uint8(pnum >> 8)
-	k.iv[len(k.iv)-1] ^= uint8(pnum)
-}
-
-// isSet returns true if valid keys are available.
-func (k keys) isSet() bool {
-	return k.aead != nil
-}
-
-// discard discards the keys (in the sense that we won't use them any more,
-// not that the keys are securely erased).
-//
-// https://www.rfc-editor.org/rfc/rfc9001.html#section-4.9
-func (k *keys) discard() {
-	*k = keys{}
-}
-
-// protect applies packet protection to a packet.
-//
-// On input, hdr contains the packet header, pay the unencrypted payload,
-// pnumOff the offset of the packet number in the header, and pnum the untruncated
-// packet number.
-//
-// protect returns the result of appending the encrypted payload to hdr and
-// applying header protection.
-func (k keys) protect(hdr, pay []byte, pnumOff int, pnum packetNumber) []byte {
-	k.xorIV(pnum)
-	hdr = k.aead.Seal(hdr, k.iv, pay, hdr)
-	k.xorIV(pnum)
-
-	// Apply header protection.
-	pnumSize := int(hdr[0]&0x03) + 1
-	sample := hdr[pnumOff+4:][:headerProtectionSampleSize]
-	mask := k.headerProtection(sample)
-	if isLongHeader(hdr[0]) {
-		hdr[0] ^= mask[0] & 0x0f
+	var clientKeys fixedKeys
+	clientSecret := hkdfExpandLabel(sha256.New, initialSecret, "client in", nil, sha256.Size)
+	clientKeys.init(tls.TLS_AES_128_GCM_SHA256, clientSecret)
+	var serverKeys fixedKeys
+	serverSecret := hkdfExpandLabel(sha256.New, initialSecret, "server in", nil, sha256.Size)
+	serverKeys.init(tls.TLS_AES_128_GCM_SHA256, serverSecret)
+	if side == clientSide {
+		return fixedKeyPair{r: serverKeys, w: clientKeys}
 	} else {
-		hdr[0] ^= mask[0] & 0x1f
+		return fixedKeyPair{w: serverKeys, r: clientKeys}
 	}
-	for i := 0; i < pnumSize; i++ {
-		hdr[pnumOff+i] ^= mask[1+i]
-	}
-
-	return hdr
 }
 
-// unprotect removes packet protection from a packet.
-//
-// On input, pkt contains the full protected packet, pnumOff the offset of
-// the packet number in the header, and pnumMax the largest packet number
-// seen in the number space of this packet.
-//
-// unprotect removes header protection from the header in pkt, and returns
-// the unprotected payload and packet number.
-func (k keys) unprotect(pkt []byte, pnumOff int, pnumMax packetNumber) (pay []byte, num packetNumber, err error) {
-	if len(pkt) < pnumOff+4+headerProtectionSampleSize {
-		return nil, 0, errInvalidPacket
+// checkCipherSuite returns an error if suite is not a supported cipher suite.
+func checkCipherSuite(suite uint16) error {
+	switch suite {
+	case tls.TLS_AES_128_GCM_SHA256:
+	case tls.TLS_AES_256_GCM_SHA384:
+	case tls.TLS_CHACHA20_POLY1305_SHA256:
+	default:
+		return errors.New("invalid cipher suite")
 	}
-	numpay := pkt[pnumOff:]
-	sample := numpay[4:][:headerProtectionSampleSize]
-	mask := k.headerProtection(sample)
-	if isLongHeader(pkt[0]) {
-		pkt[0] ^= mask[0] & 0x0f
-	} else {
-		pkt[0] ^= mask[0] & 0x1f
-	}
-	pnumLen := int(pkt[0]&0x03) + 1
-	pnum := packetNumber(0)
-	for i := 0; i < pnumLen; i++ {
-		numpay[i] ^= mask[1+i]
-		pnum = (pnum << 8) | packetNumber(numpay[i])
-	}
-	pnum = decodePacketNumber(pnumMax, pnum, pnumLen)
+	return nil
+}
 
-	hdr := pkt[:pnumOff+pnumLen]
-	pay = numpay[pnumLen:]
-	k.xorIV(pnum)
-	pay, err = k.aead.Open(pay[:0], k.iv, pay, hdr)
-	k.xorIV(pnum)
-	if err != nil {
-		return nil, 0, err
+func hashForSuite(suite uint16) (h crypto.Hash, keySize int) {
+	switch suite {
+	case tls.TLS_AES_128_GCM_SHA256:
+		return crypto.SHA256, 128 / 8
+	case tls.TLS_AES_256_GCM_SHA384:
+		return crypto.SHA384, 256 / 8
+	case tls.TLS_CHACHA20_POLY1305_SHA256:
+		return crypto.SHA256, chacha20.KeySize
+	default:
+		panic("BUG: unknown cipher suite")
 	}
-
-	return pay, pnum, nil
 }
 
 // hdkfExpandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1.
diff --git a/internal/quic/packet_protection_test.go b/internal/quic/packet_protection_test.go
index 6495360..1fe1307 100644
--- a/internal/quic/packet_protection_test.go
+++ b/internal/quic/packet_protection_test.go
@@ -16,10 +16,11 @@
 	// Test cases from:
 	// https://www.rfc-editor.org/rfc/rfc9001#section-appendix.a
 	cid := unhex(`8394c8f03e515708`)
-	initialClientKeys, initialServerKeys := initialKeys(cid)
+	k := initialKeys(cid, clientSide)
+	initialClientKeys, initialServerKeys := k.w, k.r
 	for _, test := range []struct {
 		name string
-		k    keys
+		k    fixedKeys
 		pnum packetNumber
 		hdr  []byte
 		pay  []byte
@@ -103,15 +104,13 @@
 		`),
 	}, {
 		name: "ChaCha20_Poly1305 Short Header",
-		k: func() keys {
+		k: func() fixedKeys {
 			secret := unhex(`
 				9ac312a7f877468ebe69422748ad00a1
 				5443f18203a07d6060f688f30f21632b
 			`)
-			k, err := newKeys(tls.TLS_CHACHA20_POLY1305_SHA256, secret)
-			if err != nil {
-				t.Fatal(err)
-			}
+			var k fixedKeys
+			k.init(tls.TLS_CHACHA20_POLY1305_SHA256, secret)
 			return k
 		}(),
 		pnum: 654360564,
diff --git a/internal/quic/packet_writer.go b/internal/quic/packet_writer.go
index a80b471..2009895 100644
--- a/internal/quic/packet_writer.go
+++ b/internal/quic/packet_writer.go
@@ -100,7 +100,7 @@
 // finishProtectedLongHeaderPacket finishes writing an Initial, 0-RTT, or Handshake packet,
 // canceling the packet if it contains no payload.
 // It returns a sentPacket describing the packet, or nil if no packet was written.
-func (w *packetWriter) finishProtectedLongHeaderPacket(pnumMaxAcked packetNumber, k keys, p longPacket) *sentPacket {
+func (w *packetWriter) finishProtectedLongHeaderPacket(pnumMaxAcked packetNumber, k fixedKeys, p longPacket) *sentPacket {
 	if len(w.b) == w.payOff {
 		// The payload is empty, so just abandon the packet.
 		w.b = w.b[:w.pktOff]
@@ -135,7 +135,8 @@
 	pnumOff := len(hdr)
 	hdr = appendPacketNumber(hdr, p.num, pnumMaxAcked)
 
-	return w.protect(hdr[w.pktOff:], p.num, pnumOff, k)
+	k.protect(hdr[w.pktOff:], w.b[len(hdr):], pnumOff-w.pktOff, p.num)
+	return w.finish(p.num)
 }
 
 // start1RTTPacket starts writing a 1-RTT (short header) packet.
@@ -162,7 +163,7 @@
 // finish1RTTPacket finishes writing a 1-RTT packet,
 // canceling the packet if it contains no payload.
 // It returns a sentPacket describing the packet, or nil if no packet was written.
-func (w *packetWriter) finish1RTTPacket(pnum, pnumMaxAcked packetNumber, dstConnID []byte, k keys) *sentPacket {
+func (w *packetWriter) finish1RTTPacket(pnum, pnumMaxAcked packetNumber, dstConnID []byte, k fixedKeys) *sentPacket {
 	if len(w.b) == w.payOff {
 		// The payload is empty, so just abandon the packet.
 		w.b = w.b[:w.pktOff]
@@ -177,7 +178,8 @@
 	pnumOff := len(hdr)
 	hdr = appendPacketNumber(hdr, pnum, pnumMaxAcked)
 	w.padPacketLength(pnumLen)
-	return w.protect(hdr[w.pktOff:], pnum, pnumOff, k)
+	k.protect(hdr[w.pktOff:], w.b[len(hdr):], pnumOff-w.pktOff, pnum)
+	return w.finish(pnum)
 }
 
 // padPacketLength pads out the payload of the current packet to the minimum size,
@@ -197,9 +199,8 @@
 	return plen
 }
 
-// protect applies packet protection and finishes the current packet.
-func (w *packetWriter) protect(hdr []byte, pnum packetNumber, pnumOff int, k keys) *sentPacket {
-	k.protect(hdr, w.b[w.pktOff+len(hdr):], pnumOff-w.pktOff, pnum)
+// finish finishes the current packet after protection is applied.
+func (w *packetWriter) finish(pnum packetNumber) *sentPacket {
 	w.b = w.b[:len(w.b)+aeadOverhead]
 	w.sent.size = len(w.b) - w.pktOff
 	w.sent.num = pnum
diff --git a/internal/quic/tls.go b/internal/quic/tls.go
index e3a430e..a37e26f 100644
--- a/internal/quic/tls.go
+++ b/internal/quic/tls.go
@@ -16,12 +16,7 @@
 
 // startTLS starts the TLS handshake.
 func (c *Conn) startTLS(now time.Time, initialConnID []byte, params transportParameters) error {
-	clientKeys, serverKeys := initialKeys(initialConnID)
-	if c.side == clientSide {
-		c.wkeys[initialSpace], c.rkeys[initialSpace] = clientKeys, serverKeys
-	} else {
-		c.wkeys[initialSpace], c.rkeys[initialSpace] = serverKeys, clientKeys
-	}
+	c.keysInitial = initialKeys(initialConnID, c.side)
 
 	qconfig := &tls.QUICConfig{TLSConfig: c.config.TLSConfig}
 	if c.side == clientSide {
@@ -49,21 +44,36 @@
 		case tls.QUICNoEvent:
 			return nil
 		case tls.QUICSetReadSecret:
-			space, k, err := tlsKey(e)
-			if err != nil {
+			if err := checkCipherSuite(e.Suite); err != nil {
 				return err
 			}
-			c.rkeys[space] = k
+			switch e.Level {
+			case tls.QUICEncryptionLevelHandshake:
+				c.keysHandshake.r.init(e.Suite, e.Data)
+			case tls.QUICEncryptionLevelApplication:
+				c.keysAppData.r.init(e.Suite, e.Data)
+			}
 		case tls.QUICSetWriteSecret:
-			space, k, err := tlsKey(e)
-			if err != nil {
+			if err := checkCipherSuite(e.Suite); err != nil {
 				return err
 			}
-			c.wkeys[space] = k
+			switch e.Level {
+			case tls.QUICEncryptionLevelHandshake:
+				c.keysHandshake.w.init(e.Suite, e.Data)
+			case tls.QUICEncryptionLevelApplication:
+				c.keysAppData.w.init(e.Suite, e.Data)
+			}
 		case tls.QUICWriteData:
-			space, err := spaceForLevel(e.Level)
-			if err != nil {
-				return err
+			var space numberSpace
+			switch e.Level {
+			case tls.QUICEncryptionLevelInitial:
+				space = initialSpace
+			case tls.QUICEncryptionLevelHandshake:
+				space = handshakeSpace
+			case tls.QUICEncryptionLevelApplication:
+				space = appDataSpace
+			default:
+				return fmt.Errorf("quic: internal error: write handshake data at level %v", e.Level)
 			}
 			c.crypto[space].write(e.Data)
 		case tls.QUICHandshakeDone:
@@ -86,32 +96,6 @@
 	}
 }
 
-// tlsKey returns the keys in a QUICSetReadSecret or QUICSetWriteSecret event.
-func tlsKey(e tls.QUICEvent) (numberSpace, keys, error) {
-	space, err := spaceForLevel(e.Level)
-	if err != nil {
-		return 0, keys{}, err
-	}
-	k, err := newKeys(e.Suite, e.Data)
-	if err != nil {
-		return 0, keys{}, err
-	}
-	return space, k, nil
-}
-
-func spaceForLevel(level tls.QUICEncryptionLevel) (numberSpace, error) {
-	switch level {
-	case tls.QUICEncryptionLevelInitial:
-		return initialSpace, nil
-	case tls.QUICEncryptionLevelHandshake:
-		return handshakeSpace, nil
-	case tls.QUICEncryptionLevelApplication:
-		return appDataSpace, nil
-	default:
-		return 0, fmt.Errorf("quic: internal error: write handshake data at level %v", level)
-	}
-}
-
 // handleCrypto processes data received in a CRYPTO frame.
 func (c *Conn) handleCrypto(now time.Time, space numberSpace, off int64, data []byte) error {
 	var level tls.QUICEncryptionLevel