quic: handle peer-initiated key updates

RFC 9001, Section 6.

For golang/go#58547

Change-Id: I3700043d27ab41536521b547ecf5e632a08eb1b5
Reviewed-on: https://go-review.googlesource.com/c/net/+/528835
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/internal/quic/conn.go b/internal/quic/conn.go
index 4565e1a..dc3a985 100644
--- a/internal/quic/conn.go
+++ b/internal/quic/conn.go
@@ -44,7 +44,7 @@
 	// Packet protection keys, CRYPTO streams, and TLS state.
 	keysInitial   fixedKeyPair
 	keysHandshake fixedKeyPair
-	keysAppData   fixedKeyPair
+	keysAppData   updatingKeyPair
 	crypto        [numberSpaceCount]cryptoStream
 	tls           *tls.QUICConn
 
diff --git a/internal/quic/conn_recv.go b/internal/quic/conn_recv.go
index d1fa52d..4fc4eec 100644
--- a/internal/quic/conn_recv.go
+++ b/internal/quic/conn_recv.go
@@ -89,7 +89,7 @@
 	}
 
 	pnumMax := c.acks[appDataSpace].largestSeen()
-	p, n := parse1RTTPacket(buf, c.keysAppData.r, connIDLen, pnumMax)
+	p, n := parse1RTTPacket(buf, &c.keysAppData, connIDLen, pnumMax)
 	if n < 0 {
 		return -1
 	}
