quic: packet protection

Encrypt and decrypt QUIC packets according to RFC 9001.

For golang/go#58547

Change-Id: Ib7f824cf08f8520400bd38d3b3ab89e8a968114e
Reviewed-on: https://go-review.googlesource.com/c/net/+/475438
Reviewed-by: Roland Shoemaker <roland@golang.org>
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/go.mod b/go.mod
index f661af4..b1d3e54 100644
--- a/go.mod
+++ b/go.mod
@@ -3,6 +3,7 @@
 go 1.17
 
 require (
+	golang.org/x/crypto v0.9.0
 	golang.org/x/sys v0.8.0
 	golang.org/x/term v0.8.0
 	golang.org/x/text v0.9.0
diff --git a/go.sum b/go.sum
index 6408b66..af21d7c 100644
--- a/go.sum
+++ b/go.sum
@@ -1,12 +1,15 @@
 github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
 golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
+golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
+golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
 golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
 golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
 golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
 golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
 golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
 golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
+golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
 golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
diff --git a/internal/quic/packet_protection.go b/internal/quic/packet_protection.go
new file mode 100644
index 0000000..7d96d69
--- /dev/null
+++ b/internal/quic/packet_protection.go
@@ -0,0 +1,266 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package quic
+
+import (
+	"crypto"
+	"crypto/aes"
+	"crypto/cipher"
+	"crypto/sha256"
+	"crypto/tls"
+	"errors"
+	"fmt"
+	"hash"
+
+	"golang.org/x/crypto/chacha20"
+	"golang.org/x/crypto/chacha20poly1305"
+	"golang.org/x/crypto/cryptobyte"
+	"golang.org/x/crypto/hkdf"
+)
+
+var errInvalidPacket = errors.New("quic: invalid packet")
+
+// keys holds the cryptographic material used to protect packets
+// at an encryption level and direction. (e.g., Initial client keys.)
+//
+// keys are not safe for concurrent use.
+type keys struct {
+	// AEAD function used for packet protection.
+	aead cipher.AEAD
+
+	// The header_protection function as defined in:
+	// https://www.rfc-editor.org/rfc/rfc9001#section-5.4.1
+	//
+	// This function takes a sample of the packet ciphertext
+	// and returns a 5-byte mask which will be applied to the
+	// protected portions of the packet header.
+	headerProtection func(sample []byte) (mask [5]byte)
+
+	// IV used to construct the AEAD nonce.
+	iv []byte
+}
+
+// newKeys creates keys for a given cipher suite and secret.
+//
+// It returns an error if the suite is unknown.
+func newKeys(suite uint16, secret []byte) (keys, error) {
+	switch suite {
+	case tls.TLS_AES_128_GCM_SHA256:
+		return newAESKeys(secret, crypto.SHA256, 128/8), nil
+	case tls.TLS_AES_256_GCM_SHA384:
+		return newAESKeys(secret, crypto.SHA384, 256/8), nil
+	case tls.TLS_CHACHA20_POLY1305_SHA256:
+		return newChaCha20Keys(secret), nil
+	}
+	return keys{}, fmt.Errorf("unknown cipher suite %x", suite)
+}
+
+func newAESKeys(secret []byte, h crypto.Hash, keyBytes int) keys {
+	// https://www.rfc-editor.org/rfc/rfc9001#section-5.1
+	key := hkdfExpandLabel(h.New, secret, "quic key", nil, keyBytes)
+	c, err := aes.NewCipher(key)
+	if err != nil {
+		panic(err)
+	}
+	aead, err := cipher.NewGCM(c)
+	if err != nil {
+		panic(err)
+	}
+	iv := hkdfExpandLabel(h.New, secret, "quic iv", nil, aead.NonceSize())
+	// https://www.rfc-editor.org/rfc/rfc9001#section-5.4.3
+	hpKey := hkdfExpandLabel(h.New, secret, "quic hp", nil, keyBytes)
+	hp, err := aes.NewCipher(hpKey)
+	if err != nil {
+		panic(err)
+	}
+	var scratch [aes.BlockSize]byte
+	headerProtection := func(sample []byte) (mask [5]byte) {
+		hp.Encrypt(scratch[:], sample)
+		copy(mask[:], scratch[:])
+		return mask
+	}
+	return keys{
+		aead:             aead,
+		iv:               iv,
+		headerProtection: headerProtection,
+	}
+}
+
+func newChaCha20Keys(secret []byte) keys {
+	// https://www.rfc-editor.org/rfc/rfc9001#section-5.1
+	key := hkdfExpandLabel(sha256.New, secret, "quic key", nil, chacha20poly1305.KeySize)
+	aead, err := chacha20poly1305.New(key)
+	if err != nil {
+		panic(err)
+	}
+	iv := hkdfExpandLabel(sha256.New, secret, "quic iv", nil, aead.NonceSize())
+	// https://www.rfc-editor.org/rfc/rfc9001#section-5.4.4
+	hpKey := hkdfExpandLabel(sha256.New, secret, "quic hp", nil, chacha20.KeySize)
+	headerProtection := func(sample []byte) [5]byte {
+		counter := uint32(sample[3])<<24 | uint32(sample[2])<<16 | uint32(sample[1])<<8 | uint32(sample[0])
+		nonce := sample[4:16]
+		c, err := chacha20.NewUnauthenticatedCipher(hpKey, nonce)
+		if err != nil {
+			panic(err)
+		}
+		c.SetCounter(counter)
+		var mask [5]byte
+		c.XORKeyStream(mask[:], mask[:])
+		return mask
+	}
+	return keys{
+		aead:             aead,
+		iv:               iv,
+		headerProtection: headerProtection,
+	}
+}
+
+// https://www.rfc-editor.org/rfc/rfc9001#section-5.2-2
+var initialSalt = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a}
+
+// initialKeys returns the keys used to protect Initial packets.
+//
+// The Initial packet keys are derived from the Destination Connection ID
+// field in the client's first Initial packet.
+//
+// https://www.rfc-editor.org/rfc/rfc9001#section-5.2
+func initialKeys(cid []byte) (clientKeys, serverKeys keys) {
+	initialSecret := hkdf.Extract(sha256.New, cid, initialSalt)
+	clientInitialSecret := hkdfExpandLabel(sha256.New, initialSecret, "client in", nil, sha256.Size)
+	clientKeys, err := newKeys(tls.TLS_AES_128_GCM_SHA256, clientInitialSecret)
+	if err != nil {
+		panic(err)
+	}
+
+	serverInitialSecret := hkdfExpandLabel(sha256.New, initialSecret, "server in", nil, sha256.Size)
+	serverKeys, err = newKeys(tls.TLS_AES_128_GCM_SHA256, serverInitialSecret)
+	if err != nil {
+		panic(err)
+	}
+
+	return clientKeys, serverKeys
+}
+
+const headerProtectionSampleSize = 16
+
+// aeadOverhead is the difference in size between the AEAD output and input.
+// All cipher suites defined for use with QUIC have 16 bytes of overhead.
+const aeadOverhead = 16
+
+// xorIV xors the packet protection IV with the packet number.
+func (k keys) xorIV(pnum packetNumber) {
+	k.iv[len(k.iv)-8] ^= uint8(pnum >> 56)
+	k.iv[len(k.iv)-7] ^= uint8(pnum >> 48)
+	k.iv[len(k.iv)-6] ^= uint8(pnum >> 40)
+	k.iv[len(k.iv)-5] ^= uint8(pnum >> 32)
+	k.iv[len(k.iv)-4] ^= uint8(pnum >> 24)
+	k.iv[len(k.iv)-3] ^= uint8(pnum >> 16)
+	k.iv[len(k.iv)-2] ^= uint8(pnum >> 8)
+	k.iv[len(k.iv)-1] ^= uint8(pnum)
+}
+
+// initialized returns true if valid keys are available.
+func (k keys) initialized() bool {
+	return k.aead != nil
+}
+
+// discard discards the keys (in the sense that we won't use them any more,
+// not that the keys are securely erased).
+//
+// https://www.rfc-editor.org/rfc/rfc9001.html#section-4.9
+func (k *keys) discard() {
+	*k = keys{}
+}
+
+// protect applies packet protection to a packet.
+//
+// On input, hdr contains the packet header, pay the unencrypted payload,
+// pnumOff the offset of the packet number in the header, and pnum the untruncated
+// packet number.
+//
+// protect returns the result of appending the encrypted payload to hdr and
+// applying header protection.
+func (k keys) protect(hdr, pay []byte, pnumOff int, pnum packetNumber) []byte {
+	k.xorIV(pnum)
+	hdr = k.aead.Seal(hdr, k.iv, pay, hdr)
+	k.xorIV(pnum)
+
+	// Apply header protection.
+	pnumSize := int(hdr[0]&0x03) + 1
+	sample := hdr[pnumOff+4:][:headerProtectionSampleSize]
+	mask := k.headerProtection(sample)
+	if isLongHeader(hdr[0]) {
+		hdr[0] ^= mask[0] & 0x0f
+	} else {
+		hdr[0] ^= mask[0] & 0x1f
+	}
+	for i := 0; i < pnumSize; i++ {
+		hdr[pnumOff+i] ^= mask[1+i]
+	}
+
+	return hdr
+}
+
+// unprotect removes packet protection from a packet.
+//
+// On input, pkt contains the full protected packet, pnumOff the offset of
+// the packet number in the header, and pnumMax the largest packet number
+// seen in the number space of this packet.
+//
+// unprotect removes header protection from the header in pkt, and returns
+// the unprotected payload and packet number.
+func (k keys) unprotect(pkt []byte, pnumOff int, pnumMax packetNumber) (pay []byte, num packetNumber, err error) {
+	if len(pkt) < pnumOff+4+headerProtectionSampleSize {
+		fmt.Println("too short")
+		return nil, 0, errInvalidPacket
+	}
+	numpay := pkt[pnumOff:]
+	sample := numpay[4:][:headerProtectionSampleSize]
+	mask := k.headerProtection(sample)
+	if isLongHeader(pkt[0]) {
+		pkt[0] ^= mask[0] & 0x0f
+	} else {
+		pkt[0] ^= mask[0] & 0x1f
+	}
+	pnumLen := int(pkt[0]&0x03) + 1
+	pnum := packetNumber(0)
+	for i := 0; i < pnumLen; i++ {
+		numpay[i] ^= mask[1+i]
+		pnum = (pnum << 8) | packetNumber(numpay[i])
+	}
+	pnum = decodePacketNumber(pnumMax, pnum, pnumLen)
+
+	hdr := pkt[:pnumOff+pnumLen]
+	pay = numpay[pnumLen:]
+	k.xorIV(pnum)
+	pay, err = k.aead.Open(pay[:0], k.iv, pay, hdr)
+	k.xorIV(pnum)
+	if err != nil {
+		return nil, 0, err
+	}
+
+	return pay, pnum, nil
+}
+
+// hdkfExpandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1.
+//
+// Copied from crypto/tls/key_schedule.go.
+func hkdfExpandLabel(hash func() hash.Hash, secret []byte, label string, context []byte, length int) []byte {
+	var hkdfLabel cryptobyte.Builder
+	hkdfLabel.AddUint16(uint16(length))
+	hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
+		b.AddBytes([]byte("tls13 "))
+		b.AddBytes([]byte(label))
+	})
+	hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
+		b.AddBytes(context)
+	})
+	out := make([]byte, length)
+	n, err := hkdf.Expand(hash, secret, hkdfLabel.BytesOrPanic()).Read(out)
+	if err != nil || n != length {
+		panic("quic: HKDF-Expand-Label invocation failed unexpectedly")
+	}
+	return out
+}
diff --git a/internal/quic/packet_protection_test.go b/internal/quic/packet_protection_test.go
new file mode 100644
index 0000000..f1d353d
--- /dev/null
+++ b/internal/quic/packet_protection_test.go
@@ -0,0 +1,162 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package quic
+
+import (
+	"bytes"
+	"crypto/tls"
+	"testing"
+)
+
+func TestPacketProtection(t *testing.T) {
+	// Test cases from:
+	// https://www.rfc-editor.org/rfc/rfc9001#section-appendix.a
+	cid := unhex(`8394c8f03e515708`)
+	initialClientKeys, initialServerKeys := initialKeys(cid)
+	for _, test := range []struct {
+		name string
+		k    keys
+		pnum packetNumber
+		hdr  []byte
+		pay  []byte
+		prot []byte
+	}{{
+		name: "Client Initial",
+		k:    initialClientKeys,
+		pnum: 2,
+		hdr: unhex(`
+			c300000001088394c8f03e5157080000 449e00000002
+		`),
+		pay: pad(1162, unhex(`
+			060040f1010000ed0303ebf8fa56f129 39b9584a3896472ec40bb863cfd3e868
+			04fe3a47f06a2b69484c000004130113 02010000c000000010000e00000b6578
+			616d706c652e636f6dff01000100000a 00080006001d00170018001000070005
+			04616c706e0005000501000000000033 00260024001d00209370b2c9caa47fba
+			baf4559fedba753de171fa71f50f1ce1 5d43e994ec74d748002b000302030400
+			0d0010000e0403050306030203080408 050806002d00020101001c0002400100
+			3900320408ffffffffffffffff050480 00ffff07048000ffff08011001048000
+			75300901100f088394c8f03e51570806 048000ffff
+		`)),
+		prot: unhex(`
+			c000000001088394c8f03e5157080000 449e7b9aec34d1b1c98dd7689fb8ec11
+			d242b123dc9bd8bab936b47d92ec356c 0bab7df5976d27cd449f63300099f399
+			1c260ec4c60d17b31f8429157bb35a12 82a643a8d2262cad67500cadb8e7378c
+			8eb7539ec4d4905fed1bee1fc8aafba1 7c750e2c7ace01e6005f80fcb7df6212
+			30c83711b39343fa028cea7f7fb5ff89 eac2308249a02252155e2347b63d58c5
+			457afd84d05dfffdb20392844ae81215 4682e9cf012f9021a6f0be17ddd0c208
+			4dce25ff9b06cde535d0f920a2db1bf3 62c23e596d11a4f5a6cf3948838a3aec
+			4e15daf8500a6ef69ec4e3feb6b1d98e 610ac8b7ec3faf6ad760b7bad1db4ba3
+			485e8a94dc250ae3fdb41ed15fb6a8e5 eba0fc3dd60bc8e30c5c4287e53805db
+			059ae0648db2f64264ed5e39be2e20d8 2df566da8dd5998ccabdae053060ae6c
+			7b4378e846d29f37ed7b4ea9ec5d82e7 961b7f25a9323851f681d582363aa5f8
+			9937f5a67258bf63ad6f1a0b1d96dbd4 faddfcefc5266ba6611722395c906556
+			be52afe3f565636ad1b17d508b73d874 3eeb524be22b3dcbc2c7468d54119c74
+			68449a13d8e3b95811a198f3491de3e7 fe942b330407abf82a4ed7c1b311663a
+			c69890f4157015853d91e923037c227a 33cdd5ec281ca3f79c44546b9d90ca00
+			f064c99e3dd97911d39fe9c5d0b23a22 9a234cb36186c4819e8b9c5927726632
+			291d6a418211cc2962e20fe47feb3edf 330f2c603a9d48c0fcb5699dbfe58964
+			25c5bac4aee82e57a85aaf4e2513e4f0 5796b07ba2ee47d80506f8d2c25e50fd
+			14de71e6c418559302f939b0e1abd576 f279c4b2e0feb85c1f28ff18f58891ff
+			ef132eef2fa09346aee33c28eb130ff2 8f5b766953334113211996d20011a198
+			e3fc433f9f2541010ae17c1bf202580f 6047472fb36857fe843b19f5984009dd
+			c324044e847a4f4a0ab34f719595de37 252d6235365e9b84392b061085349d73
+			203a4a13e96f5432ec0fd4a1ee65accd d5e3904df54c1da510b0ff20dcc0c77f
+			cb2c0e0eb605cb0504db87632cf3d8b4 dae6e705769d1de354270123cb11450e
+			fc60ac47683d7b8d0f811365565fd98c 4c8eb936bcab8d069fc33bd801b03ade
+			a2e1fbc5aa463d08ca19896d2bf59a07 1b851e6c239052172f296bfb5e724047
+			90a2181014f3b94a4e97d117b4381303 68cc39dbb2d198065ae3986547926cd2
+			162f40a29f0c3c8745c0f50fba3852e5 66d44575c29d39a03f0cda721984b6f4
+			40591f355e12d439ff150aab7613499d bd49adabc8676eef023b15b65bfc5ca0
+			6948109f23f350db82123535eb8a7433 bdabcb909271a6ecbcb58b936a88cd4e
+			8f2e6ff5800175f113253d8fa9ca8885 c2f552e657dc603f252e1a8e308f76f0
+			be79e2fb8f5d5fbbe2e30ecadd220723 c8c0aea8078cdfcb3868263ff8f09400
+			54da48781893a7e49ad5aff4af300cd8 04a6b6279ab3ff3afb64491c85194aab
+			760d58a606654f9f4400e8b38591356f bf6425aca26dc85244259ff2b19c41b9
+			f96f3ca9ec1dde434da7d2d392b905dd f3d1f9af93d1af5950bd493f5aa731b4
+			056df31bd267b6b90a079831aaf579be 0a39013137aac6d404f518cfd4684064
+			7e78bfe706ca4cf5e9c5453e9f7cfd2b 8b4c8d169a44e55c88d4a9a7f9474241
+			e221af44860018ab0856972e194cd934
+		`),
+	}, {
+		name: "Server Initial",
+		k:    initialServerKeys,
+		pnum: 1,
+		hdr: unhex(`
+			c1000000010008f067a5502a4262b500 40750001
+		`),
+		pay: unhex(`
+			02000000000600405a020000560303ee fce7f7b37ba1d1632e96677825ddf739
+			88cfc79825df566dc5430b9a045a1200 130100002e00330024001d00209d3c94
+			0d89690b84d08a60993c144eca684d10 81287c834d5311bcf32bb9da1a002b00
+			020304
+		`),
+		prot: unhex(`
+			cf000000010008f067a5502a4262b500 4075c0d95a482cd0991cd25b0aac406a
+			5816b6394100f37a1c69797554780bb3 8cc5a99f5ede4cf73c3ec2493a1839b3
+			dbcba3f6ea46c5b7684df3548e7ddeb9 c3bf9c73cc3f3bded74b562bfb19fb84
+			022f8ef4cdd93795d77d06edbb7aaf2f 58891850abbdca3d20398c276456cbc4
+			2158407dd074ee
+		`),
+	}, {
+		name: "ChaCha20_Poly1305 Short Header",
+		k: func() keys {
+			secret := unhex(`
+				9ac312a7f877468ebe69422748ad00a1
+				5443f18203a07d6060f688f30f21632b
+			`)
+			k, err := newKeys(tls.TLS_CHACHA20_POLY1305_SHA256, secret)
+			if err != nil {
+				t.Fatal(err)
+			}
+			return k
+		}(),
+		pnum: 654360564,
+		hdr:  unhex(`4200bff4`),
+		pay:  unhex(`01`),
+		prot: unhex(`
+			4cfe4189655e5cd55c41f69080575d79 99c25a5bfb
+		`),
+	}} {
+		test := test
+		t.Run(test.name, func(t *testing.T) {
+			pnumLen := int(test.hdr[0]&0x03) + 1
+			pnumOff := len(test.hdr) - pnumLen
+
+			b := append([]byte{}, test.hdr...)
+			gotProt := test.k.protect(b, test.pay, pnumOff, test.pnum)
+			if got, want := gotProt, test.prot; !bytes.Equal(got, want) {
+				t.Errorf("Protected payload does not match:")
+				t.Errorf("got:  %x", got)
+				t.Errorf("want: %x", want)
+			}
+
+			pkt := append([]byte{}, test.prot...)
+			gotPay, gotNum, err := test.k.unprotect(pkt, pnumOff, test.pnum-1)
+			if err != nil {
+				t.Fatalf("Unexpected error unprotecting packet: %v", err)
+			}
+			if got, want := pkt[:len(test.hdr)], test.hdr; !bytes.Equal(got, want) {
+				t.Errorf("Unprotected header does not match:")
+				t.Errorf("got:  %x", got)
+				t.Errorf("want: %x", want)
+			}
+			if got, want := gotPay, test.pay; !bytes.Equal(got, want) {
+				t.Errorf("Unprotected payload does not match:")
+				t.Errorf("got:  %x", got)
+				t.Errorf("want: %x", want)
+			}
+			if got, want := gotNum, test.pnum; got != want {
+				t.Errorf("Unprotected packet number does not match: got %v, want %v", got, want)
+			}
+		})
+	}
+}
+
+func pad(n int, b []byte) []byte {
+	for len(b) < n {
+		b = append(b, 0)
+	}
+	return b
+}