quic: tls handshake
Exchange TLS handshake data in CRYPTO frames.
Receive packet protection keys from the TLS layer.
Discard packet protection keys as the handshake progresses.
Send and receive HANDSHAKE_DONE frames (used by the server
to inform the client of the handshake completing).
Add a very minimal implementation of CONNECTION_CLOSE,
just enough to let us write tests that trigger immediate
close of connections.
For golang/go#58547
Change-Id: I77496ca65bd72977565733739d563eaa2bb7d8d3
Reviewed-on: https://go-review.googlesource.com/c/net/+/510915
Reviewed-by: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Damien Neil <dneil@google.com>
Auto-Submit: Damien Neil <dneil@google.com>
diff --git a/internal/quic/config.go b/internal/quic/config.go
new file mode 100644
index 0000000..7d1b743
--- /dev/null
+++ b/internal/quic/config.go
@@ -0,0 +1,20 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "crypto/tls"
+)
+
+// A Config structure configures a QUIC endpoint.
+// A Config must not be modified after it has been passed to a QUIC function.
+// A Config may be reused; the quic package will also not modify it.
+type Config struct {
+ // TLSConfig is the endpoint's TLS configuration.
+ // It must be non-nil and include at least one certificate or else set GetCertificate.
+ TLSConfig *tls.Config
+}
diff --git a/internal/quic/conn.go b/internal/quic/conn.go
index e6375e8..8130c54 100644
--- a/internal/quic/conn.go
+++ b/internal/quic/conn.go
@@ -7,6 +7,7 @@
package quic
import (
+ "crypto/tls"
"errors"
"fmt"
"net/netip"
@@ -19,6 +20,7 @@
type Conn struct {
side connSide
listener connListener
+ config *Config
testHooks connTestHooks
peerAddr netip.AddrPort
@@ -29,14 +31,27 @@
w packetWriter
acks [numberSpaceCount]ackState // indexed by number space
connIDState connIDState
- tlsState tlsState
loss lossState
+ // errForPeer is set when the connection is being closed.
+ errForPeer error
+ connCloseSent [numberSpaceCount]bool
+
// idleTimeout is the time at which the connection will be closed due to inactivity.
// https://www.rfc-editor.org/rfc/rfc9000#section-10.1
maxIdleTimeout time.Duration
idleTimeout time.Time
+ // Packet protection keys, CRYPTO streams, and TLS state.
+ rkeys [numberSpaceCount]keys
+ wkeys [numberSpaceCount]keys
+ crypto [numberSpaceCount]cryptoStream
+ tls *tls.QUICConn
+
+ // handshakeConfirmed is set when the handshake is confirmed.
+ // For server connections, it tracks sending HANDSHAKE_DONE.
+ handshakeConfirmed sentVal
+
peerAckDelayExponent int8 // -1 when unknown
// Tests only: Send a PING in a specific number space.
@@ -53,12 +68,14 @@
// connTestHooks override conn behavior in tests.
type connTestHooks interface {
nextMessage(msgc chan any, nextTimeout time.Time) (now time.Time, message any)
+ handleTLSEvent(tls.QUICEvent)
}
-func newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip.AddrPort, l connListener, hooks connTestHooks) (*Conn, error) {
+func newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip.AddrPort, config *Config, l connListener, hooks connTestHooks) (*Conn, error) {
c := &Conn{
side: side,
listener: l,
+ config: config,
peerAddr: peerAddr,
msgc: make(chan any, 1),
donec: make(chan struct{}),
@@ -88,12 +105,58 @@
const maxDatagramSize = 1200
c.loss.init(c.side, maxDatagramSize, now)
- c.tlsState.init(c.side, initialConnID)
+ c.startTLS(now, initialConnID, transportParameters{
+ initialSrcConnID: c.connIDState.srcConnID(),
+ ackDelayExponent: ackDelayExponent,
+ maxUDPPayloadSize: maxUDPPayloadSize,
+ maxAckDelay: maxAckDelay,
+ })
go c.loop(now)
return c, nil
}
+// confirmHandshake is called when the handshake is confirmed.
+// https://www.rfc-editor.org/rfc/rfc9001#section-4.1.2
+func (c *Conn) confirmHandshake(now time.Time) {
+ // If handshakeConfirmed is unset, the handshake is not confirmed.
+ // If it is unsent, the handshake is confirmed and we need to send a HANDSHAKE_DONE.
+ // If it is sent, we have sent a HANDSHAKE_DONE.
+ // If it is received, the handshake is confirmed and we do not need to send anything.
+ if c.handshakeConfirmed.isSet() {
+ return // already confirmed
+ }
+ if c.side == serverSide {
+ // When the server confirms the handshake, it sends a HANDSHAKE_DONE.
+ c.handshakeConfirmed.setUnsent()
+ } else {
+ // The client never sends a HANDSHAKE_DONE, so we set handshakeConfirmed
+ // to the received state, indicating that the handshake is confirmed and we
+ // don't need to send anything.
+ c.handshakeConfirmed.setReceived()
+ }
+ c.loss.confirmHandshake()
+ // "An endpoint MUST discard its Handshake keys when the TLS handshake is confirmed"
+ // https://www.rfc-editor.org/rfc/rfc9001#section-4.9.2-1
+ c.discardKeys(now, handshakeSpace)
+}
+
+// discardKeys discards unused packet protection keys.
+// https://www.rfc-editor.org/rfc/rfc9001#section-4.9
+func (c *Conn) discardKeys(now time.Time, space numberSpace) {
+ c.rkeys[space].discard()
+ c.wkeys[space].discard()
+ c.loss.discardKeys(now, space)
+}
+
+// receiveTransportParameters applies transport parameters sent by the peer.
+func (c *Conn) receiveTransportParameters(p transportParameters) {
+ c.peerAckDelayExponent = p.ackDelayExponent
+ c.loss.setMaxAckDelay(p.maxAckDelay)
+
+ // TODO: Many more transport parameters to come.
+}
+
type timerEvent struct{}
// loop is the connection main loop.
@@ -104,6 +167,7 @@
// Other goroutines may examine or modify conn state by sending the loop funcs to execute.
func (c *Conn) loop(now time.Time) {
defer close(c.donec)
+ defer c.tls.Close()
// The connection timer sends a message to the connection loop on expiry.
// We need to give it an expiry when creating it, so set the initial timeout to
@@ -201,8 +265,9 @@
// abort terminates a connection with an error.
func (c *Conn) abort(now time.Time, err error) {
- // TODO: Send CONNECTION_CLOSE frames.
- c.exit()
+ if c.errForPeer == nil {
+ c.errForPeer = err
+ }
}
// exit fully terminates a connection immediately.
diff --git a/internal/quic/conn_loss.go b/internal/quic/conn_loss.go
index 11ed42d..6cb459c 100644
--- a/internal/quic/conn_loss.go
+++ b/internal/quic/conn_loss.go
@@ -29,7 +29,7 @@
for !sent.done() {
switch f := sent.next(); f {
default:
- panic(fmt.Sprintf("BUG: unhandled lost frame type %x", f))
+ panic(fmt.Sprintf("BUG: unhandled acked/lost frame type %x", f))
case frameTypeAck:
// Unlike most information, loss of an ACK frame does not trigger
// retransmission. ACKs are sent in response to ack-eliciting packets,
@@ -41,6 +41,11 @@
if fate == packetAcked {
c.acks[space].handleAck(largest)
}
+ case frameTypeCrypto:
+ start, end := sent.nextRange()
+ c.crypto[space].ackOrLoss(start, end, fate)
+ case frameTypeHandshakeDone:
+ c.handshakeConfirmed.ackOrLoss(sent.num, fate)
}
}
}
diff --git a/internal/quic/conn_loss_test.go b/internal/quic/conn_loss_test.go
new file mode 100644
index 0000000..be4f5fb
--- /dev/null
+++ b/internal/quic/conn_loss_test.go
@@ -0,0 +1,143 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "crypto/tls"
+ "testing"
+)
+
+// Frames may be retransmitted either when the packet containing the frame is lost, or on PTO.
+// lostFrameTest runs a test in both configurations.
+func lostFrameTest(t *testing.T, f func(t *testing.T, pto bool)) {
+ t.Run("lost", func(t *testing.T) {
+ f(t, false)
+ })
+ t.Run("pto", func(t *testing.T) {
+ f(t, true)
+ })
+}
+
+// triggerLossOrPTO causes the conn to declare the last sent packet lost,
+// or advances to the PTO timer.
+func (tc *testConn) triggerLossOrPTO(ptype packetType, pto bool) {
+ tc.t.Helper()
+ if pto {
+ if !tc.conn.loss.ptoTimerArmed {
+ tc.t.Fatalf("PTO timer not armed, expected it to be")
+ }
+ tc.advanceTo(tc.conn.loss.timer)
+ return
+ }
+ defer func(ignoreFrames map[byte]bool) {
+ tc.ignoreFrames = ignoreFrames
+ }(tc.ignoreFrames)
+ tc.ignoreFrames = map[byte]bool{
+ frameTypeAck: true,
+ frameTypePadding: true,
+ }
+ // Send three packets containing PINGs, and then respond with an ACK for the
+ // last one. This puts the last packet before the PINGs outside the packet
+ // reordering threshold, and it will be declared lost.
+ const lossThreshold = 3
+ var num packetNumber
+ for i := 0; i < lossThreshold; i++ {
+ tc.conn.ping(spaceForPacketType(ptype))
+ d := tc.readDatagram()
+ if d == nil {
+ tc.t.Fatalf("conn is idle; want PING frame")
+ }
+ if d.packets[0].ptype != ptype {
+ tc.t.Fatalf("conn sent %v packet; want %v", d.packets[0].ptype, ptype)
+ }
+ num = d.packets[0].num
+ }
+ tc.writeFrames(ptype, debugFrameAck{
+ ranges: []i64range[packetNumber]{
+ {num, num + 1},
+ },
+ })
+}
+
+func TestLostCRYPTOFrame(t *testing.T) {
+ // "Data sent in CRYPTO frames is retransmitted [...] until all data has been acknowledged."
+ // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.3-3.1
+ lostFrameTest(t, func(t *testing.T, pto bool) {
+ tc := newTestConn(t, clientSide)
+ tc.ignoreFrame(frameTypeAck)
+
+ tc.wantFrame("client sends Initial CRYPTO frame",
+ packetTypeInitial, debugFrameCrypto{
+ data: tc.cryptoDataOut[tls.QUICEncryptionLevelInitial],
+ })
+ tc.triggerLossOrPTO(packetTypeInitial, pto)
+ tc.wantFrame("client resends Initial CRYPTO frame",
+ packetTypeInitial, debugFrameCrypto{
+ data: tc.cryptoDataOut[tls.QUICEncryptionLevelInitial],
+ })
+
+ tc.writeFrames(packetTypeInitial,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial],
+ })
+ tc.writeFrames(packetTypeHandshake,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake],
+ })
+
+ tc.wantFrame("client sends Handshake CRYPTO frame",
+ packetTypeHandshake, debugFrameCrypto{
+ data: tc.cryptoDataOut[tls.QUICEncryptionLevelHandshake],
+ })
+ tc.triggerLossOrPTO(packetTypeHandshake, pto)
+ tc.wantFrame("client resends Handshake CRYPTO frame",
+ packetTypeHandshake, debugFrameCrypto{
+ data: tc.cryptoDataOut[tls.QUICEncryptionLevelHandshake],
+ })
+ })
+}
+
+func TestLostHandshakeDoneFrame(t *testing.T) {
+ // "The HANDSHAKE_DONE frame MUST be retransmitted until it is acknowledged."
+ // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.3-3.16
+ lostFrameTest(t, func(t *testing.T, pto bool) {
+ tc := newTestConn(t, serverSide)
+ tc.ignoreFrame(frameTypeAck)
+
+ tc.writeFrames(packetTypeInitial,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial],
+ })
+ tc.wantFrame("server sends Initial CRYPTO frame",
+ packetTypeInitial, debugFrameCrypto{
+ data: tc.cryptoDataOut[tls.QUICEncryptionLevelInitial],
+ })
+ tc.wantFrame("server sends Handshake CRYPTO frame",
+ packetTypeHandshake, debugFrameCrypto{
+ data: tc.cryptoDataOut[tls.QUICEncryptionLevelHandshake],
+ })
+ tc.writeFrames(packetTypeHandshake,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake],
+ })
+
+ tc.wantFrame("server sends HANDSHAKE_DONE after handshake completes",
+ packetType1RTT, debugFrameHandshakeDone{})
+ tc.wantFrame("server sends session ticket in CRYPTO frame",
+ packetType1RTT, debugFrameCrypto{
+ data: tc.cryptoDataOut[tls.QUICEncryptionLevelApplication],
+ })
+
+ tc.triggerLossOrPTO(packetType1RTT, pto)
+ tc.wantFrame("server resends HANDSHAKE_DONE",
+ packetType1RTT, debugFrameHandshakeDone{})
+ tc.wantFrame("server resends session ticket",
+ packetType1RTT, debugFrameCrypto{
+ data: tc.cryptoDataOut[tls.QUICEncryptionLevelApplication],
+ })
+ })
+}
diff --git a/internal/quic/conn_recv.go b/internal/quic/conn_recv.go
index d5a3b8c..7eb03e7 100644
--- a/internal/quic/conn_recv.go
+++ b/internal/quic/conn_recv.go
@@ -41,12 +41,12 @@
}
func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpace, buf []byte) int {
- if !c.tlsState.rkeys[space].isSet() {
+ if !c.rkeys[space].isSet() {
return skipLongHeaderPacket(buf)
}
pnumMax := c.acks[space].largestSeen()
- p, n := parseLongHeaderPacket(buf, c.tlsState.rkeys[space], pnumMax)
+ p, n := parseLongHeaderPacket(buf, c.rkeys[space], pnumMax)
if n < 0 {
return -1
}
@@ -66,21 +66,23 @@
if p.ptype == packetTypeHandshake && c.side == serverSide {
c.loss.validateClientAddress()
- // TODO: Discard Initial keys.
+ // "[...] a server MUST discard Initial keys when it first successfully
+ // processes a Handshake packet [...]"
// https://www.rfc-editor.org/rfc/rfc9001#section-4.9.1-2
+ c.discardKeys(now, initialSpace)
}
return n
}
func (c *Conn) handle1RTT(now time.Time, buf []byte) int {
- if !c.tlsState.rkeys[appDataSpace].isSet() {
+ if !c.rkeys[appDataSpace].isSet() {
// 1-RTT packets extend to the end of the datagram,
// so skip the remainder of the datagram if we can't parse this.
return len(buf)
}
pnumMax := c.acks[appDataSpace].largestSeen()
- p, n := parse1RTTPacket(buf, c.tlsState.rkeys[appDataSpace], connIDLen, pnumMax)
+ p, n := parse1RTTPacket(buf, c.rkeys[appDataSpace], connIDLen, pnumMax)
if n < 0 {
return -1
}
@@ -163,7 +165,7 @@
if !frameOK(c, ptype, IH_1) {
return
}
- _, _, n = consumeCryptoFrame(payload)
+ n = c.handleCryptoFrame(now, space, payload)
case frameTypeNewToken:
if !frameOK(c, ptype, ___1) {
return
@@ -207,14 +209,18 @@
case frameTypeConnectionCloseTransport:
// CONNECTION_CLOSE is OK in all spaces.
_, _, _, n = consumeConnectionCloseTransportFrame(payload)
+ // TODO: https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.2
+ c.abort(now, localTransportError(errNo))
case frameTypeConnectionCloseApplication:
// CONNECTION_CLOSE is OK in all spaces.
_, _, n = consumeConnectionCloseApplicationFrame(payload)
+ // TODO: https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.2
+ c.abort(now, localTransportError(errNo))
case frameTypeHandshakeDone:
if !frameOK(c, ptype, ___1) {
return
}
- n = 1
+ n = c.handleHandshakeDoneFrame(now, space, payload)
}
if n < 0 {
c.abort(now, localTransportError(errFrameEncoding))
@@ -262,3 +268,24 @@
c.loss.receiveAckEnd(now, space, delay, c.handleAckOrLoss)
return n
}
+
+func (c *Conn) handleCryptoFrame(now time.Time, space numberSpace, payload []byte) int {
+ off, data, n := consumeCryptoFrame(payload)
+ err := c.handleCrypto(now, space, off, data)
+ if err != nil {
+ c.abort(now, err)
+ return -1
+ }
+ return n
+}
+
+func (c *Conn) handleHandshakeDoneFrame(now time.Time, space numberSpace, payload []byte) int {
+ if c.side == serverSide {
+ // Clients should never send HANDSHAKE_DONE.
+ // https://www.rfc-editor.org/rfc/rfc9000#section-19.20-4
+ c.abort(now, localTransportError(errProtocolViolation))
+ return -1
+ }
+ c.confirmHandshake(now)
+ return 1
+}
diff --git a/internal/quic/conn_send.go b/internal/quic/conn_send.go
index 3a51ceb..71d24e6 100644
--- a/internal/quic/conn_send.go
+++ b/internal/quic/conn_send.go
@@ -7,6 +7,8 @@
package quic
import (
+ "crypto/tls"
+ "errors"
"time"
)
@@ -45,7 +47,7 @@
// Initial packet.
pad := false
var sentInitial *sentPacket
- if k := c.tlsState.wkeys[initialSpace]; k.isSet() {
+ if k := c.wkeys[initialSpace]; k.isSet() {
pnumMaxAcked := c.acks[initialSpace].largestSeen()
pnum := c.loss.nextNumber(initialSpace)
p := longPacket{
@@ -62,14 +64,14 @@
// Client initial packets need to be sent in a datagram padded to
// at least 1200 bytes. We can't add the padding yet, however,
// since we may want to coalesce additional packets with this one.
- if c.side == clientSide || sentInitial.ackEliciting {
+ if c.side == clientSide {
pad = true
}
}
}
// Handshake packet.
- if k := c.tlsState.wkeys[handshakeSpace]; k.isSet() {
+ if k := c.wkeys[handshakeSpace]; k.isSet() {
pnumMaxAcked := c.acks[handshakeSpace].largestSeen()
pnum := c.loss.nextNumber(handshakeSpace)
p := longPacket{
@@ -84,14 +86,16 @@
if sent := c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, k, p); sent != nil {
c.loss.packetSent(now, handshakeSpace, sent)
if c.side == clientSide {
- // TODO: Discard the Initial keys.
- // https://www.rfc-editor.org/rfc/rfc9001.html#section-4.9.1
+ // "[...] a client MUST discard Initial keys when it first
+ // sends a Handshake packet [...]"
+ // https://www.rfc-editor.org/rfc/rfc9001.html#section-4.9.1-2
+ c.discardKeys(now, initialSpace)
}
}
}
// 1-RTT packet.
- if k := c.tlsState.wkeys[appDataSpace]; k.isSet() {
+ if k := c.wkeys[appDataSpace]; k.isSet() {
pnumMaxAcked := c.acks[appDataSpace].largestSeen()
pnum := c.loss.nextNumber(appDataSpace)
dstConnID := c.connIDState.dstConnID()
@@ -133,7 +137,7 @@
sentInitial.inFlight = true
}
}
- if k := c.tlsState.wkeys[initialSpace]; k.isSet() {
+ if k := c.wkeys[initialSpace]; k.isSet() {
c.loss.packetSent(now, initialSpace, sentInitial)
}
}
@@ -143,6 +147,26 @@
}
func (c *Conn) appendFrames(now time.Time, space numberSpace, pnum packetNumber, limit ccLimit) {
+ if c.errForPeer != nil {
+ // This is the bare minimum required to send a CONNECTION_CLOSE frame
+ // when closing a connection immediately, for example in response to a
+ // protocol error.
+ //
+ // This does not handle the closing and draining states
+ // (https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2),
+ // but it's enough to let us write tests that result in a CONNECTION_CLOSE,
+ // and have those tests still pass when we finish implementing
+ // connection shutdown.
+ //
+ // TODO: Finish implementing connection shutdown.
+ if !c.connCloseSent[space] {
+ c.exited = true
+ c.appendConnectionCloseFrame(c.errForPeer)
+ c.connCloseSent[space] = true
+ }
+ return
+ }
+
shouldSendAck := c.acks[space].shouldSendAck(now)
if limit != ccOK {
// ACKs are not limited by congestion control.
@@ -185,6 +209,21 @@
// TODO: Add all the other frames we can send.
+ // HANDSHAKE_DONE
+ if c.handshakeConfirmed.shouldSendPTO(pto) {
+ if !c.w.appendHandshakeDoneFrame() {
+ return
+ }
+ c.handshakeConfirmed.setSent(pnum)
+ }
+
+ // CRYPTO
+ c.crypto[space].dataToSend(pto, func(off, size int64) int64 {
+ b, _ := c.w.appendCryptoFrame(off, int(size))
+ c.crypto[space].sendData(off, b)
+ return int64(len(b))
+ })
+
// Test-only PING frames.
if space == c.testSendPingSpace && c.testSendPing.shouldSendPTO(pto) {
if !c.w.appendPingFrame() {
@@ -253,3 +292,22 @@
d := unscaledAckDelayFromDuration(delay, ackDelayExponent)
return c.w.appendAckFrame(seen, d)
}
+
+func (c *Conn) appendConnectionCloseFrame(err error) {
+ // TODO: Send application errors.
+ switch e := err.(type) {
+ case localTransportError:
+ c.w.appendConnectionCloseTransportFrame(transportError(e), 0, "")
+ default:
+ // TLS alerts are sent using error codes [0x0100,0x01ff).
+ // https://www.rfc-editor.org/rfc/rfc9000#section-20.1-2.36.1
+ var alert tls.AlertError
+ if errors.As(err, &alert) {
+ // tls.AlertError is a uint8, so this can't exceed 0x01ff.
+ code := errTLSBase + transportError(alert)
+ c.w.appendConnectionCloseTransportFrame(code, 0, "")
+ return
+ }
+ c.w.appendConnectionCloseTransportFrame(errInternal, 0, "")
+ }
+}
diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go
index fda1d4b..511fb97 100644
--- a/internal/quic/conn_test.go
+++ b/internal/quic/conn_test.go
@@ -7,6 +7,9 @@
package quic
import (
+ "bytes"
+ "context"
+ "crypto/tls"
"errors"
"fmt"
"math"
@@ -111,8 +114,22 @@
// we use Handshake keys to encrypt the packet.
// The client only acquires those keys when it processes
// the Initial packet.
- rkeys [numberSpaceCount]keys // for packets sent to the conn
- wkeys [numberSpaceCount]keys // for packets sent by the conn
+ rkeys [numberSpaceCount]keyData // for packets sent to the conn
+ wkeys [numberSpaceCount]keyData // for packets sent by the conn
+
+ // testConn uses a test hook to snoop on the conn's TLS events.
+ // CRYPTO data produced by the conn's QUICConn is placed in
+ // cryptoDataOut.
+ //
+ // The peerTLSConn is is a QUICConn representing the peer.
+ // CRYPTO data produced by the conn is written to peerTLSConn,
+ // and data produced by peerTLSConn is placed in cryptoDataIn.
+ cryptoDataOut map[tls.QUICEncryptionLevel][]byte
+ cryptoDataIn map[tls.QUICEncryptionLevel][]byte
+ peerTLSConn *tls.QUICConn
+
+ localConnID []byte
+ transientConnID []byte
// Information about the conn's (fake) peer.
peerConnID []byte // source conn id of peer's packets
@@ -129,12 +146,18 @@
ignoreFrames map[byte]bool
}
+type keyData struct {
+ suite uint16
+ secret []byte
+ k keys
+}
+
// newTestConn creates a Conn for testing.
//
// The Conn's event loop is controlled by the test,
// allowing test code to access Conn state directly
// by first ensuring the loop goroutine is idle.
-func newTestConn(t *testing.T, side connSide) *testConn {
+func newTestConn(t *testing.T, side connSide, opts ...any) *testConn {
t.Helper()
tc := &testConn{
t: t,
@@ -143,9 +166,24 @@
ignoreFrames: map[byte]bool{
frameTypePadding: true, // ignore PADDING by default
},
+ cryptoDataOut: make(map[tls.QUICEncryptionLevel][]byte),
+ cryptoDataIn: make(map[tls.QUICEncryptionLevel][]byte),
}
t.Cleanup(tc.cleanup)
+ config := &Config{
+ TLSConfig: newTestTLSConfig(side),
+ }
+ peerProvidedParams := defaultTransportParameters()
+ for _, o := range opts {
+ switch o := o.(type) {
+ case func(*tls.Config):
+ o(config.TLSConfig)
+ default:
+ t.Fatalf("unknown newTestConn option %T", o)
+ }
+ }
+
var initialConnID []byte
if side == serverSide {
// The initial connection ID for the server is chosen by the client.
@@ -157,11 +195,21 @@
}
}
+ peerQUICConfig := &tls.QUICConfig{TLSConfig: newTestTLSConfig(side.peer())}
+ if side == clientSide {
+ tc.peerTLSConn = tls.QUICServer(peerQUICConfig)
+ } else {
+ tc.peerTLSConn = tls.QUICClient(peerQUICConfig)
+ }
+ tc.peerTLSConn.SetTransportParameters(marshalTransportParameters(peerProvidedParams))
+ tc.peerTLSConn.Start(context.Background())
+
conn, err := newConn(
tc.now,
side,
initialConnID,
netip.MustParseAddrPort("127.0.0.1:443"),
+ config,
(*testConnListener)(tc),
(*testConnHooks)(tc))
if err != nil {
@@ -169,8 +217,16 @@
}
tc.conn = conn
- tc.wkeys[initialSpace] = conn.tlsState.wkeys[initialSpace]
- tc.rkeys[initialSpace] = conn.tlsState.rkeys[initialSpace]
+ if side == serverSide {
+ tc.transientConnID = tc.conn.connIDState.local[0].cid
+ tc.localConnID = tc.conn.connIDState.local[1].cid
+ } else if side == clientSide {
+ tc.transientConnID = tc.conn.connIDState.remote[0].cid
+ tc.localConnID = tc.conn.connIDState.local[0].cid
+ }
+
+ tc.wkeys[initialSpace].k = conn.wkeys[initialSpace]
+ tc.rkeys[initialSpace].k = conn.rkeys[initialSpace]
tc.wait()
return tc
@@ -385,7 +441,7 @@
tc.t.Fatalf("%v:\nconnection is idle\nwant %v frame: %v", expectation, wantType, want)
}
if gotType != wantType {
- tc.t.Fatalf("%v:\ngot %v packet, want %v", expectation, wantType, want)
+ tc.t.Fatalf("%v:\ngot %v packet, want %v\ngot frame: %v", expectation, gotType, wantType, got)
}
if !reflect.DeepEqual(got, want) {
tc.t.Fatalf("%v:\ngot frame: %v\nwant frame: %v", expectation, got, want)
@@ -426,12 +482,12 @@
f.write(&w)
}
space := spaceForPacketType(p.ptype)
- if !tc.rkeys[space].isSet() {
+ if !tc.rkeys[space].k.isSet() {
tc.t.Fatalf("sending packet with no %v keys available", space)
return nil
}
if p.ptype != packetType1RTT {
- w.finishProtectedLongHeaderPacket(pnumMaxAcked, tc.rkeys[space], longPacket{
+ w.finishProtectedLongHeaderPacket(pnumMaxAcked, tc.rkeys[space].k, longPacket{
ptype: p.ptype,
version: p.version,
num: p.num,
@@ -439,7 +495,7 @@
srcConnID: p.srcConnID,
})
} else {
- w.finish1RTTPacket(p.num, pnumMaxAcked, p.dstConnID, tc.rkeys[space])
+ w.finish1RTTPacket(p.num, pnumMaxAcked, p.dstConnID, tc.rkeys[space].k)
}
return w.datagram()
}
@@ -455,12 +511,12 @@
}
ptype := getPacketType(buf)
space := spaceForPacketType(ptype)
- if !tc.wkeys[space].isSet() {
+ if !tc.wkeys[space].k.isSet() {
tc.t.Fatalf("no keys for space %v, packet type %v", space, ptype)
}
if isLongHeader(buf[0]) {
var pnumMax packetNumber // TODO: Track packet numbers.
- p, n := parseLongHeaderPacket(buf, tc.wkeys[space], pnumMax)
+ p, n := parseLongHeaderPacket(buf, tc.wkeys[space].k, pnumMax)
if n < 0 {
tc.t.Fatalf("packet parse error")
}
@@ -479,11 +535,10 @@
buf = buf[n:]
} else {
var pnumMax packetNumber // TODO: Track packet numbers.
- p, n := parse1RTTPacket(buf, tc.wkeys[space], len(tc.peerConnID), pnumMax)
+ p, n := parse1RTTPacket(buf, tc.wkeys[space].k, len(tc.peerConnID), pnumMax)
if n < 0 {
tc.t.Fatalf("packet parse error")
}
- dstConnID, _ := dstConnIDForDatagram(buf)
frames, err := tc.parseTestFrames(p.payload)
if err != nil {
tc.t.Fatal(err)
@@ -491,7 +546,7 @@
d.packets = append(d.packets, &testPacket{
ptype: packetType1RTT,
num: p.num,
- dstConnID: dstConnID,
+ dstConnID: buf[1:][:len(tc.peerConnID)],
frames: frames,
})
buf = buf[n:]
@@ -535,6 +590,73 @@
// testConnHooks implements connTestHooks.
type testConnHooks testConn
+// handleTLSEvent processes TLS events generated by
+// the connection under test's tls.QUICConn.
+//
+// We maintain a second tls.QUICConn representing the peer,
+// and feed the TLS handshake data into it.
+//
+// We stash TLS handshake data from both sides in the testConn,
+// where it can be used by tests.
+//
+// We snoop packet protection keys out of the tls.QUICConns,
+// and verify that both sides of the connection are getting
+// matching keys.
+func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) {
+ setKey := func(keys *[numberSpaceCount]keyData, e tls.QUICEvent) {
+ k, err := newKeys(e.Suite, e.Data)
+ if err != nil {
+ tc.t.Errorf("newKeys: %v", err)
+ return
+ }
+ var space numberSpace
+ switch {
+ case e.Level == tls.QUICEncryptionLevelHandshake:
+ space = handshakeSpace
+ case e.Level == tls.QUICEncryptionLevelApplication:
+ space = appDataSpace
+ default:
+ tc.t.Errorf("unexpected encryption level %v", e.Level)
+ return
+ }
+ s := "read"
+ if keys == &tc.wkeys {
+ s = "write"
+ }
+ if keys[space].k.isSet() {
+ if keys[space].suite != e.Suite || !bytes.Equal(keys[space].secret, e.Data) {
+ tc.t.Errorf("%v key mismatch for level for level %v", s, e.Level)
+ }
+ return
+ }
+ keys[space].suite = e.Suite
+ keys[space].secret = append([]byte{}, e.Data...)
+ keys[space].k = k
+ }
+ switch e.Kind {
+ case tls.QUICSetReadSecret:
+ setKey(&tc.rkeys, e)
+ case tls.QUICSetWriteSecret:
+ setKey(&tc.wkeys, e)
+ case tls.QUICWriteData:
+ tc.cryptoDataOut[e.Level] = append(tc.cryptoDataOut[e.Level], e.Data...)
+ tc.peerTLSConn.HandleData(e.Level, e.Data)
+ }
+ for {
+ e := tc.peerTLSConn.NextEvent()
+ switch e.Kind {
+ case tls.QUICNoEvent:
+ return
+ case tls.QUICSetReadSecret:
+ setKey(&tc.wkeys, e)
+ case tls.QUICSetWriteSecret:
+ setKey(&tc.rkeys, e)
+ case tls.QUICWriteData:
+ tc.cryptoDataIn[e.Level] = append(tc.cryptoDataIn[e.Level], e.Data...)
+ }
+ }
+}
+
// nextMessage is called by the Conn's event loop to request its next event.
func (tc *testConnHooks) nextMessage(msgc chan any, timer time.Time) (now time.Time, m any) {
tc.timer = timer
diff --git a/internal/quic/ping_test.go b/internal/quic/ping_test.go
index 4a732ed..c370aaf 100644
--- a/internal/quic/ping_test.go
+++ b/internal/quic/ping_test.go
@@ -10,26 +10,34 @@
func TestPing(t *testing.T) {
tc := newTestConn(t, clientSide)
- tc.conn.ping(initialSpace)
+ tc.handshake()
+
+ tc.conn.ping(appDataSpace)
tc.wantFrame("connection should send a PING frame",
- packetTypeInitial, debugFramePing{})
+ packetType1RTT, debugFramePing{})
tc.advanceToTimer()
tc.wantFrame("on PTO, connection should send another PING frame",
- packetTypeInitial, debugFramePing{})
+ packetType1RTT, debugFramePing{})
tc.wantIdle("after sending PTO probe, no additional frames to send")
}
func TestAck(t *testing.T) {
tc := newTestConn(t, serverSide)
- tc.writeFrames(packetTypeInitial,
+ tc.handshake()
+
+ // Send two packets, to trigger an immediate ACK.
+ tc.writeFrames(packetType1RTT,
+ debugFramePing{},
+ )
+ tc.writeFrames(packetType1RTT,
debugFramePing{},
)
tc.wantFrame("connection should respond to ack-eliciting packet with an ACK frame",
- packetTypeInitial,
+ packetType1RTT,
debugFrameAck{
- ranges: []i64range[packetNumber]{{0, 1}},
+ ranges: []i64range[packetNumber]{{0, 3}},
},
)
}
diff --git a/internal/quic/quic.go b/internal/quic/quic.go
index 9df7f7e..a61c91f 100644
--- a/internal/quic/quic.go
+++ b/internal/quic/quic.go
@@ -64,6 +64,14 @@
}
}
+func (s connSide) peer() connSide {
+ if s == clientSide {
+ return serverSide
+ } else {
+ return clientSide
+ }
+}
+
// A numberSpace is the context in which a packet number applies.
// https://www.rfc-editor.org/rfc/rfc9000.html#section-12.3-7
type numberSpace byte
diff --git a/internal/quic/tls.go b/internal/quic/tls.go
index 1cdb727..4306a3e 100644
--- a/internal/quic/tls.go
+++ b/internal/quic/tls.go
@@ -6,18 +6,132 @@
package quic
-// tlsState encapsulates interactions with TLS.
-type tlsState struct {
- // Encryption keys indexed by number space.
- rkeys [numberSpaceCount]keys
- wkeys [numberSpaceCount]keys
+import (
+ "context"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "time"
+)
+
+// startTLS starts the TLS handshake.
+func (c *Conn) startTLS(now time.Time, initialConnID []byte, params transportParameters) error {
+ clientKeys, serverKeys := initialKeys(initialConnID)
+ if c.side == clientSide {
+ c.wkeys[initialSpace], c.rkeys[initialSpace] = clientKeys, serverKeys
+ } else {
+ c.wkeys[initialSpace], c.rkeys[initialSpace] = serverKeys, clientKeys
+ }
+
+ qconfig := &tls.QUICConfig{TLSConfig: c.config.TLSConfig}
+ if c.side == clientSide {
+ c.tls = tls.QUICClient(qconfig)
+ } else {
+ c.tls = tls.QUICServer(qconfig)
+ }
+ c.tls.SetTransportParameters(marshalTransportParameters(params))
+ // TODO: We don't need or want a context for cancelation here,
+ // but users can use a context to plumb values through to hooks defined
+ // in the tls.Config. Pass through a context.
+ if err := c.tls.Start(context.TODO()); err != nil {
+ return err
+ }
+ return c.handleTLSEvents(now)
}
-func (s *tlsState) init(side connSide, initialConnID []byte) {
- clientKeys, serverKeys := initialKeys(initialConnID)
- if side == clientSide {
- s.wkeys[initialSpace], s.rkeys[initialSpace] = clientKeys, serverKeys
- } else {
- s.wkeys[initialSpace], s.rkeys[initialSpace] = serverKeys, clientKeys
+func (c *Conn) handleTLSEvents(now time.Time) error {
+ for {
+ e := c.tls.NextEvent()
+ if c.testHooks != nil {
+ c.testHooks.handleTLSEvent(e)
+ }
+ switch e.Kind {
+ case tls.QUICNoEvent:
+ return nil
+ case tls.QUICSetReadSecret:
+ space, k, err := tlsKey(e)
+ if err != nil {
+ return err
+ }
+ c.rkeys[space] = k
+ case tls.QUICSetWriteSecret:
+ space, k, err := tlsKey(e)
+ if err != nil {
+ return err
+ }
+ c.wkeys[space] = k
+ case tls.QUICWriteData:
+ space, err := spaceForLevel(e.Level)
+ if err != nil {
+ return err
+ }
+ c.crypto[space].write(e.Data)
+ case tls.QUICHandshakeDone:
+ if c.side == serverSide {
+ // "[...] the TLS handshake is considered confirmed
+ // at the server when the handshake completes."
+ // https://www.rfc-editor.org/rfc/rfc9001#section-4.1.2-1
+ c.confirmHandshake(now)
+ if !c.config.TLSConfig.SessionTicketsDisabled {
+ if err := c.tls.SendSessionTicket(false); err != nil {
+ return err
+ }
+ }
+ }
+ case tls.QUICTransportParameters:
+ params, err := unmarshalTransportParams(e.Data)
+ if err != nil {
+ return err
+ }
+ c.receiveTransportParameters(params)
+ }
}
}
+
+// tlsKey returns the keys in a QUICSetReadSecret or QUICSetWriteSecret event.
+func tlsKey(e tls.QUICEvent) (numberSpace, keys, error) {
+ space, err := spaceForLevel(e.Level)
+ if err != nil {
+ return 0, keys{}, err
+ }
+ k, err := newKeys(e.Suite, e.Data)
+ if err != nil {
+ return 0, keys{}, err
+ }
+ return space, k, nil
+}
+
+func spaceForLevel(level tls.QUICEncryptionLevel) (numberSpace, error) {
+ switch level {
+ case tls.QUICEncryptionLevelInitial:
+ return initialSpace, nil
+ case tls.QUICEncryptionLevelHandshake:
+ return handshakeSpace, nil
+ case tls.QUICEncryptionLevelApplication:
+ return appDataSpace, nil
+ default:
+ return 0, fmt.Errorf("quic: internal error: write handshake data at level %v", level)
+ }
+}
+
+// handleCrypto processes data received in a CRYPTO frame.
+func (c *Conn) handleCrypto(now time.Time, space numberSpace, off int64, data []byte) error {
+ var level tls.QUICEncryptionLevel
+ switch space {
+ case initialSpace:
+ level = tls.QUICEncryptionLevelInitial
+ case handshakeSpace:
+ level = tls.QUICEncryptionLevelHandshake
+ case appDataSpace:
+ level = tls.QUICEncryptionLevelApplication
+ default:
+ return errors.New("quic: internal error: received CRYPTO frame in unexpected number space")
+ }
+ err := c.crypto[space].handleCrypto(off, data, func(b []byte) error {
+ return c.tls.HandleData(level, b)
+ })
+ if err != nil {
+ return err
+ }
+ return c.handleTLSEvents(now)
+}
diff --git a/internal/quic/tls_test.go b/internal/quic/tls_test.go
new file mode 100644
index 0000000..df07820
--- /dev/null
+++ b/internal/quic/tls_test.go
@@ -0,0 +1,421 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "crypto/tls"
+ "crypto/x509"
+ "errors"
+ "reflect"
+ "testing"
+ "time"
+)
+
+// handshake executes the handshake.
+func (tc *testConn) handshake() {
+ tc.t.Helper()
+ defer func(saved map[byte]bool) {
+ tc.ignoreFrames = saved
+ }(tc.ignoreFrames)
+ tc.ignoreFrames = nil
+ t := tc.t
+ dgrams := handshakeDatagrams(tc)
+ i := 0
+ for {
+ if i == len(dgrams)-1 {
+ if tc.conn.side == clientSide {
+ want := tc.now.Add(maxAckDelay - timerGranularity)
+ if !tc.timer.Equal(want) {
+ t.Fatalf("want timer = %v (max_ack_delay), got %v", want, tc.timer)
+ }
+ if got := tc.readDatagram(); got != nil {
+ t.Fatalf("client unexpectedly sent: %v", got)
+ }
+ }
+ tc.advance(maxAckDelay)
+ }
+
+ // Check that we're sending exactly the data we expect.
+ // Any variation from the norm here should be intentional.
+ got := tc.readDatagram()
+ var want *testDatagram
+ if !(tc.conn.side == serverSide && i == 0) && i < len(dgrams) {
+ want = dgrams[i]
+ fillCryptoFrames(want, tc.cryptoDataOut)
+ i++
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Fatalf("dgram %v:\ngot %v\n\nwant %v", i, got, want)
+ }
+ if i >= len(dgrams) {
+ break
+ }
+
+ fillCryptoFrames(dgrams[i], tc.cryptoDataIn)
+ tc.write(dgrams[i])
+ i++
+ }
+}
+
+func handshakeDatagrams(tc *testConn) (dgrams []*testDatagram) {
+ var (
+ clientConnID []byte
+ serverConnID []byte
+ )
+ if tc.conn.side == clientSide {
+ clientConnID = tc.localConnID
+ serverConnID = tc.peerConnID
+ } else {
+ clientConnID = tc.peerConnID
+ serverConnID = tc.localConnID
+ }
+ return []*testDatagram{{
+ // Client Initial
+ packets: []*testPacket{{
+ ptype: packetTypeInitial,
+ num: 0,
+ version: 1,
+ srcConnID: clientConnID,
+ dstConnID: tc.transientConnID,
+ frames: []debugFrame{
+ debugFrameCrypto{},
+ },
+ }},
+ paddedSize: 1200,
+ }, {
+ // Server Initial + Handshake
+ packets: []*testPacket{{
+ ptype: packetTypeInitial,
+ num: 0,
+ version: 1,
+ srcConnID: serverConnID,
+ dstConnID: clientConnID,
+ frames: []debugFrame{
+ debugFrameAck{
+ ranges: []i64range[packetNumber]{{0, 1}},
+ },
+ debugFrameCrypto{},
+ },
+ }, {
+ ptype: packetTypeHandshake,
+ num: 0,
+ version: 1,
+ srcConnID: serverConnID,
+ dstConnID: clientConnID,
+ frames: []debugFrame{
+ debugFrameCrypto{},
+ },
+ }},
+ }, {
+ // Client Handshake
+ packets: []*testPacket{{
+ ptype: packetTypeInitial,
+ num: 1,
+ version: 1,
+ srcConnID: clientConnID,
+ dstConnID: serverConnID,
+ frames: []debugFrame{
+ debugFrameAck{
+ ranges: []i64range[packetNumber]{{0, 1}},
+ },
+ },
+ }, {
+ ptype: packetTypeHandshake,
+ num: 0,
+ version: 1,
+ srcConnID: clientConnID,
+ dstConnID: serverConnID,
+ frames: []debugFrame{
+ debugFrameAck{
+ ranges: []i64range[packetNumber]{{0, 1}},
+ },
+ debugFrameCrypto{},
+ },
+ }},
+ paddedSize: 1200,
+ }, {
+ // Server HANDSHAKE_DONE and session ticket
+ packets: []*testPacket{{
+ ptype: packetType1RTT,
+ num: 0,
+ dstConnID: clientConnID,
+ frames: []debugFrame{
+ debugFrameHandshakeDone{},
+ debugFrameCrypto{},
+ },
+ }},
+ }, {
+ // Client ack (after max_ack_delay)
+ packets: []*testPacket{{
+ ptype: packetType1RTT,
+ num: 0,
+ dstConnID: serverConnID,
+ frames: []debugFrame{
+ debugFrameAck{
+ ackDelay: unscaledAckDelayFromDuration(
+ maxAckDelay, ackDelayExponent),
+ ranges: []i64range[packetNumber]{{0, 1}},
+ },
+ },
+ }},
+ }}
+}
+
+func fillCryptoFrames(d *testDatagram, data map[tls.QUICEncryptionLevel][]byte) {
+ for _, p := range d.packets {
+ var level tls.QUICEncryptionLevel
+ switch p.ptype {
+ case packetTypeInitial:
+ level = tls.QUICEncryptionLevelInitial
+ case packetTypeHandshake:
+ level = tls.QUICEncryptionLevelHandshake
+ case packetType1RTT:
+ level = tls.QUICEncryptionLevelApplication
+ default:
+ continue
+ }
+ for i := range p.frames {
+ c, ok := p.frames[i].(debugFrameCrypto)
+ if !ok {
+ continue
+ }
+ c.data = data[level]
+ data[level] = nil
+ p.frames[i] = c
+ }
+ }
+}
+
+func TestConnClientHandshake(t *testing.T) {
+ tc := newTestConn(t, clientSide)
+ tc.handshake()
+ tc.advance(1 * time.Second)
+ tc.wantIdle("no packets should be sent by an idle conn after the handshake")
+}
+
+func TestConnServerHandshake(t *testing.T) {
+ tc := newTestConn(t, serverSide)
+ tc.handshake()
+ tc.advance(1 * time.Second)
+ tc.wantIdle("no packets should be sent by an idle conn after the handshake")
+}
+
+func TestConnKeysDiscardedClient(t *testing.T) {
+ tc := newTestConn(t, clientSide)
+ tc.ignoreFrame(frameTypeAck)
+
+ tc.wantFrame("client sends Initial CRYPTO frame",
+ packetTypeInitial, debugFrameCrypto{
+ data: tc.cryptoDataOut[tls.QUICEncryptionLevelInitial],
+ })
+ tc.writeFrames(packetTypeInitial,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial],
+ })
+ tc.writeFrames(packetTypeHandshake,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake],
+ })
+ tc.wantFrame("client sends Handshake CRYPTO frame",
+ packetTypeHandshake, debugFrameCrypto{
+ data: tc.cryptoDataOut[tls.QUICEncryptionLevelHandshake],
+ })
+
+ // The client discards Initial keys after sending a Handshake packet.
+ tc.writeFrames(packetTypeInitial,
+ debugFrameConnectionCloseTransport{code: errInternal})
+ tc.wantIdle("client has discarded Initial keys, cannot read CONNECTION_CLOSE")
+
+ // The client discards Handshake keys after receiving a HANDSHAKE_DONE frame.
+ tc.writeFrames(packetType1RTT,
+ debugFrameHandshakeDone{})
+ tc.writeFrames(packetTypeHandshake,
+ debugFrameConnectionCloseTransport{code: errInternal})
+ tc.wantIdle("client has discarded Handshake keys, cannot read CONNECTION_CLOSE")
+
+ tc.writeFrames(packetType1RTT,
+ debugFrameConnectionCloseTransport{code: errInternal})
+ tc.wantFrame("client closes connection after 1-RTT CONNECTION_CLOSE",
+ packetType1RTT, debugFrameConnectionCloseTransport{
+ code: errNo,
+ })
+}
+
+func TestConnKeysDiscardedServer(t *testing.T) {
+ tc := newTestConn(t, serverSide, func(c *tls.Config) {
+ c.SessionTicketsDisabled = true
+ })
+ tc.ignoreFrame(frameTypeAck)
+
+ tc.writeFrames(packetTypeInitial,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial],
+ })
+ tc.wantFrame("server sends Initial CRYPTO frame",
+ packetTypeInitial, debugFrameCrypto{
+ data: tc.cryptoDataOut[tls.QUICEncryptionLevelInitial],
+ })
+ tc.wantFrame("server sends Handshake CRYPTO frame",
+ packetTypeHandshake, debugFrameCrypto{
+ data: tc.cryptoDataOut[tls.QUICEncryptionLevelHandshake],
+ })
+
+ // The server discards Initial keys after receiving a Handshake packet.
+ // The Handshake packet contains only the start of the client's CRYPTO flight here,
+ // to avoids completing the handshake yet.
+ tc.writeFrames(packetTypeHandshake,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake][:1],
+ })
+ tc.writeFrames(packetTypeInitial,
+ debugFrameConnectionCloseTransport{code: errInternal})
+ tc.wantIdle("server has discarded Initial keys, cannot read CONNECTION_CLOSE")
+
+ // The server discards Handshake keys after sending a HANDSHAKE_DONE frame.
+ tc.writeFrames(packetTypeHandshake,
+ debugFrameCrypto{
+ off: 1,
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake][1:],
+ })
+ tc.wantFrame("server sends HANDSHAKE_DONE after handshake completes",
+ packetType1RTT, debugFrameHandshakeDone{})
+ tc.writeFrames(packetTypeHandshake,
+ debugFrameConnectionCloseTransport{code: errInternal})
+ tc.wantIdle("server has discarded Handshake keys, cannot read CONNECTION_CLOSE")
+
+ tc.writeFrames(packetType1RTT,
+ debugFrameConnectionCloseTransport{code: errInternal})
+ tc.wantFrame("server closes connection after 1-RTT CONNECTION_CLOSE",
+ packetType1RTT, debugFrameConnectionCloseTransport{
+ code: errNo,
+ })
+}
+
+func TestConnInvalidCryptoData(t *testing.T) {
+ tc := newTestConn(t, clientSide)
+ tc.ignoreFrame(frameTypeAck)
+
+ tc.wantFrame("client sends Initial CRYPTO frame",
+ packetTypeInitial, debugFrameCrypto{
+ data: tc.cryptoDataOut[tls.QUICEncryptionLevelInitial],
+ })
+ tc.writeFrames(packetTypeInitial,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial],
+ })
+
+ // Render the server's response invalid.
+ //
+ // The client closes the connection with CRYPTO_ERROR.
+ //
+ // Changing the first byte will change the TLS message type,
+ // so we can reasonably assume that this is an unexpected_message alert (10).
+ tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake][0] ^= 0x1
+ tc.writeFrames(packetTypeHandshake,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake],
+ })
+ tc.wantFrame("client closes connection due to TLS handshake error",
+ packetTypeInitial, debugFrameConnectionCloseTransport{
+ code: errTLSBase + 10,
+ })
+}
+
+func TestConnInvalidPeerCertificate(t *testing.T) {
+ tc := newTestConn(t, clientSide, func(c *tls.Config) {
+ c.VerifyPeerCertificate = func([][]byte, [][]*x509.Certificate) error {
+ return errors.New("I will not buy this certificate. It is scratched.")
+ }
+ })
+ tc.ignoreFrame(frameTypeAck)
+
+ tc.wantFrame("client sends Initial CRYPTO frame",
+ packetTypeInitial, debugFrameCrypto{
+ data: tc.cryptoDataOut[tls.QUICEncryptionLevelInitial],
+ })
+ tc.writeFrames(packetTypeInitial,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial],
+ })
+ tc.writeFrames(packetTypeHandshake,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake],
+ })
+ tc.wantFrame("client closes connection due to rejecting server certificate",
+ packetTypeInitial, debugFrameConnectionCloseTransport{
+ code: errTLSBase + 42, // 42: bad_certificate
+ })
+}
+
+func TestConnHandshakeDoneSentToServer(t *testing.T) {
+ tc := newTestConn(t, serverSide)
+ tc.handshake()
+
+ tc.writeFrames(packetType1RTT,
+ debugFrameHandshakeDone{})
+ tc.wantFrame("server closes connection when client sends a HANDSHAKE_DONE frame",
+ packetType1RTT, debugFrameConnectionCloseTransport{
+ code: errProtocolViolation,
+ })
+}
+
+func TestConnCryptoDataOutOfOrder(t *testing.T) {
+ tc := newTestConn(t, clientSide)
+ tc.ignoreFrame(frameTypeAck)
+
+ tc.wantFrame("client sends Initial CRYPTO frame",
+ packetTypeInitial, debugFrameCrypto{
+ data: tc.cryptoDataOut[tls.QUICEncryptionLevelInitial],
+ })
+ tc.writeFrames(packetTypeInitial,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial],
+ })
+ tc.wantIdle("client is idle, server Handshake flight has not arrived")
+
+ tc.writeFrames(packetTypeHandshake,
+ debugFrameCrypto{
+ off: 15,
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake][15:],
+ })
+ tc.wantIdle("client is idle, server Handshake flight is not complete")
+
+ tc.writeFrames(packetTypeHandshake,
+ debugFrameCrypto{
+ off: 1,
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake][1:20],
+ })
+ tc.wantIdle("client is idle, server Handshake flight is still not complete")
+
+ tc.writeFrames(packetTypeHandshake,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake][0:1],
+ })
+ tc.wantFrame("client sends Handshake CRYPTO frame",
+ packetTypeHandshake, debugFrameCrypto{
+ data: tc.cryptoDataOut[tls.QUICEncryptionLevelHandshake],
+ })
+}
+
+func TestConnCryptoBufferSizeExceeded(t *testing.T) {
+ tc := newTestConn(t, clientSide)
+ tc.ignoreFrame(frameTypeAck)
+
+ tc.wantFrame("client sends Initial CRYPTO frame",
+ packetTypeInitial, debugFrameCrypto{
+ data: tc.cryptoDataOut[tls.QUICEncryptionLevelInitial],
+ })
+ tc.writeFrames(packetTypeInitial,
+ debugFrameCrypto{
+ off: cryptoBufferSize,
+ data: []byte{0},
+ })
+ tc.wantFrame("client closes connection after server exceeds CRYPTO buffer",
+ packetTypeInitial, debugFrameConnectionCloseTransport{
+ code: errCryptoBufferExceeded,
+ })
+}
diff --git a/internal/quic/tlsconfig_test.go b/internal/quic/tlsconfig_test.go
new file mode 100644
index 0000000..47bfb05
--- /dev/null
+++ b/internal/quic/tlsconfig_test.go
@@ -0,0 +1,62 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "crypto/tls"
+ "strings"
+)
+
+func newTestTLSConfig(side connSide) *tls.Config {
+ config := &tls.Config{
+ InsecureSkipVerify: true,
+ CipherSuites: []uint16{
+ tls.TLS_AES_128_GCM_SHA256,
+ tls.TLS_AES_256_GCM_SHA384,
+ tls.TLS_CHACHA20_POLY1305_SHA256,
+ },
+ MinVersion: tls.VersionTLS13,
+ }
+ if side == serverSide {
+ config.Certificates = []tls.Certificate{testCert}
+ }
+ return config
+}
+
+var testCert = func() tls.Certificate {
+ cert, err := tls.X509KeyPair(localhostCert, localhostKey)
+ if err != nil {
+ panic(err)
+ }
+ return cert
+}()
+
+// localhostCert is a PEM-encoded TLS cert with SAN IPs
+// "127.0.0.1" and "[::1]", expiring at Jan 29 16:00:00 2084 GMT.
+// generated from src/crypto/tls:
+// go run generate_cert.go --ecdsa-curve P256 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
+var localhostCert = []byte(`-----BEGIN CERTIFICATE-----
+MIIBrDCCAVKgAwIBAgIPCvPhO+Hfv+NW76kWxULUMAoGCCqGSM49BAMCMBIxEDAO
+BgNVBAoTB0FjbWUgQ28wIBcNNzAwMTAxMDAwMDAwWhgPMjA4NDAxMjkxNjAwMDBa
+MBIxEDAOBgNVBAoTB0FjbWUgQ28wWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAARh
+WRF8p8X9scgW7JjqAwI9nYV8jtkdhqAXG9gyEgnaFNN5Ze9l3Tp1R9yCDBMNsGms
+PyfMPe5Jrha/LmjgR1G9o4GIMIGFMA4GA1UdDwEB/wQEAwIChDATBgNVHSUEDDAK
+BggrBgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MB0GA1UdDgQWBBSOJri/wLQxq6oC
+Y6ZImms/STbTljAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAA
+AAAAAAAAAAAAATAKBggqhkjOPQQDAgNIADBFAiBUguxsW6TGhixBAdORmVNnkx40
+HjkKwncMSDbUaeL9jQIhAJwQ8zV9JpQvYpsiDuMmqCuW35XXil3cQ6Drz82c+fvE
+-----END CERTIFICATE-----`)
+
+// localhostKey is the private key for localhostCert.
+var localhostKey = []byte(testingKey(`-----BEGIN TESTING KEY-----
+MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgY1B1eL/Bbwf/MDcs
+rnvvWhFNr1aGmJJR59PdCN9lVVqhRANCAARhWRF8p8X9scgW7JjqAwI9nYV8jtkd
+hqAXG9gyEgnaFNN5Ze9l3Tp1R9yCDBMNsGmsPyfMPe5Jrha/LmjgR1G9
+-----END TESTING KEY-----`))
+
+// testingKey helps keep security scanners from getting excited about a private key in this file.
+func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }
diff --git a/internal/quic/transport_params.go b/internal/quic/transport_params.go
index 416bfb8..89ea69f 100644
--- a/internal/quic/transport_params.go
+++ b/internal/quic/transport_params.go
@@ -25,7 +25,7 @@
initialMaxStreamDataUni int64
initialMaxStreamsBidi int64
initialMaxStreamsUni int64
- ackDelayExponent uint8
+ ackDelayExponent int8
maxAckDelay time.Duration
disableActiveMigration bool
preferredAddrV4 netip.AddrPort
@@ -220,7 +220,7 @@
if v > 20 {
return p, localTransportError(errTransportParameter)
}
- p.ackDelayExponent = uint8(v)
+ p.ackDelayExponent = int8(v)
case paramMaxAckDelay:
var v uint64
v, n = consumeVarint(val)