@@ -247,7 +247,7 @@
 
 func (c *Conn) handleAckFrame(now time.Time, space numberSpace, payload []byte) int {
 	c.loss.receiveAckStart()
-	_, ackDelay, n := consumeAckFrame(payload, func(rangeIndex int, start, end packetNumber) {
+	largest, ackDelay, n := consumeAckFrame(payload, func(rangeIndex int, start, end packetNumber) {
 		if end > c.loss.nextNumber(space) {
 			// Acknowledgement of a packet we never sent.
 			c.abort(now, localTransportError(errProtocolViolation))
@@ -280,6 +280,9 @@
 		delay = ackDelay.Duration(uint8(c.peerAckDelayExponent))
 	}
 	c.loss.receiveAckEnd(now, space, delay, c.handleAckOrLoss)
+	if space == appDataSpace {
+		c.keysAppData.handleAckFor(largest)
+	}
 	return n
 }
 
diff --git a/internal/quic/conn_send.go b/internal/quic/conn_send.go
index 58a3df1..63f65b5 100644
--- a/internal/quic/conn_send.go
+++ b/internal/quic/conn_send.go
@@ -128,7 +128,7 @@
 			if logPackets {
 				logSentPacket(c, packetType1RTT, pnum, nil, dstConnID, c.w.payload())
 			}
-			if sent := c.w.finish1RTTPacket(pnum, pnumMaxAcked, dstConnID, c.keysAppData.w); sent != nil {
+			if sent := c.w.finish1RTTPacket(pnum, pnumMaxAcked, dstConnID, &c.keysAppData); sent != nil {
 				c.loss.packetSent(now, appDataSpace, sent)
 			}
 		}
@@ -197,16 +197,23 @@
 			// All frames other than ACK and PADDING are ack-eliciting,
 			// so if the packet is ack-eliciting we've added additional
 			// frames to it.
-			if shouldSendAck || c.w.sent.ackEliciting {
-				// Either we are willing to send an ACK-only packet,
-				// or we've added additional frames.
-				c.acks[space].sentAck()
-			} else {
+			if !shouldSendAck && !c.w.sent.ackEliciting {
 				// There's nothing in this packet but ACK frames, and
 				// we don't want to send an ACK-only packet at this time.
 				// Abandoning the packet means we wrote an ACK frame for
 				// nothing, but constructing the frame is cheap.
 				c.w.abandonPacket()
+				return
+			}
+			// Either we are willing to send an ACK-only packet,
+			// or we've added additional frames.
+			c.acks[space].sentAck()
+			if !c.w.sent.ackEliciting && c.keysAppData.needAckEliciting() {
+				// The peer has initiated a key update.
+				// We haven't sent them any packets yet in the new phase.
+				// Make this an ack-eliciting packet.
+				// Their ack of this packet will complete the key update.
+				c.w.appendPingFrame()
 			}
 		}()
 	}
diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go
index 3fef62d..76774cc 100644
--- a/internal/quic/conn_test.go
+++ b/internal/quic/conn_test.go
@@ -76,12 +76,14 @@
 }
 
 type testPacket struct {
-	ptype     packetType
-	version   uint32
-	num       packetNumber
-	dstConnID []byte
-	srcConnID []byte
-	frames    []debugFrame
+	ptype       packetType
+	version     uint32
+	num         packetNumber
+	keyPhaseBit bool
+	keyNumber   int
+	dstConnID   []byte
+	srcConnID   []byte
+	frames      []debugFrame
 }
 
 func (p testPacket) String() string {
@@ -102,6 +104,9 @@
 	return b.String()
 }
 
+// maxTestKeyPhases is the maximum number of 1-RTT keys we'll generate in a test.
+const maxTestKeyPhases = 3
+
 // A testConn is a Conn whose external interactions (sending and receiving packets,
 // setting timers) can be manipulated in tests.
 type testConn struct {
@@ -122,9 +127,10 @@
 	// the Initial packet.
 	keysInitial   fixedKeyPair
 	keysHandshake fixedKeyPair
-	keysAppData   fixedKeyPair
-	rsecrets      [numberSpaceCount]testKeySecret
-	wsecrets      [numberSpaceCount]testKeySecret
+	rkeyAppData   test1RTTKeys
+	wkeyAppData   test1RTTKeys
+	rsecrets      [numberSpaceCount]keySecret
+	wsecrets      [numberSpaceCount]keySecret
 
 	// testConn uses a test hook to snoop on the conn's TLS events.
 	// CRYPTO data produced by the conn's QUICConn is placed in
@@ -156,10 +162,19 @@
 	// Frame types to ignore in tests.
 	ignoreFrames map[byte]bool
 
+	// Values to set in packets sent to the conn.
+	sendKeyNumber   int
+	sendKeyPhaseBit bool
+
 	asyncTestState
 }
 
-type testKeySecret struct {
+type test1RTTKeys struct {
+	hdr headerKey
+	pkt [maxTestKeyPhases]packetKey
+}
+
+type keySecret struct {
 	suite  uint16
 	secret []byte
 }
@@ -333,12 +348,20 @@
 	}
 	tc.t.Logf("%v datagram%v", text, pad)
 	for _, p := range d.packets {
+		var s string
 		switch p.ptype {
 		case packetType1RTT:
-			tc.t.Logf("  %v pnum=%v", p.ptype, p.num)
+			s = fmt.Sprintf("  %v pnum=%v", p.ptype, p.num)
 		default:
-			tc.t.Logf("  %v pnum=%v ver=%v dst={%x} src={%x}", p.ptype, p.num, p.version, p.dstConnID, p.srcConnID)
+			s = fmt.Sprintf("  %v pnum=%v ver=%v dst={%x} src={%x}", p.ptype, p.num, p.version, p.dstConnID, p.srcConnID)
 		}
+		if p.keyPhaseBit {
+			s += fmt.Sprintf(" KeyPhase")
+		}
+		if p.keyNumber != 0 {
+			s += fmt.Sprintf(" keynum=%v", p.keyNumber)
+		}
+		tc.t.Log(s)
 		for _, f := range p.frames {
 			tc.t.Logf("    %v", f)
 		}
@@ -381,12 +404,14 @@
 	}
 	d := &testDatagram{
 		packets: []*testPacket{{
-			ptype:     ptype,
-			num:       tc.peerNextPacketNum[space],
-			frames:    frames,
-			version:   1,
-			dstConnID: dstConnID,
-			srcConnID: tc.peerConnID,
+			ptype:       ptype,
+			num:         tc.peerNextPacketNum[space],
+			keyNumber:   tc.sendKeyNumber,
+			keyPhaseBit: tc.sendKeyPhaseBit,
+			frames:      frames,
+			version:     1,
+			dstConnID:   dstConnID,
+			srcConnID:   tc.peerConnID,
 		}},
 	}
 	if ptype == packetTypeInitial && tc.conn.side == serverSide {
@@ -580,6 +605,22 @@
 	}
 }
 
+// wantFrameType indicates that we expect the Conn to send a frame,
+// although we don't care about the contents.
+func (tc *testConn) wantFrameType(expectation string, wantType packetType, want debugFrame) {
+	tc.t.Helper()
+	got, gotType := tc.readFrame()
+	if got == nil {
+		tc.t.Fatalf("%v:\nconnection is idle\nwant %v frame: %v", expectation, wantType, want)
+	}
+	if gotType != wantType {
+		tc.t.Fatalf("%v:\ngot %v packet, want %v\ngot frame:  %v", expectation, gotType, wantType, got)
+	}
+	if reflect.TypeOf(got) != reflect.TypeOf(want) {
+		tc.t.Fatalf("%v:\ngot frame:  %v\nwant frame of type: %v", expectation, got, want)
+	}
+}
+
 // wantIdle indicates that we expect the Conn to not send any more frames.
 func (tc *testConn) wantIdle(expectation string) {
 	tc.t.Helper()
@@ -615,17 +656,17 @@
 	}
 	w.appendPaddingTo(pad)
 	if p.ptype != packetType1RTT {
-		var k fixedKeyPair
+		var k fixedKeys
 		switch p.ptype {
 		case packetTypeInitial:
-			k = tc.keysInitial
+			k = tc.keysInitial.w
 		case packetTypeHandshake:
-			k = tc.keysHandshake
+			k = tc.keysHandshake.w
 		}
-		if !k.canWrite() {
+		if !k.isSet() {
 			tc.t.Fatalf("sending %v packet with no write key", p.ptype)
 		}
-		w.finishProtectedLongHeaderPacket(pnumMaxAcked, k.w, longPacket{
+		w.finishProtectedLongHeaderPacket(pnumMaxAcked, k, longPacket{
 			ptype:     p.ptype,
 			version:   p.version,
 			num:       p.num,
@@ -633,10 +674,24 @@
 			srcConnID: p.srcConnID,
 		})
 	} else {
-		if !tc.keysAppData.canWrite() {
-			tc.t.Fatalf("sending %v packet with no write key", p.ptype)
+		if !tc.wkeyAppData.hdr.isSet() {
+			tc.t.Fatalf("sending 1-RTT packet with no write key")
 		}
-		w.finish1RTTPacket(p.num, pnumMaxAcked, p.dstConnID, tc.keysAppData.w)
+		// Somewhat hackish: Generate a temporary updatingKeyPair that will
+		// always use our desired key phase.
+		k := &updatingKeyPair{
+			w: updatingKeys{
+				hdr: tc.wkeyAppData.hdr,
+				pkt: [2]packetKey{
+					tc.wkeyAppData.pkt[p.keyNumber],
+					tc.wkeyAppData.pkt[p.keyNumber],
+				},
+			},
+		}
+		if p.keyPhaseBit {
+			k.phase |= keyPhaseBit
+		}
+		w.finish1RTTPacket(p.num, pnumMaxAcked, p.dstConnID, k)
 	}
 	return w.datagram()
 }
@@ -682,25 +737,45 @@
 			})
 			buf = buf[n:]
 		} else {
-			if !tc.keysAppData.canRead() {
+			if !tc.rkeyAppData.hdr.isSet() {
 				tc.t.Fatalf("reading 1-RTT packet with no read key")
 			}
 			var pnumMax packetNumber // TODO: Track packet numbers.
-			p, n := parse1RTTPacket(buf, tc.keysAppData.r, len(tc.peerConnID), pnumMax)
-			if n < 0 {
-				tc.t.Fatalf("packet parse error")
+			pnumOff := 1 + len(tc.peerConnID)
+			// Try unprotecting the packet with the first maxTestKeyPhases keys.
+			var phase int
+			var pnum packetNumber
+			var hdr []byte
+			var pay []byte
+			var err error
+			for phase = 0; phase < maxTestKeyPhases; phase++ {
+				b := append([]byte{}, buf...)
+				hdr, pay, pnum, err = tc.rkeyAppData.hdr.unprotect(b, pnumOff, pnumMax)
+				if err != nil {
+					tc.t.Fatalf("1-RTT packet header parse error")
+				}
+				k := tc.rkeyAppData.pkt[phase]
+				pay, err = k.unprotect(hdr, pay, pnum)
+				if err == nil {
+					break
+				}
 			}
-			frames, err := tc.parseTestFrames(p.payload)
+			if err != nil {
+				tc.t.Fatalf("1-RTT packet payload parse error")
+			}
+			frames, err := tc.parseTestFrames(pay)
 			if err != nil {
 				tc.t.Fatal(err)
 			}
 			d.packets = append(d.packets, &testPacket{
-				ptype:     packetType1RTT,
-				num:       p.num,
-				dstConnID: buf[1:][:len(tc.peerConnID)],
-				frames:    frames,
+				ptype:       packetType1RTT,
+				num:         pnum,
+				dstConnID:   hdr[1:][:len(tc.peerConnID)],
+				keyPhaseBit: hdr[0]&keyPhaseBit != 0,
+				keyNumber:   phase,
+				frames:      frames,
 			})
-			buf = buf[n:]
+			buf = buf[len(buf):]
 		}
 	}
 	// This is rather hackish: If the last frame in the last packet
@@ -766,7 +841,7 @@
 // and verify that both sides of the connection are getting
 // matching keys.
 func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) {
-	checkKey := func(typ string, secrets *[numberSpaceCount]testKeySecret, e tls.QUICEvent) {
+	checkKey := func(typ string, secrets *[numberSpaceCount]keySecret, e tls.QUICEvent) {
 		var space numberSpace
 		switch {
 		case e.Level == tls.QUICEncryptionLevelHandshake:
@@ -781,25 +856,32 @@
 			secrets[space].suite = e.Suite
 			secrets[space].secret = append([]byte{}, e.Data...)
 		} else if secrets[space].suite != e.Suite || !bytes.Equal(secrets[space].secret, e.Data) {
-			tc.t.Errorf("%v key mismatch for level %v", typ, e.Level)
+			tc.t.Errorf("%v key mismatch for level for level %v", typ, e.Level)
+		}
+	}
+	setAppDataKey := func(suite uint16, secret []byte, k *test1RTTKeys) {
+		k.hdr.init(suite, secret)
+		for i := 0; i < len(k.pkt); i++ {
+			k.pkt[i].init(suite, secret)
+			secret = updateSecret(suite, secret)
 		}
 	}
 	switch e.Kind {
 	case tls.QUICSetReadSecret:
-		checkKey("read", &tc.rsecrets, e)
+		checkKey("write", &tc.wsecrets, e)
 		switch e.Level {
 		case tls.QUICEncryptionLevelHandshake:
 			tc.keysHandshake.w.init(e.Suite, e.Data)
 		case tls.QUICEncryptionLevelApplication:
-			tc.keysAppData.w.init(e.Suite, e.Data)
+			setAppDataKey(e.Suite, e.Data, &tc.wkeyAppData)
 		}
 	case tls.QUICSetWriteSecret:
-		checkKey("write", &tc.wsecrets, e)
+		checkKey("read", &tc.rsecrets, e)
 		switch e.Level {
 		case tls.QUICEncryptionLevelHandshake:
 			tc.keysHandshake.r.init(e.Suite, e.Data)
 		case tls.QUICEncryptionLevelApplication:
-			tc.keysAppData.r.init(e.Suite, e.Data)
+			setAppDataKey(e.Suite, e.Data, &tc.rkeyAppData)
 		}
 	case tls.QUICWriteData:
 		tc.cryptoDataOut[e.Level] = append(tc.cryptoDataOut[e.Level], e.Data...)
@@ -811,20 +893,20 @@
 		case tls.QUICNoEvent:
 			return
 		case tls.QUICSetReadSecret:
-			checkKey("write", &tc.wsecrets, e)
+			checkKey("write", &tc.rsecrets, e)
 			switch e.Level {
 			case tls.QUICEncryptionLevelHandshake:
 				tc.keysHandshake.r.init(e.Suite, e.Data)
 			case tls.QUICEncryptionLevelApplication:
-				tc.keysAppData.r.init(e.Suite, e.Data)
+				setAppDataKey(e.Suite, e.Data, &tc.rkeyAppData)
 			}
 		case tls.QUICSetWriteSecret:
-			checkKey("read", &tc.rsecrets, e)
+			checkKey("read", &tc.wsecrets, e)
 			switch e.Level {
 			case tls.QUICEncryptionLevelHandshake:
 				tc.keysHandshake.w.init(e.Suite, e.Data)
 			case tls.QUICEncryptionLevelApplication:
-				tc.keysAppData.w.init(e.Suite, e.Data)
+				setAppDataKey(e.Suite, e.Data, &tc.wkeyAppData)
 			}
 		case tls.QUICWriteData:
 			tc.cryptoDataIn[e.Level] = append(tc.cryptoDataIn[e.Level], e.Data...)
diff --git a/internal/quic/key_update_test.go b/internal/quic/key_update_test.go
new file mode 100644
index 0000000..6b6bb79
--- /dev/null
+++ b/internal/quic/key_update_test.go
@@ -0,0 +1,163 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+	"testing"
+)
+
+func TestKeyUpdatePeerUpdates(t *testing.T) {
+	tc := newTestConn(t, serverSide)
+	tc.handshake()
+	tc.ignoreFrames = nil // ignore nothing
+
+	// Peer initiates a key update.
+	tc.sendKeyNumber = 1
+	tc.sendKeyPhaseBit = true
+	tc.writeFrames(packetType1RTT, debugFramePing{})
+
+	// We update to the new key.
+	tc.advanceToTimer()
+	tc.wantFrameType("conn ACKs last packet",
+		packetType1RTT, debugFrameAck{})
+	tc.wantFrame("first packet after a key update is always ack-eliciting",
+		packetType1RTT, debugFramePing{})
+	if got, want := tc.lastPacket.keyNumber, 1; got != want {
+		t.Errorf("after key rotation, conn sent packet with key %v, want %v", got, want)
+	}
+	if !tc.lastPacket.keyPhaseBit {
+		t.Errorf("after key rotation, conn failed to change Key Phase bit")
+	}
+	tc.wantIdle("conn has nothing to send")
+
+	// Peer's ACK of a packet we sent in the new phase completes the update.
+	tc.writeAckForAll()
+
+	// Peer initiates a second key update.
+	tc.sendKeyNumber = 2
+	tc.sendKeyPhaseBit = false
+	tc.writeFrames(packetType1RTT, debugFramePing{})
+
+	// We update to the new key.
+	tc.advanceToTimer()
+	tc.wantFrameType("conn ACKs last packet",
+		packetType1RTT, debugFrameAck{})
+	tc.wantFrame("first packet after a key update is always ack-eliciting",
+		packetType1RTT, debugFramePing{})
+	if got, want := tc.lastPacket.keyNumber, 2; got != want {
+		t.Errorf("after key rotation, conn sent packet with key %v, want %v", got, want)
+	}
+	if tc.lastPacket.keyPhaseBit {
+		t.Errorf("after second key rotation, conn failed to change Key Phase bit")
+	}
+	tc.wantIdle("conn has nothing to send")
+}
+
+func TestKeyUpdateAcceptPreviousPhaseKeys(t *testing.T) {
+	// "An endpoint SHOULD retain old keys for some time after
+	// unprotecting a packet sent using the new keys."
+	// https://www.rfc-editor.org/rfc/rfc9001#section-6.1-8
+	tc := newTestConn(t, serverSide)
+	tc.handshake()
+	tc.ignoreFrames = nil // ignore nothing
+
+	// Peer initiates a key update, skipping one packet number.
+	pnum0 := tc.peerNextPacketNum[appDataSpace]
+	tc.peerNextPacketNum[appDataSpace]++
+	tc.sendKeyNumber = 1
+	tc.sendKeyPhaseBit = true
+	tc.writeFrames(packetType1RTT, debugFramePing{})
+
+	// We update to the new key.
+	// This ACK is not delayed, because we've skipped a packet number.
+	tc.wantFrame("conn ACKs last packet",
+		packetType1RTT, debugFrameAck{
+			ranges: []i64range[packetNumber]{
+				{0, pnum0},
+				{pnum0 + 1, pnum0 + 2},
+			},
+		})
+	tc.wantFrame("first packet after a key update is always ack-eliciting",
+		packetType1RTT, debugFramePing{})
+	if got, want := tc.lastPacket.keyNumber, 1; got != want {
+		t.Errorf("after key rotation, conn sent packet with key %v, want %v", got, want)
+	}
+	if !tc.lastPacket.keyPhaseBit {
+		t.Errorf("after key rotation, conn failed to change Key Phase bit")
+	}
+	tc.wantIdle("conn has nothing to send")
+
+	// We receive the previously-skipped packet in the earlier key phase.
+	tc.peerNextPacketNum[appDataSpace] = pnum0
+	tc.sendKeyNumber = 0
+	tc.sendKeyPhaseBit = false
+	tc.writeFrames(packetType1RTT, debugFramePing{})
+
+	// We ack the reordered packet immediately, still in the new key phase.
+	tc.wantFrame("conn ACKs reordered packet",
+		packetType1RTT, debugFrameAck{
+			ranges: []i64range[packetNumber]{
+				{0, pnum0 + 2},
+			},
+		})
+	tc.wantIdle("packet is not ack-eliciting")
+	if got, want := tc.lastPacket.keyNumber, 1; got != want {
+		t.Errorf("after key rotation, conn sent packet with key %v, want %v", got, want)
+	}
+	if !tc.lastPacket.keyPhaseBit {
+		t.Errorf("after key rotation, conn failed to change Key Phase bit")
+	}
+}
+
+func TestKeyUpdateRejectPacketFromPriorPhase(t *testing.T) {
+	// "Packets with higher packet numbers MUST be protected with either
+	// the same or newer packet protection keys than packets with lower packet numbers."
+	// https://www.rfc-editor.org/rfc/rfc9001#section-6.4-2
+	tc := newTestConn(t, serverSide)
+	tc.handshake()
+	tc.ignoreFrames = nil // ignore nothing
+
+	// Peer initiates a key update.
+	tc.sendKeyNumber = 1
+	tc.sendKeyPhaseBit = true
+	tc.writeFrames(packetType1RTT, debugFramePing{})
+
+	// We update to the new key.
+	tc.advanceToTimer()
+	tc.wantFrameType("conn ACKs last packet",
+		packetType1RTT, debugFrameAck{})
+	tc.wantFrame("first packet after a key update is always ack-eliciting",
+		packetType1RTT, debugFramePing{})
+	if got, want := tc.lastPacket.keyNumber, 1; got != want {
+		t.Errorf("after key rotation, conn sent packet with key %v, want %v", got, want)
+	}
+	if !tc.lastPacket.keyPhaseBit {
+		t.Errorf("after key rotation, conn failed to change Key Phase bit")
+	}
+	tc.wantIdle("conn has nothing to send")
+
+	// Peer sends an ack-eliciting packet using the prior phase keys.
+	// We fail to unprotect the packet and ignore it.
+	skipped := tc.peerNextPacketNum[appDataSpace]
+	tc.sendKeyNumber = 0
+	tc.sendKeyPhaseBit = false
+	tc.writeFrames(packetType1RTT, debugFramePing{})
+
+	// Peer sends an ack-eliciting packet using the current phase keys.
+	tc.sendKeyNumber = 1
+	tc.sendKeyPhaseBit = true
+	tc.writeFrames(packetType1RTT, debugFramePing{})
+
+	// We ack the peer's packets, not including the one sent with the wrong keys.
+	tc.wantFrame("conn ACKs packets, not including packet sent with wrong keys",
+		packetType1RTT, debugFrameAck{
+			ranges: []i64range[packetNumber]{
+				{0, skipped},
+				{skipped + 1, skipped + 2},
+			},
+		})
+}
diff --git a/internal/quic/packet.go b/internal/quic/packet.go
index a1bcead..8242bd0 100644
--- a/internal/quic/packet.go
+++ b/internal/quic/packet.go
@@ -45,6 +45,7 @@
 	fixedBit         = 0x40 // https://www.rfc-editor.org/rfc/rfc9000.html#section-17.2-3.4.1
 	reservedLongBits = 0x0c // https://www.rfc-editor.org/rfc/rfc9000#section-17.2-8.2.1
 	reserved1RTTBits = 0x18 // https://www.rfc-editor.org/rfc/rfc9000#section-17.3.1-4.8.1
+	keyPhaseBit      = 0x04 // https://www.rfc-editor.org/rfc/rfc9000#section-17.3.1-4.10.1
 )
 
 // Long Packet Type bits.
diff --git a/internal/quic/packet_codec_test.go b/internal/quic/packet_codec_test.go
index 7f0846f..c8b1f9b 100644
--- a/internal/quic/packet_codec_test.go
+++ b/internal/quic/packet_codec_test.go
@@ -146,10 +146,13 @@
 }
 
 func TestRoundtripEncodeShortPacket(t *testing.T) {
-	var aes128Keys, aes256Keys, chachaKeys fixedKeys
-	aes128Keys.init(tls.TLS_AES_128_GCM_SHA256, []byte("secret"))
-	aes256Keys.init(tls.TLS_AES_256_GCM_SHA384, []byte("secret"))
-	chachaKeys.init(tls.TLS_CHACHA20_POLY1305_SHA256, []byte("secret"))
+	var aes128Keys, aes256Keys, chachaKeys updatingKeyPair
+	aes128Keys.r.init(tls.TLS_AES_128_GCM_SHA256, []byte("secret"))
+	aes256Keys.r.init(tls.TLS_AES_256_GCM_SHA384, []byte("secret"))
+	chachaKeys.r.init(tls.TLS_CHACHA20_POLY1305_SHA256, []byte("secret"))
+	aes128Keys.w = aes128Keys.r
+	aes256Keys.w = aes256Keys.r
+	chachaKeys.w = chachaKeys.r
 	connID := make([]byte, connIDLen)
 	for i := range connID {
 		connID[i] = byte(i)
@@ -158,7 +161,7 @@
 		desc    string
 		num     packetNumber
 		payload []byte
-		k       fixedKeys
+		k       updatingKeyPair
 	}{{
 		desc:    "1-byte number, AES128",
 		num:     0, // 1-byte encoding,
@@ -185,9 +188,9 @@
 			w.reset(1200)
 			w.start1RTTPacket(test.num, 0, connID)
 			w.b = append(w.b, test.payload...)
-			w.finish1RTTPacket(test.num, 0, connID, test.k)
+			w.finish1RTTPacket(test.num, 0, connID, &test.k)
 			pkt := w.datagram()
-			p, n := parse1RTTPacket(pkt, test.k, connIDLen, 0)
+			p, n := parse1RTTPacket(pkt, &test.k, connIDLen, 0)
 			if n != len(pkt) {
 				t.Errorf("parse1RTTPacket: n=%v, want %v", n, len(pkt))
 			}
diff --git a/internal/quic/packet_parser.go b/internal/quic/packet_parser.go
index 458cd3a..8bb3cae 100644
--- a/internal/quic/packet_parser.go
+++ b/internal/quic/packet_parser.go
@@ -143,12 +143,13 @@
 //
 // On input, pkt contains a short header packet, k the decryption keys for the packet,
 // and pnumMax the largest packet number seen in the number space of this packet.
-func parse1RTTPacket(pkt []byte, k fixedKeys, dstConnIDLen int, pnumMax packetNumber) (p shortPacket, n int) {
-	var err error
-	p.payload, p.num, err = k.unprotect(pkt, 1+dstConnIDLen, pnumMax)
+func parse1RTTPacket(pkt []byte, k *updatingKeyPair, dstConnIDLen int, pnumMax packetNumber) (p shortPacket, n int) {
+	pay, pnum, err := k.unprotect(pkt, 1+dstConnIDLen, pnumMax)
 	if err != nil {
 		return shortPacket{}, -1
 	}
+	p.num = pnum
+	p.payload = pay
 	return p, len(pkt)
 }
 
diff --git a/internal/quic/packet_protection.go b/internal/quic/packet_protection.go
index 2f9b9ce..aab1eaf 100644
--- a/internal/quic/packet_protection.go
+++ b/internal/quic/packet_protection.go
@@ -37,6 +37,10 @@
 	hp headerProtection
 }
 
+func (k headerKey) isSet() bool {
+	return k.hp != nil
+}
+
 func (k *headerKey) init(suite uint16, secret []byte) {
 	h, keySize := hashForSuite(suite)
 	hpKey := hkdfExpandLabel(h.New, secret, "quic hp", nil, keySize)
@@ -275,6 +279,148 @@
 	return k.w.isSet()
 }
 
+// An updatingKeys is a header protection key and updatable packet protection key.
+// updatingKeys are used for 1-RTT keys, where the packet protection key changes
+// over the lifetime of a connection.
+// https://www.rfc-editor.org/rfc/rfc9001#section-6
+type updatingKeys struct {
+	suite      uint16
+	hdr        headerKey
+	pkt        [2]packetKey // current, next
+	nextSecret []byte       // secret used to generate pkt[1]
+}
+
+func (k *updatingKeys) init(suite uint16, secret []byte) {
+	k.suite = suite
+	k.hdr.init(suite, secret)
+	// Initialize pkt[1] with secret_0, and then call update to generate secret_1.
+	k.pkt[1].init(suite, secret)
+	k.nextSecret = secret
+	k.update()
+}
+
+// update performs a key update.
+// The current key in pkt[0] is discarded.
+// The next key in pkt[1] becomes the current key.
+// A new next key is generated in pkt[1].
+func (k *updatingKeys) update() {
+	k.nextSecret = updateSecret(k.suite, k.nextSecret)
+	k.pkt[0] = k.pkt[1]
+	k.pkt[1].init(k.suite, k.nextSecret)
+}
+
+func updateSecret(suite uint16, secret []byte) (nextSecret []byte) {
+	h, _ := hashForSuite(suite)
+	return hkdfExpandLabel(h.New, secret, "quic ku", nil, len(secret))
+}
+
+// An updatingKeyPair is a read/write pair of updating keys.
+//
+// We keep two keys (current and next) in both read and write directions.
+// When an incoming packet's phase matches the current phase bit,
+// we unprotect it using the current keys; otherwise we use the next keys.
+//
+// When updating=false, outgoing packets are protected using the current phase.
+//
+// An update is initiated and updating is set to true when:
+//   - we decide to initiate a key update; or
+//   - we successfully unprotect a packet using the next keys,
+//     indicating the peer has initiated a key update.
+//
+// When updating=true, outgoing packets are protected using the next phase.
+// We do not change the current phase bit or generate new keys yet.
+//
+// The update concludes when we receive an ACK frame for a packet sent
+// with the next keys. At this time, we set updating to false, flip the
+// phase bit, and update the keys. This permits us to handle up to 1-RTT
+// of reordered packets before discarding the previous phase's keys after
+// an update.
+type updatingKeyPair struct {
+	phase       uint8 // current key phase (r.pkt[0], w.pkt[0])
+	updating    bool
+	minSent     packetNumber // min packet number sent since entering the updating state
+	minReceived packetNumber // min packet number received in the next phase
+	r, w        updatingKeys
+}
+
+func (k *updatingKeyPair) canRead() bool {
+	return k.r.hdr.hp != nil
+}
+
+func (k *updatingKeyPair) canWrite() bool {
+	return k.w.hdr.hp != nil
+}
+
+// handleAckFor finishes a key update after receiving an ACK for a packet in the next phase.
+func (k *updatingKeyPair) handleAckFor(pnum packetNumber) {
+	if k.updating && pnum >= k.minSent {
+		k.updating = false
+		k.phase ^= keyPhaseBit
+		k.r.update()
+		k.w.update()
+	}
+}
+
+// needAckEliciting reports whether we should send an ack-eliciting packet in the next phase.
+// The first packet sent in a phase is ack-eliciting, since the peer must acknowledge a
+// packet in the new phase for us to finish the update.
+func (k *updatingKeyPair) needAckEliciting() bool {
+	return k.updating && k.minSent == maxPacketNumber
+}
+
+// protect applies packet protection to a packet.
+// Parameters and returns are as for fixedKeyPair.protect.
+func (k *updatingKeyPair) protect(hdr, pay []byte, pnumOff int, pnum packetNumber) []byte {
+	// TODO: Initiate key updates as required to avoid the AEAD usage limit.
+	// https://www.rfc-editor.org/rfc/rfc9001#section-6.6
+	var pkt []byte
+	if k.updating {
+		hdr[0] |= k.phase ^ keyPhaseBit
+		pkt = k.w.pkt[1].protect(hdr, pay, pnum)
+		k.minSent = min(pnum, k.minSent)
+	} else {
+		hdr[0] |= k.phase
+		pkt = k.w.pkt[0].protect(hdr, pay, pnum)
+	}
+	k.w.hdr.protect(pkt, pnumOff)
+	return pkt
+}
+
+// unprotect removes packet protection from a packet.
+// Parameters and returns are as for fixedKeyPair.unprotect.
+func (k *updatingKeyPair) unprotect(pkt []byte, pnumOff int, pnumMax packetNumber) (pay []byte, pnum packetNumber, err error) {
+	hdr, pay, pnum, err := k.r.hdr.unprotect(pkt, pnumOff, pnumMax)
+	if err != nil {
+		return nil, 0, err
+	}
+	// To avoid timing signals that might indicate the key phase bit is invalid,
+	// we always attempt to unprotect the packet with one key.
+	//
+	// If the key phase bit matches and the packet number doesn't come after
+	// the start of an in-progress update, use the current phase.
+	// Otherwise, use the next phase.
+	if hdr[0]&keyPhaseBit == k.phase && (!k.updating || pnum < k.minReceived) {
+		pay, err = k.r.pkt[0].unprotect(hdr, pay, pnum)
+		if err != nil {
+			return nil, 0, err
+		}
+	} else {
+		pay, err = k.r.pkt[1].unprotect(hdr, pay, pnum)
+		if err != nil {
+			return nil, 0, err
+		}
+		if !k.updating {
+			// The peer has initiated a key update.
+			k.updating = true
+			k.minSent = maxPacketNumber
+			k.minReceived = pnum
+		} else {
+			k.minReceived = min(pnum, k.minReceived)
+		}
+	}
+	return pay, pnum, nil
+}
+
 // https://www.rfc-editor.org/rfc/rfc9001#section-5.2-2
 var initialSalt = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a}
 
diff --git a/internal/quic/packet_writer.go b/internal/quic/packet_writer.go
index 2009895..0c2b2ee 100644
--- a/internal/quic/packet_writer.go
+++ b/internal/quic/packet_writer.go
@@ -163,14 +163,13 @@
 // finish1RTTPacket finishes writing a 1-RTT packet,
 // canceling the packet if it contains no payload.
 // It returns a sentPacket describing the packet, or nil if no packet was written.
-func (w *packetWriter) finish1RTTPacket(pnum, pnumMaxAcked packetNumber, dstConnID []byte, k fixedKeys) *sentPacket {
+func (w *packetWriter) finish1RTTPacket(pnum, pnumMaxAcked packetNumber, dstConnID []byte, k *updatingKeyPair) *sentPacket {
 	if len(w.b) == w.payOff {
 		// The payload is empty, so just abandon the packet.
 		w.b = w.b[:w.pktOff]
 		return nil
 	}
 	// TODO: Spin
-	// TODO: Key phase
 	pnumLen := packetNumberLength(pnum, pnumMaxAcked)
 	hdr := w.b[:w.pktOff]
 	hdr = append(hdr, 0x40|byte(pnumLen-1))