blob: f4f1818a6485f7105fbd56f44d395ea26c7a920b [file] [log] [blame]
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.21
package quic
import (
"bytes"
"context"
"crypto/tls"
"errors"
"flag"
"fmt"
"log/slog"
"math"
"net/netip"
"reflect"
"strings"
"testing"
"time"
"golang.org/x/net/quic/qlog"
)
var (
testVV = flag.Bool("vv", false, "even more verbose test output")
qlogdir = flag.String("qlog", "", "write qlog logs to directory")
)
func TestConnTestConn(t *testing.T) {
tc := newTestConn(t, serverSide)
tc.handshake()
if got, want := tc.timeUntilEvent(), defaultMaxIdleTimeout; got != want {
t.Errorf("new conn timeout=%v, want %v (max_idle_timeout)", got, want)
}
ranAt, _ := runAsync(tc, func(ctx context.Context) (when time.Time, _ error) {
tc.conn.runOnLoop(ctx, func(now time.Time, c *Conn) {
when = now
})
return
}).result()
if !ranAt.Equal(tc.endpoint.now) {
t.Errorf("func ran on loop at %v, want %v", ranAt, tc.endpoint.now)
}
tc.wait()
nextTime := tc.endpoint.now.Add(defaultMaxIdleTimeout / 2)
tc.advanceTo(nextTime)
ranAt, _ = runAsync(tc, func(ctx context.Context) (when time.Time, _ error) {
tc.conn.runOnLoop(ctx, func(now time.Time, c *Conn) {
when = now
})
return
}).result()
if !ranAt.Equal(nextTime) {
t.Errorf("func ran on loop at %v, want %v", ranAt, nextTime)
}
tc.wait()
tc.advanceToTimer()
if got := tc.conn.lifetime.state; got != connStateDone {
t.Errorf("after advancing to idle timeout, conn state = %v, want done", got)
}
}
type testDatagram struct {
packets []*testPacket
paddedSize int
addr netip.AddrPort
}
func (d testDatagram) String() string {
var b strings.Builder
fmt.Fprintf(&b, "datagram with %v packets", len(d.packets))
if d.paddedSize > 0 {
fmt.Fprintf(&b, " (padded to %v bytes)", d.paddedSize)
}
b.WriteString(":")
for _, p := range d.packets {
b.WriteString("\n")
b.WriteString(p.String())
}
return b.String()
}
type testPacket struct {
ptype packetType
header byte
version uint32
num packetNumber
keyPhaseBit bool
keyNumber int
dstConnID []byte
srcConnID []byte
token []byte
originalDstConnID []byte // used for encoding Retry packets
frames []debugFrame
}
func (p testPacket) String() string {
var b strings.Builder
fmt.Fprintf(&b, " %v %v", p.ptype, p.num)
if p.version != 0 {
fmt.Fprintf(&b, " version=%v", p.version)
}
if p.srcConnID != nil {
fmt.Fprintf(&b, " src={%x}", p.srcConnID)
}
if p.dstConnID != nil {
fmt.Fprintf(&b, " dst={%x}", p.dstConnID)
}
if p.token != nil {
fmt.Fprintf(&b, " token={%x}", p.token)
}
for _, f := range p.frames {
fmt.Fprintf(&b, "\n %v", f)
}
return b.String()
}
// maxTestKeyPhases is the maximum number of 1-RTT keys we'll generate in a test.
const maxTestKeyPhases = 3
// A testConn is a Conn whose external interactions (sending and receiving packets,
// setting timers) can be manipulated in tests.
type testConn struct {
t *testing.T
conn *Conn
endpoint *testEndpoint
timer time.Time
timerLastFired time.Time
idlec chan struct{} // only accessed on the conn's loop
// Keys are distinct from the conn's keys,
// because the test may know about keys before the conn does.
// For example, when sending a datagram with coalesced
// Initial and Handshake packets to a client conn,
// we use Handshake keys to encrypt the packet.
// The client only acquires those keys when it processes
// the Initial packet.
keysInitial fixedKeyPair
keysHandshake fixedKeyPair
rkeyAppData test1RTTKeys
wkeyAppData test1RTTKeys
rsecrets [numberSpaceCount]keySecret
wsecrets [numberSpaceCount]keySecret
// testConn uses a test hook to snoop on the conn's TLS events.
// CRYPTO data produced by the conn's QUICConn is placed in
// cryptoDataOut.
//
// The peerTLSConn is is a QUICConn representing the peer.
// CRYPTO data produced by the conn is written to peerTLSConn,
// and data produced by peerTLSConn is placed in cryptoDataIn.
cryptoDataOut map[tls.QUICEncryptionLevel][]byte
cryptoDataIn map[tls.QUICEncryptionLevel][]byte
peerTLSConn *tls.QUICConn
// Information about the conn's (fake) peer.
peerConnID []byte // source conn id of peer's packets
peerNextPacketNum [numberSpaceCount]packetNumber // next packet number to use
// Datagrams, packets, and frames sent by the conn,
// but not yet processed by the test.
sentDatagrams [][]byte
sentPackets []*testPacket
sentFrames []debugFrame
lastDatagram *testDatagram
lastPacket *testPacket
recvDatagram chan *datagram
// Transport parameters sent by the conn.
sentTransportParameters *transportParameters
// Frame types to ignore in tests.
ignoreFrames map[byte]bool
// Values to set in packets sent to the conn.
sendKeyNumber int
sendKeyPhaseBit bool
asyncTestState
}
type test1RTTKeys struct {
hdr headerKey
pkt [maxTestKeyPhases]packetKey
}
type keySecret struct {
suite uint16
secret []byte
}
// newTestConn creates a Conn for testing.
//
// The Conn's event loop is controlled by the test,
// allowing test code to access Conn state directly
// by first ensuring the loop goroutine is idle.
func newTestConn(t *testing.T, side connSide, opts ...any) *testConn {
t.Helper()
config := &Config{
TLSConfig: newTestTLSConfig(side),
StatelessResetKey: testStatelessResetKey,
QLogLogger: slog.New(qlog.NewJSONHandler(qlog.HandlerOptions{
Level: QLogLevelFrame,
Dir: *qlogdir,
})),
}
var cids newServerConnIDs
if side == serverSide {
// The initial connection ID for the server is chosen by the client.
cids.srcConnID = testPeerConnID(0)
cids.dstConnID = testPeerConnID(-1)
cids.originalDstConnID = cids.dstConnID
}
var configTransportParams []func(*transportParameters)
var configTestConn []func(*testConn)
for _, o := range opts {
switch o := o.(type) {
case func(*Config):
o(config)
case func(*tls.Config):
o(config.TLSConfig)
case func(cids *newServerConnIDs):
o(&cids)
case func(p *transportParameters):
configTransportParams = append(configTransportParams, o)
case func(p *testConn):
configTestConn = append(configTestConn, o)
default:
t.Fatalf("unknown newTestConn option %T", o)
}
}
endpoint := newTestEndpoint(t, config)
endpoint.configTransportParams = configTransportParams
endpoint.configTestConn = configTestConn
conn, err := endpoint.e.newConn(
endpoint.now,
config,
side,
cids,
"",
netip.MustParseAddrPort("127.0.0.1:443"))
if err != nil {
t.Fatal(err)
}
tc := endpoint.conns[conn]
tc.wait()
return tc
}
func newTestConnForConn(t *testing.T, endpoint *testEndpoint, conn *Conn) *testConn {
t.Helper()
tc := &testConn{
t: t,
endpoint: endpoint,
conn: conn,
peerConnID: testPeerConnID(0),
ignoreFrames: map[byte]bool{
frameTypePadding: true, // ignore PADDING by default
},
cryptoDataOut: make(map[tls.QUICEncryptionLevel][]byte),
cryptoDataIn: make(map[tls.QUICEncryptionLevel][]byte),
recvDatagram: make(chan *datagram),
}
t.Cleanup(tc.cleanup)
for _, f := range endpoint.configTestConn {
f(tc)
}
conn.testHooks = (*testConnHooks)(tc)
if endpoint.peerTLSConn != nil {
tc.peerTLSConn = endpoint.peerTLSConn
endpoint.peerTLSConn = nil
return tc
}
peerProvidedParams := defaultTransportParameters()
peerProvidedParams.initialSrcConnID = testPeerConnID(0)
if conn.side == clientSide {
peerProvidedParams.originalDstConnID = testLocalConnID(-1)
}
for _, f := range endpoint.configTransportParams {
f(&peerProvidedParams)
}
peerQUICConfig := &tls.QUICConfig{TLSConfig: newTestTLSConfig(conn.side.peer())}
if conn.side == clientSide {
tc.peerTLSConn = tls.QUICServer(peerQUICConfig)
} else {
tc.peerTLSConn = tls.QUICClient(peerQUICConfig)
}
tc.peerTLSConn.SetTransportParameters(marshalTransportParameters(peerProvidedParams))
tc.peerTLSConn.Start(context.Background())
t.Cleanup(func() {
tc.peerTLSConn.Close()
})
return tc
}
// advance causes time to pass.
func (tc *testConn) advance(d time.Duration) {
tc.t.Helper()
tc.endpoint.advance(d)
}
// advanceTo sets the current time.
func (tc *testConn) advanceTo(now time.Time) {
tc.t.Helper()
tc.endpoint.advanceTo(now)
}
// advanceToTimer sets the current time to the time of the Conn's next timer event.
func (tc *testConn) advanceToTimer() {
if tc.timer.IsZero() {
tc.t.Fatalf("advancing to timer, but timer is not set")
}
tc.advanceTo(tc.timer)
}
func (tc *testConn) timerDelay() time.Duration {
if tc.timer.IsZero() {
return math.MaxInt64 // infinite
}
if tc.timer.Before(tc.endpoint.now) {
return 0
}
return tc.timer.Sub(tc.endpoint.now)
}
const infiniteDuration = time.Duration(math.MaxInt64)
// timeUntilEvent returns the amount of time until the next connection event.
func (tc *testConn) timeUntilEvent() time.Duration {
if tc.timer.IsZero() {
return infiniteDuration
}
if tc.timer.Before(tc.endpoint.now) {
return 0
}
return tc.timer.Sub(tc.endpoint.now)
}
// wait blocks until the conn becomes idle.
// The conn is idle when it is blocked waiting for a packet to arrive or a timer to expire.
// Tests shouldn't need to call wait directly.
// testConn methods that wake the Conn event loop will call wait for them.
func (tc *testConn) wait() {
tc.t.Helper()
idlec := make(chan struct{})
fail := false
tc.conn.sendMsg(func(now time.Time, c *Conn) {
if tc.idlec != nil {
tc.t.Errorf("testConn.wait called concurrently")
fail = true
close(idlec)
} else {
// nextMessage will close idlec.
tc.idlec = idlec
}
})
select {
case <-idlec:
case <-tc.conn.donec:
// We may have async ops that can proceed now that the conn is done.
tc.wakeAsync()
}
if fail {
panic(fail)
}
}
func (tc *testConn) cleanup() {
if tc.conn == nil {
return
}
tc.conn.exit()
<-tc.conn.donec
}
func (tc *testConn) acceptStream() *Stream {
tc.t.Helper()
s, err := tc.conn.AcceptStream(canceledContext())
if err != nil {
tc.t.Fatalf("conn.AcceptStream() = %v, want stream", err)
}
s.SetReadContext(canceledContext())
s.SetWriteContext(canceledContext())
return s
}
func logDatagram(t *testing.T, text string, d *testDatagram) {
t.Helper()
if !*testVV {
return
}
pad := ""
if d.paddedSize > 0 {
pad = fmt.Sprintf(" (padded to %v)", d.paddedSize)
}
t.Logf("%v datagram%v", text, pad)
for _, p := range d.packets {
var s string
switch p.ptype {
case packetType1RTT:
s = fmt.Sprintf(" %v pnum=%v", p.ptype, p.num)
default:
s = fmt.Sprintf(" %v pnum=%v ver=%v dst={%x} src={%x}", p.ptype, p.num, p.version, p.dstConnID, p.srcConnID)
}
if p.token != nil {
s += fmt.Sprintf(" token={%x}", p.token)
}
if p.keyPhaseBit {
s += fmt.Sprintf(" KeyPhase")
}
if p.keyNumber != 0 {
s += fmt.Sprintf(" keynum=%v", p.keyNumber)
}
t.Log(s)
for _, f := range p.frames {
t.Logf(" %v", f)
}
}
}
// write sends the Conn a datagram.
func (tc *testConn) write(d *testDatagram) {
tc.t.Helper()
tc.endpoint.writeDatagram(d)
}
// writeFrame sends the Conn a datagram containing the given frames.
func (tc *testConn) writeFrames(ptype packetType, frames ...debugFrame) {
tc.t.Helper()
space := spaceForPacketType(ptype)
dstConnID := tc.conn.connIDState.local[0].cid
if tc.conn.connIDState.local[0].seq == -1 && ptype != packetTypeInitial {
// Only use the transient connection ID in Initial packets.
dstConnID = tc.conn.connIDState.local[1].cid
}
d := &testDatagram{
packets: []*testPacket{{
ptype: ptype,
num: tc.peerNextPacketNum[space],
keyNumber: tc.sendKeyNumber,
keyPhaseBit: tc.sendKeyPhaseBit,
frames: frames,
version: quicVersion1,
dstConnID: dstConnID,
srcConnID: tc.peerConnID,
}},
addr: tc.conn.peerAddr,
}
if ptype == packetTypeInitial && tc.conn.side == serverSide {
d.paddedSize = 1200
}
tc.write(d)
}
// writeAckForAll sends the Conn a datagram containing an ack for all packets up to the
// last one received.
func (tc *testConn) writeAckForAll() {
tc.t.Helper()
if tc.lastPacket == nil {
return
}
tc.writeFrames(tc.lastPacket.ptype, debugFrameAck{
ranges: []i64range[packetNumber]{{0, tc.lastPacket.num + 1}},
})
}
// writeAckForLatest sends the Conn a datagram containing an ack for the
// most recent packet received.
func (tc *testConn) writeAckForLatest() {
tc.t.Helper()
if tc.lastPacket == nil {
return
}
tc.writeFrames(tc.lastPacket.ptype, debugFrameAck{
ranges: []i64range[packetNumber]{{tc.lastPacket.num, tc.lastPacket.num + 1}},
})
}
// ignoreFrame hides frames of the given type sent by the Conn.
func (tc *testConn) ignoreFrame(frameType byte) {
tc.ignoreFrames[frameType] = true
}
// readDatagram reads the next datagram sent by the Conn.
// It returns nil if the Conn has no more datagrams to send at this time.
func (tc *testConn) readDatagram() *testDatagram {
tc.t.Helper()
tc.wait()
tc.sentPackets = nil
tc.sentFrames = nil
buf := tc.endpoint.read()
if buf == nil {
return nil
}
d := parseTestDatagram(tc.t, tc.endpoint, tc, buf)
// Log the datagram before removing ignored frames.
// When things go wrong, it's useful to see all the frames.
logDatagram(tc.t, "-> conn under test sends", d)
typeForFrame := func(f debugFrame) byte {
// This is very clunky, and points at a problem
// in how we specify what frames to ignore in tests.
//
// We mark frames to ignore using the frame type,
// but we've got a debugFrame data structure here.
// Perhaps we should be ignoring frames by debugFrame
// type instead: tc.ignoreFrame[debugFrameAck]().
switch f := f.(type) {
case debugFramePadding:
return frameTypePadding
case debugFramePing:
return frameTypePing
case debugFrameAck:
return frameTypeAck
case debugFrameResetStream:
return frameTypeResetStream
case debugFrameStopSending:
return frameTypeStopSending
case debugFrameCrypto:
return frameTypeCrypto
case debugFrameNewToken:
return frameTypeNewToken
case debugFrameStream:
return frameTypeStreamBase
case debugFrameMaxData:
return frameTypeMaxData
case debugFrameMaxStreamData:
return frameTypeMaxStreamData
case debugFrameMaxStreams:
if f.streamType == bidiStream {
return frameTypeMaxStreamsBidi
} else {
return frameTypeMaxStreamsUni
}
case debugFrameDataBlocked:
return frameTypeDataBlocked
case debugFrameStreamDataBlocked:
return frameTypeStreamDataBlocked
case debugFrameStreamsBlocked:
if f.streamType == bidiStream {
return frameTypeStreamsBlockedBidi
} else {
return frameTypeStreamsBlockedUni
}
case debugFrameNewConnectionID:
return frameTypeNewConnectionID
case debugFrameRetireConnectionID:
return frameTypeRetireConnectionID
case debugFramePathChallenge:
return frameTypePathChallenge
case debugFramePathResponse:
return frameTypePathResponse
case debugFrameConnectionCloseTransport:
return frameTypeConnectionCloseTransport
case debugFrameConnectionCloseApplication:
return frameTypeConnectionCloseApplication
case debugFrameHandshakeDone:
return frameTypeHandshakeDone
}
panic(fmt.Errorf("unhandled frame type %T", f))
}
for _, p := range d.packets {
var frames []debugFrame
for _, f := range p.frames {
if !tc.ignoreFrames[typeForFrame(f)] {
frames = append(frames, f)
}
}
p.frames = frames
}
tc.lastDatagram = d
return d
}
// readPacket reads the next packet sent by the Conn.
// It returns nil if the Conn has no more packets to send at this time.
func (tc *testConn) readPacket() *testPacket {
tc.t.Helper()
for len(tc.sentPackets) == 0 {
d := tc.readDatagram()
if d == nil {
return nil
}
for _, p := range d.packets {
if len(p.frames) == 0 {
tc.lastPacket = p
continue
}
tc.sentPackets = append(tc.sentPackets, p)
}
}
p := tc.sentPackets[0]
tc.sentPackets = tc.sentPackets[1:]
tc.lastPacket = p
return p
}
// readFrame reads the next frame sent by the Conn.
// It returns nil if the Conn has no more frames to send at this time.
func (tc *testConn) readFrame() (debugFrame, packetType) {
tc.t.Helper()
for len(tc.sentFrames) == 0 {
p := tc.readPacket()
if p == nil {
return nil, packetTypeInvalid
}
tc.sentFrames = p.frames
}
f := tc.sentFrames[0]
tc.sentFrames = tc.sentFrames[1:]
return f, tc.lastPacket.ptype
}
// wantDatagram indicates that we expect the Conn to send a datagram.
func (tc *testConn) wantDatagram(expectation string, want *testDatagram) {
tc.t.Helper()
got := tc.readDatagram()
if !datagramEqual(got, want) {
tc.t.Fatalf("%v:\ngot datagram: %v\nwant datagram: %v", expectation, got, want)
}
}
func datagramEqual(a, b *testDatagram) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
if a.paddedSize != b.paddedSize ||
a.addr != b.addr ||
len(a.packets) != len(b.packets) {
return false
}
for i := range a.packets {
if !packetEqual(a.packets[i], b.packets[i]) {
return false
}
}
return true
}
// wantPacket indicates that we expect the Conn to send a packet.
func (tc *testConn) wantPacket(expectation string, want *testPacket) {
tc.t.Helper()
got := tc.readPacket()
if !packetEqual(got, want) {
tc.t.Fatalf("%v:\ngot packet: %v\nwant packet: %v", expectation, got, want)
}
}
func packetEqual(a, b *testPacket) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
ac := *a
ac.frames = nil
ac.header = 0
bc := *b
bc.frames = nil
bc.header = 0
if !reflect.DeepEqual(ac, bc) {
return false
}
if len(a.frames) != len(b.frames) {
return false
}
for i := range a.frames {
if !frameEqual(a.frames[i], b.frames[i]) {
return false
}
}
return true
}
// wantFrame indicates that we expect the Conn to send a frame.
func (tc *testConn) wantFrame(expectation string, wantType packetType, want debugFrame) {
tc.t.Helper()
got, gotType := tc.readFrame()
if got == nil {
tc.t.Fatalf("%v:\nconnection is idle\nwant %v frame: %v", expectation, wantType, want)
}
if gotType != wantType {
tc.t.Fatalf("%v:\ngot %v packet, want %v\ngot frame: %v", expectation, gotType, wantType, got)
}
if !frameEqual(got, want) {
tc.t.Fatalf("%v:\ngot frame: %v\nwant frame: %v", expectation, got, want)
}
}
func frameEqual(a, b debugFrame) bool {
switch af := a.(type) {
case debugFrameConnectionCloseTransport:
bf, ok := b.(debugFrameConnectionCloseTransport)
return ok && af.code == bf.code
}
return reflect.DeepEqual(a, b)
}
// wantFrameType indicates that we expect the Conn to send a frame,
// although we don't care about the contents.
func (tc *testConn) wantFrameType(expectation string, wantType packetType, want debugFrame) {
tc.t.Helper()
got, gotType := tc.readFrame()
if got == nil {
tc.t.Fatalf("%v:\nconnection is idle\nwant %v frame: %v", expectation, wantType, want)
}
if gotType != wantType {
tc.t.Fatalf("%v:\ngot %v packet, want %v\ngot frame: %v", expectation, gotType, wantType, got)
}
if reflect.TypeOf(got) != reflect.TypeOf(want) {
tc.t.Fatalf("%v:\ngot frame: %v\nwant frame of type: %v", expectation, got, want)
}
}
// wantIdle indicates that we expect the Conn to not send any more frames.
func (tc *testConn) wantIdle(expectation string) {
tc.t.Helper()
switch {
case len(tc.sentFrames) > 0:
tc.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, tc.sentFrames[0])
case len(tc.sentPackets) > 0:
tc.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, tc.sentPackets[0])
}
if f, _ := tc.readFrame(); f != nil {
tc.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, f)
}
}
func encodeTestPacket(t *testing.T, tc *testConn, p *testPacket, pad int) []byte {
t.Helper()
var w packetWriter
w.reset(1200)
var pnumMaxAcked packetNumber
switch p.ptype {
case packetTypeRetry:
return encodeRetryPacket(p.originalDstConnID, retryPacket{
srcConnID: p.srcConnID,
dstConnID: p.dstConnID,
token: p.token,
})
case packetType1RTT:
w.start1RTTPacket(p.num, pnumMaxAcked, p.dstConnID)
default:
w.startProtectedLongHeaderPacket(pnumMaxAcked, longPacket{
ptype: p.ptype,
version: p.version,
num: p.num,
dstConnID: p.dstConnID,
srcConnID: p.srcConnID,
extra: p.token,
})
}
for _, f := range p.frames {
f.write(&w)
}
w.appendPaddingTo(pad)
if p.ptype != packetType1RTT {
var k fixedKeys
if tc == nil {
if p.ptype == packetTypeInitial {
k = initialKeys(p.dstConnID, serverSide).r
} else {
t.Fatalf("sending %v packet with no conn", p.ptype)
}
} else {
switch p.ptype {
case packetTypeInitial:
k = tc.keysInitial.w
case packetTypeHandshake:
k = tc.keysHandshake.w
}
}
if !k.isSet() {
t.Fatalf("sending %v packet with no write key", p.ptype)
}
w.finishProtectedLongHeaderPacket(pnumMaxAcked, k, longPacket{
ptype: p.ptype,
version: p.version,
num: p.num,
dstConnID: p.dstConnID,
srcConnID: p.srcConnID,
extra: p.token,
})
} else {
if tc == nil || !tc.wkeyAppData.hdr.isSet() {
t.Fatalf("sending 1-RTT packet with no write key")
}
// Somewhat hackish: Generate a temporary updatingKeyPair that will
// always use our desired key phase.
k := &updatingKeyPair{
w: updatingKeys{
hdr: tc.wkeyAppData.hdr,
pkt: [2]packetKey{
tc.wkeyAppData.pkt[p.keyNumber],
tc.wkeyAppData.pkt[p.keyNumber],
},
},
updateAfter: maxPacketNumber,
}
if p.keyPhaseBit {
k.phase |= keyPhaseBit
}
w.finish1RTTPacket(p.num, pnumMaxAcked, p.dstConnID, k)
}
return w.datagram()
}
func parseTestDatagram(t *testing.T, te *testEndpoint, tc *testConn, buf []byte) *testDatagram {
t.Helper()
bufSize := len(buf)
d := &testDatagram{}
size := len(buf)
for len(buf) > 0 {
if buf[0] == 0 {
d.paddedSize = bufSize
break
}
ptype := getPacketType(buf)
switch ptype {
case packetTypeRetry:
retry, ok := parseRetryPacket(buf, te.lastInitialDstConnID)
if !ok {
t.Fatalf("could not parse %v packet", ptype)
}
return &testDatagram{
packets: []*testPacket{{
ptype: packetTypeRetry,
dstConnID: retry.dstConnID,
srcConnID: retry.srcConnID,
token: retry.token,
}},
}
case packetTypeInitial, packetTypeHandshake:
var k fixedKeys
if tc == nil {
if ptype == packetTypeInitial {
p, _ := parseGenericLongHeaderPacket(buf)
k = initialKeys(p.srcConnID, serverSide).w
} else {
t.Fatalf("reading %v packet with no conn", ptype)
}
} else {
switch ptype {
case packetTypeInitial:
k = tc.keysInitial.r
case packetTypeHandshake:
k = tc.keysHandshake.r
}
}
if !k.isSet() {
t.Fatalf("reading %v packet with no read key", ptype)
}
var pnumMax packetNumber // TODO: Track packet numbers.
p, n := parseLongHeaderPacket(buf, k, pnumMax)
if n < 0 {
t.Fatalf("packet parse error")
}
frames, err := parseTestFrames(t, p.payload)
if err != nil {
t.Fatal(err)
}
var token []byte
if ptype == packetTypeInitial && len(p.extra) > 0 {
token = p.extra
}
d.packets = append(d.packets, &testPacket{
ptype: p.ptype,
header: buf[0],
version: p.version,
num: p.num,
dstConnID: p.dstConnID,
srcConnID: p.srcConnID,
token: token,
frames: frames,
})
buf = buf[n:]
case packetType1RTT:
if tc == nil || !tc.rkeyAppData.hdr.isSet() {
t.Fatalf("reading 1-RTT packet with no read key")
}
var pnumMax packetNumber // TODO: Track packet numbers.
pnumOff := 1 + len(tc.peerConnID)
// Try unprotecting the packet with the first maxTestKeyPhases keys.
var phase int
var pnum packetNumber
var hdr []byte
var pay []byte
var err error
for phase = 0; phase < maxTestKeyPhases; phase++ {
b := append([]byte{}, buf...)
hdr, pay, pnum, err = tc.rkeyAppData.hdr.unprotect(b, pnumOff, pnumMax)
if err != nil {
t.Fatalf("1-RTT packet header parse error")
}
k := tc.rkeyAppData.pkt[phase]
pay, err = k.unprotect(hdr, pay, pnum)
if err == nil {
break
}
}
if err != nil {
t.Fatalf("1-RTT packet payload parse error")
}
frames, err := parseTestFrames(t, pay)
if err != nil {
t.Fatal(err)
}
d.packets = append(d.packets, &testPacket{
ptype: packetType1RTT,
header: hdr[0],
num: pnum,
dstConnID: hdr[1:][:len(tc.peerConnID)],
keyPhaseBit: hdr[0]&keyPhaseBit != 0,
keyNumber: phase,
frames: frames,
})
buf = buf[len(buf):]
default:
t.Fatalf("unhandled packet type %v", ptype)
}
}
// This is rather hackish: If the last frame in the last packet
// in the datagram is PADDING, then remove it and record
// the padded size in the testDatagram.paddedSize.
//
// This makes it easier to write a test that expects a datagram
// padded to 1200 bytes.
if len(d.packets) > 0 && len(d.packets[len(d.packets)-1].frames) > 0 {
p := d.packets[len(d.packets)-1]
f := p.frames[len(p.frames)-1]
if _, ok := f.(debugFramePadding); ok {
p.frames = p.frames[:len(p.frames)-1]
d.paddedSize = size
}
}
return d
}
func parseTestFrames(t *testing.T, payload []byte) ([]debugFrame, error) {
t.Helper()
var frames []debugFrame
for len(payload) > 0 {
f, n := parseDebugFrame(payload)
if n < 0 {
return nil, errors.New("error parsing frames")
}
frames = append(frames, f)
payload = payload[n:]
}
return frames, nil
}
func spaceForPacketType(ptype packetType) numberSpace {
switch ptype {
case packetTypeInitial:
return initialSpace
case packetType0RTT:
panic("TODO: packetType0RTT")
case packetTypeHandshake:
return handshakeSpace
case packetTypeRetry:
panic("retry packets have no number space")
case packetType1RTT:
return appDataSpace
}
panic("unknown packet type")
}
// testConnHooks implements connTestHooks.
type testConnHooks testConn
func (tc *testConnHooks) init() {
tc.conn.keysAppData.updateAfter = maxPacketNumber // disable key updates
tc.keysInitial.r = tc.conn.keysInitial.w
tc.keysInitial.w = tc.conn.keysInitial.r
if tc.conn.side == serverSide {
tc.endpoint.acceptQueue = append(tc.endpoint.acceptQueue, (*testConn)(tc))
}
}
// handleTLSEvent processes TLS events generated by
// the connection under test's tls.QUICConn.
//
// We maintain a second tls.QUICConn representing the peer,
// and feed the TLS handshake data into it.
//
// We stash TLS handshake data from both sides in the testConn,
// where it can be used by tests.
//
// We snoop packet protection keys out of the tls.QUICConns,
// and verify that both sides of the connection are getting
// matching keys.
func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) {
checkKey := func(typ string, secrets *[numberSpaceCount]keySecret, e tls.QUICEvent) {
var space numberSpace
switch {
case e.Level == tls.QUICEncryptionLevelHandshake:
space = handshakeSpace
case e.Level == tls.QUICEncryptionLevelApplication:
space = appDataSpace
default:
tc.t.Errorf("unexpected encryption level %v", e.Level)
return
}
if secrets[space].secret == nil {
secrets[space].suite = e.Suite
secrets[space].secret = append([]byte{}, e.Data...)
} else if secrets[space].suite != e.Suite || !bytes.Equal(secrets[space].secret, e.Data) {
tc.t.Errorf("%v key mismatch for level for level %v", typ, e.Level)
}
}
setAppDataKey := func(suite uint16, secret []byte, k *test1RTTKeys) {
k.hdr.init(suite, secret)
for i := 0; i < len(k.pkt); i++ {
k.pkt[i].init(suite, secret)
secret = updateSecret(suite, secret)
}
}
switch e.Kind {
case tls.QUICSetReadSecret:
checkKey("write", &tc.wsecrets, e)
switch e.Level {
case tls.QUICEncryptionLevelHandshake:
tc.keysHandshake.w.init(e.Suite, e.Data)
case tls.QUICEncryptionLevelApplication:
setAppDataKey(e.Suite, e.Data, &tc.wkeyAppData)
}
case tls.QUICSetWriteSecret:
checkKey("read", &tc.rsecrets, e)
switch e.Level {
case tls.QUICEncryptionLevelHandshake:
tc.keysHandshake.r.init(e.Suite, e.Data)
case tls.QUICEncryptionLevelApplication:
setAppDataKey(e.Suite, e.Data, &tc.rkeyAppData)
}
case tls.QUICWriteData:
tc.cryptoDataOut[e.Level] = append(tc.cryptoDataOut[e.Level], e.Data...)
tc.peerTLSConn.HandleData(e.Level, e.Data)
}
for {
e := tc.peerTLSConn.NextEvent()
switch e.Kind {
case tls.QUICNoEvent:
return
case tls.QUICSetReadSecret:
checkKey("write", &tc.rsecrets, e)
switch e.Level {
case tls.QUICEncryptionLevelHandshake:
tc.keysHandshake.r.init(e.Suite, e.Data)
case tls.QUICEncryptionLevelApplication:
setAppDataKey(e.Suite, e.Data, &tc.rkeyAppData)
}
case tls.QUICSetWriteSecret:
checkKey("read", &tc.wsecrets, e)
switch e.Level {
case tls.QUICEncryptionLevelHandshake:
tc.keysHandshake.w.init(e.Suite, e.Data)
case tls.QUICEncryptionLevelApplication:
setAppDataKey(e.Suite, e.Data, &tc.wkeyAppData)
}
case tls.QUICWriteData:
tc.cryptoDataIn[e.Level] = append(tc.cryptoDataIn[e.Level], e.Data...)
case tls.QUICTransportParameters:
p, err := unmarshalTransportParams(e.Data)
if err != nil {
tc.t.Logf("sent unparseable transport parameters %x %v", e.Data, err)
} else {
tc.sentTransportParameters = &p
}
}
}
}
// nextMessage is called by the Conn's event loop to request its next event.
func (tc *testConnHooks) nextMessage(msgc chan any, timer time.Time) (now time.Time, m any) {
tc.timer = timer
for {
if !timer.IsZero() && !timer.After(tc.endpoint.now) {
if timer.Equal(tc.timerLastFired) {
// If the connection timer fires at time T, the Conn should take some
// action to advance the timer into the future. If the Conn reschedules
// the timer for the same time, it isn't making progress and we have a bug.
tc.t.Errorf("connection timer spinning; now=%v timer=%v", tc.endpoint.now, timer)
} else {
tc.timerLastFired = timer
return tc.endpoint.now, timerEvent{}
}
}
select {
case m := <-msgc:
return tc.endpoint.now, m
default:
}
if !tc.wakeAsync() {
break
}
}
// If the message queue is empty, then the conn is idle.
if tc.idlec != nil {
idlec := tc.idlec
tc.idlec = nil
close(idlec)
}
m = <-msgc
return tc.endpoint.now, m
}
func (tc *testConnHooks) newConnID(seq int64) ([]byte, error) {
return testLocalConnID(seq), nil
}
func (tc *testConnHooks) timeNow() time.Time {
return tc.endpoint.now
}
// testLocalConnID returns the connection ID with a given sequence number
// used by a Conn under test.
func testLocalConnID(seq int64) []byte {
cid := make([]byte, connIDLen)
copy(cid, []byte{0xc0, 0xff, 0xee})
cid[len(cid)-1] = byte(seq)
return cid
}
// testPeerConnID returns the connection ID with a given sequence number
// used by the fake peer of a Conn under test.
func testPeerConnID(seq int64) []byte {
// Use a different length than we choose for our own conn ids,
// to help catch any bad assumptions.
return []byte{0xbe, 0xee, 0xff, byte(seq)}
}
func testPeerStatelessResetToken(seq int64) statelessResetToken {
return statelessResetToken{
0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee,
0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, byte(seq),
}
}
// canceledContext returns a canceled Context.
//
// Functions which take a context preference progress over cancelation.
// For example, a read with a canceled context will return data if any is available.
// Tests use canceled contexts to perform non-blocking operations.
func canceledContext() context.Context {
ctx, cancel := context.WithCancel(context.Background())
cancel()
return ctx
}