| // Copyright 2023 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| //go:build go1.21 |
| |
| package quic |
| |
| import ( |
| "bytes" |
| "context" |
| "crypto/tls" |
| "errors" |
| "flag" |
| "fmt" |
| "math" |
| "net/netip" |
| "reflect" |
| "strings" |
| "testing" |
| "time" |
| ) |
| |
| var testVV = flag.Bool("vv", false, "even more verbose test output") |
| |
| func TestConnTestConn(t *testing.T) { |
| tc := newTestConn(t, serverSide) |
| tc.handshake() |
| if got, want := tc.timeUntilEvent(), defaultMaxIdleTimeout; got != want { |
| t.Errorf("new conn timeout=%v, want %v (max_idle_timeout)", got, want) |
| } |
| |
| var ranAt time.Time |
| tc.conn.runOnLoop(func(now time.Time, c *Conn) { |
| ranAt = now |
| }) |
| if !ranAt.Equal(tc.endpoint.now) { |
| t.Errorf("func ran on loop at %v, want %v", ranAt, tc.endpoint.now) |
| } |
| tc.wait() |
| |
| nextTime := tc.endpoint.now.Add(defaultMaxIdleTimeout / 2) |
| tc.advanceTo(nextTime) |
| tc.conn.runOnLoop(func(now time.Time, c *Conn) { |
| ranAt = now |
| }) |
| if !ranAt.Equal(nextTime) { |
| t.Errorf("func ran on loop at %v, want %v", ranAt, nextTime) |
| } |
| tc.wait() |
| |
| tc.advanceToTimer() |
| if got := tc.conn.lifetime.state; got != connStateDone { |
| t.Errorf("after advancing to idle timeout, conn state = %v, want done", got) |
| } |
| } |
| |
| type testDatagram struct { |
| packets []*testPacket |
| paddedSize int |
| addr netip.AddrPort |
| } |
| |
| func (d testDatagram) String() string { |
| var b strings.Builder |
| fmt.Fprintf(&b, "datagram with %v packets", len(d.packets)) |
| if d.paddedSize > 0 { |
| fmt.Fprintf(&b, " (padded to %v bytes)", d.paddedSize) |
| } |
| b.WriteString(":") |
| for _, p := range d.packets { |
| b.WriteString("\n") |
| b.WriteString(p.String()) |
| } |
| return b.String() |
| } |
| |
| type testPacket struct { |
| ptype packetType |
| version uint32 |
| num packetNumber |
| keyPhaseBit bool |
| keyNumber int |
| dstConnID []byte |
| srcConnID []byte |
| token []byte |
| originalDstConnID []byte // used for encoding Retry packets |
| frames []debugFrame |
| } |
| |
| func (p testPacket) String() string { |
| var b strings.Builder |
| fmt.Fprintf(&b, " %v %v", p.ptype, p.num) |
| if p.version != 0 { |
| fmt.Fprintf(&b, " version=%v", p.version) |
| } |
| if p.srcConnID != nil { |
| fmt.Fprintf(&b, " src={%x}", p.srcConnID) |
| } |
| if p.dstConnID != nil { |
| fmt.Fprintf(&b, " dst={%x}", p.dstConnID) |
| } |
| if p.token != nil { |
| fmt.Fprintf(&b, " token={%x}", p.token) |
| } |
| for _, f := range p.frames { |
| fmt.Fprintf(&b, "\n %v", f) |
| } |
| return b.String() |
| } |
| |
| // maxTestKeyPhases is the maximum number of 1-RTT keys we'll generate in a test. |
| const maxTestKeyPhases = 3 |
| |
| // A testConn is a Conn whose external interactions (sending and receiving packets, |
| // setting timers) can be manipulated in tests. |
| type testConn struct { |
| t *testing.T |
| conn *Conn |
| endpoint *testEndpoint |
| timer time.Time |
| timerLastFired time.Time |
| idlec chan struct{} // only accessed on the conn's loop |
| |
| // Keys are distinct from the conn's keys, |
| // because the test may know about keys before the conn does. |
| // For example, when sending a datagram with coalesced |
| // Initial and Handshake packets to a client conn, |
| // we use Handshake keys to encrypt the packet. |
| // The client only acquires those keys when it processes |
| // the Initial packet. |
| keysInitial fixedKeyPair |
| keysHandshake fixedKeyPair |
| rkeyAppData test1RTTKeys |
| wkeyAppData test1RTTKeys |
| rsecrets [numberSpaceCount]keySecret |
| wsecrets [numberSpaceCount]keySecret |
| |
| // testConn uses a test hook to snoop on the conn's TLS events. |
| // CRYPTO data produced by the conn's QUICConn is placed in |
| // cryptoDataOut. |
| // |
| // The peerTLSConn is is a QUICConn representing the peer. |
| // CRYPTO data produced by the conn is written to peerTLSConn, |
| // and data produced by peerTLSConn is placed in cryptoDataIn. |
| cryptoDataOut map[tls.QUICEncryptionLevel][]byte |
| cryptoDataIn map[tls.QUICEncryptionLevel][]byte |
| peerTLSConn *tls.QUICConn |
| |
| // Information about the conn's (fake) peer. |
| peerConnID []byte // source conn id of peer's packets |
| peerNextPacketNum [numberSpaceCount]packetNumber // next packet number to use |
| |
| // Datagrams, packets, and frames sent by the conn, |
| // but not yet processed by the test. |
| sentDatagrams [][]byte |
| sentPackets []*testPacket |
| sentFrames []debugFrame |
| lastPacket *testPacket |
| |
| recvDatagram chan *datagram |
| |
| // Transport parameters sent by the conn. |
| sentTransportParameters *transportParameters |
| |
| // Frame types to ignore in tests. |
| ignoreFrames map[byte]bool |
| |
| // Values to set in packets sent to the conn. |
| sendKeyNumber int |
| sendKeyPhaseBit bool |
| |
| asyncTestState |
| } |
| |
| type test1RTTKeys struct { |
| hdr headerKey |
| pkt [maxTestKeyPhases]packetKey |
| } |
| |
| type keySecret struct { |
| suite uint16 |
| secret []byte |
| } |
| |
| // newTestConn creates a Conn for testing. |
| // |
| // The Conn's event loop is controlled by the test, |
| // allowing test code to access Conn state directly |
| // by first ensuring the loop goroutine is idle. |
| func newTestConn(t *testing.T, side connSide, opts ...any) *testConn { |
| t.Helper() |
| config := &Config{ |
| TLSConfig: newTestTLSConfig(side), |
| StatelessResetKey: testStatelessResetKey, |
| } |
| var cids newServerConnIDs |
| if side == serverSide { |
| // The initial connection ID for the server is chosen by the client. |
| cids.srcConnID = testPeerConnID(0) |
| cids.dstConnID = testPeerConnID(-1) |
| cids.originalDstConnID = cids.dstConnID |
| } |
| var configTransportParams []func(*transportParameters) |
| var configTestConn []func(*testConn) |
| for _, o := range opts { |
| switch o := o.(type) { |
| case func(*Config): |
| o(config) |
| case func(*tls.Config): |
| o(config.TLSConfig) |
| case func(cids *newServerConnIDs): |
| o(&cids) |
| case func(p *transportParameters): |
| configTransportParams = append(configTransportParams, o) |
| case func(p *testConn): |
| configTestConn = append(configTestConn, o) |
| default: |
| t.Fatalf("unknown newTestConn option %T", o) |
| } |
| } |
| |
| endpoint := newTestEndpoint(t, config) |
| endpoint.configTransportParams = configTransportParams |
| endpoint.configTestConn = configTestConn |
| conn, err := endpoint.e.newConn( |
| endpoint.now, |
| side, |
| cids, |
| netip.MustParseAddrPort("127.0.0.1:443")) |
| if err != nil { |
| t.Fatal(err) |
| } |
| tc := endpoint.conns[conn] |
| tc.wait() |
| return tc |
| } |
| |
| func newTestConnForConn(t *testing.T, endpoint *testEndpoint, conn *Conn) *testConn { |
| t.Helper() |
| tc := &testConn{ |
| t: t, |
| endpoint: endpoint, |
| conn: conn, |
| peerConnID: testPeerConnID(0), |
| ignoreFrames: map[byte]bool{ |
| frameTypePadding: true, // ignore PADDING by default |
| }, |
| cryptoDataOut: make(map[tls.QUICEncryptionLevel][]byte), |
| cryptoDataIn: make(map[tls.QUICEncryptionLevel][]byte), |
| recvDatagram: make(chan *datagram), |
| } |
| t.Cleanup(tc.cleanup) |
| for _, f := range endpoint.configTestConn { |
| f(tc) |
| } |
| conn.testHooks = (*testConnHooks)(tc) |
| |
| if endpoint.peerTLSConn != nil { |
| tc.peerTLSConn = endpoint.peerTLSConn |
| endpoint.peerTLSConn = nil |
| return tc |
| } |
| |
| peerProvidedParams := defaultTransportParameters() |
| peerProvidedParams.initialSrcConnID = testPeerConnID(0) |
| if conn.side == clientSide { |
| peerProvidedParams.originalDstConnID = testLocalConnID(-1) |
| } |
| for _, f := range endpoint.configTransportParams { |
| f(&peerProvidedParams) |
| } |
| |
| peerQUICConfig := &tls.QUICConfig{TLSConfig: newTestTLSConfig(conn.side.peer())} |
| if conn.side == clientSide { |
| tc.peerTLSConn = tls.QUICServer(peerQUICConfig) |
| } else { |
| tc.peerTLSConn = tls.QUICClient(peerQUICConfig) |
| } |
| tc.peerTLSConn.SetTransportParameters(marshalTransportParameters(peerProvidedParams)) |
| tc.peerTLSConn.Start(context.Background()) |
| |
| return tc |
| } |
| |
| // advance causes time to pass. |
| func (tc *testConn) advance(d time.Duration) { |
| tc.t.Helper() |
| tc.endpoint.advance(d) |
| } |
| |
| // advanceTo sets the current time. |
| func (tc *testConn) advanceTo(now time.Time) { |
| tc.t.Helper() |
| tc.endpoint.advanceTo(now) |
| } |
| |
| // advanceToTimer sets the current time to the time of the Conn's next timer event. |
| func (tc *testConn) advanceToTimer() { |
| if tc.timer.IsZero() { |
| tc.t.Fatalf("advancing to timer, but timer is not set") |
| } |
| tc.advanceTo(tc.timer) |
| } |
| |
| func (tc *testConn) timerDelay() time.Duration { |
| if tc.timer.IsZero() { |
| return math.MaxInt64 // infinite |
| } |
| if tc.timer.Before(tc.endpoint.now) { |
| return 0 |
| } |
| return tc.timer.Sub(tc.endpoint.now) |
| } |
| |
| const infiniteDuration = time.Duration(math.MaxInt64) |
| |
| // timeUntilEvent returns the amount of time until the next connection event. |
| func (tc *testConn) timeUntilEvent() time.Duration { |
| if tc.timer.IsZero() { |
| return infiniteDuration |
| } |
| if tc.timer.Before(tc.endpoint.now) { |
| return 0 |
| } |
| return tc.timer.Sub(tc.endpoint.now) |
| } |
| |
| // wait blocks until the conn becomes idle. |
| // The conn is idle when it is blocked waiting for a packet to arrive or a timer to expire. |
| // Tests shouldn't need to call wait directly. |
| // testConn methods that wake the Conn event loop will call wait for them. |
| func (tc *testConn) wait() { |
| tc.t.Helper() |
| idlec := make(chan struct{}) |
| fail := false |
| tc.conn.sendMsg(func(now time.Time, c *Conn) { |
| if tc.idlec != nil { |
| tc.t.Errorf("testConn.wait called concurrently") |
| fail = true |
| close(idlec) |
| } else { |
| // nextMessage will close idlec. |
| tc.idlec = idlec |
| } |
| }) |
| select { |
| case <-idlec: |
| case <-tc.conn.donec: |
| // We may have async ops that can proceed now that the conn is done. |
| tc.wakeAsync() |
| } |
| if fail { |
| panic(fail) |
| } |
| } |
| |
| func (tc *testConn) cleanup() { |
| if tc.conn == nil { |
| return |
| } |
| tc.conn.exit() |
| <-tc.conn.donec |
| } |
| |
| func logDatagram(t *testing.T, text string, d *testDatagram) { |
| t.Helper() |
| if !*testVV { |
| return |
| } |
| pad := "" |
| if d.paddedSize > 0 { |
| pad = fmt.Sprintf(" (padded to %v)", d.paddedSize) |
| } |
| t.Logf("%v datagram%v", text, pad) |
| for _, p := range d.packets { |
| var s string |
| switch p.ptype { |
| case packetType1RTT: |
| s = fmt.Sprintf(" %v pnum=%v", p.ptype, p.num) |
| default: |
| s = fmt.Sprintf(" %v pnum=%v ver=%v dst={%x} src={%x}", p.ptype, p.num, p.version, p.dstConnID, p.srcConnID) |
| } |
| if p.token != nil { |
| s += fmt.Sprintf(" token={%x}", p.token) |
| } |
| if p.keyPhaseBit { |
| s += fmt.Sprintf(" KeyPhase") |
| } |
| if p.keyNumber != 0 { |
| s += fmt.Sprintf(" keynum=%v", p.keyNumber) |
| } |
| t.Log(s) |
| for _, f := range p.frames { |
| t.Logf(" %v", f) |
| } |
| } |
| } |
| |
| // write sends the Conn a datagram. |
| func (tc *testConn) write(d *testDatagram) { |
| tc.t.Helper() |
| tc.endpoint.writeDatagram(d) |
| } |
| |
| // writeFrame sends the Conn a datagram containing the given frames. |
| func (tc *testConn) writeFrames(ptype packetType, frames ...debugFrame) { |
| tc.t.Helper() |
| space := spaceForPacketType(ptype) |
| dstConnID := tc.conn.connIDState.local[0].cid |
| if tc.conn.connIDState.local[0].seq == -1 && ptype != packetTypeInitial { |
| // Only use the transient connection ID in Initial packets. |
| dstConnID = tc.conn.connIDState.local[1].cid |
| } |
| d := &testDatagram{ |
| packets: []*testPacket{{ |
| ptype: ptype, |
| num: tc.peerNextPacketNum[space], |
| keyNumber: tc.sendKeyNumber, |
| keyPhaseBit: tc.sendKeyPhaseBit, |
| frames: frames, |
| version: quicVersion1, |
| dstConnID: dstConnID, |
| srcConnID: tc.peerConnID, |
| }}, |
| } |
| if ptype == packetTypeInitial && tc.conn.side == serverSide { |
| d.paddedSize = 1200 |
| } |
| tc.write(d) |
| } |
| |
| // writeAckForAll sends the Conn a datagram containing an ack for all packets up to the |
| // last one received. |
| func (tc *testConn) writeAckForAll() { |
| tc.t.Helper() |
| if tc.lastPacket == nil { |
| return |
| } |
| tc.writeFrames(tc.lastPacket.ptype, debugFrameAck{ |
| ranges: []i64range[packetNumber]{{0, tc.lastPacket.num + 1}}, |
| }) |
| } |
| |
| // writeAckForLatest sends the Conn a datagram containing an ack for the |
| // most recent packet received. |
| func (tc *testConn) writeAckForLatest() { |
| tc.t.Helper() |
| if tc.lastPacket == nil { |
| return |
| } |
| tc.writeFrames(tc.lastPacket.ptype, debugFrameAck{ |
| ranges: []i64range[packetNumber]{{tc.lastPacket.num, tc.lastPacket.num + 1}}, |
| }) |
| } |
| |
| // ignoreFrame hides frames of the given type sent by the Conn. |
| func (tc *testConn) ignoreFrame(frameType byte) { |
| tc.ignoreFrames[frameType] = true |
| } |
| |
| // readDatagram reads the next datagram sent by the Conn. |
| // It returns nil if the Conn has no more datagrams to send at this time. |
| func (tc *testConn) readDatagram() *testDatagram { |
| tc.t.Helper() |
| tc.wait() |
| tc.sentPackets = nil |
| tc.sentFrames = nil |
| buf := tc.endpoint.read() |
| if buf == nil { |
| return nil |
| } |
| d := parseTestDatagram(tc.t, tc.endpoint, tc, buf) |
| // Log the datagram before removing ignored frames. |
| // When things go wrong, it's useful to see all the frames. |
| logDatagram(tc.t, "-> conn under test sends", d) |
| typeForFrame := func(f debugFrame) byte { |
| // This is very clunky, and points at a problem |
| // in how we specify what frames to ignore in tests. |
| // |
| // We mark frames to ignore using the frame type, |
| // but we've got a debugFrame data structure here. |
| // Perhaps we should be ignoring frames by debugFrame |
| // type instead: tc.ignoreFrame[debugFrameAck](). |
| switch f := f.(type) { |
| case debugFramePadding: |
| return frameTypePadding |
| case debugFramePing: |
| return frameTypePing |
| case debugFrameAck: |
| return frameTypeAck |
| case debugFrameResetStream: |
| return frameTypeResetStream |
| case debugFrameStopSending: |
| return frameTypeStopSending |
| case debugFrameCrypto: |
| return frameTypeCrypto |
| case debugFrameNewToken: |
| return frameTypeNewToken |
| case debugFrameStream: |
| return frameTypeStreamBase |
| case debugFrameMaxData: |
| return frameTypeMaxData |
| case debugFrameMaxStreamData: |
| return frameTypeMaxStreamData |
| case debugFrameMaxStreams: |
| if f.streamType == bidiStream { |
| return frameTypeMaxStreamsBidi |
| } else { |
| return frameTypeMaxStreamsUni |
| } |
| case debugFrameDataBlocked: |
| return frameTypeDataBlocked |
| case debugFrameStreamDataBlocked: |
| return frameTypeStreamDataBlocked |
| case debugFrameStreamsBlocked: |
| if f.streamType == bidiStream { |
| return frameTypeStreamsBlockedBidi |
| } else { |
| return frameTypeStreamsBlockedUni |
| } |
| case debugFrameNewConnectionID: |
| return frameTypeNewConnectionID |
| case debugFrameRetireConnectionID: |
| return frameTypeRetireConnectionID |
| case debugFramePathChallenge: |
| return frameTypePathChallenge |
| case debugFramePathResponse: |
| return frameTypePathResponse |
| case debugFrameConnectionCloseTransport: |
| return frameTypeConnectionCloseTransport |
| case debugFrameConnectionCloseApplication: |
| return frameTypeConnectionCloseApplication |
| case debugFrameHandshakeDone: |
| return frameTypeHandshakeDone |
| } |
| panic(fmt.Errorf("unhandled frame type %T", f)) |
| } |
| for _, p := range d.packets { |
| var frames []debugFrame |
| for _, f := range p.frames { |
| if !tc.ignoreFrames[typeForFrame(f)] { |
| frames = append(frames, f) |
| } |
| } |
| p.frames = frames |
| } |
| return d |
| } |
| |
| // readPacket reads the next packet sent by the Conn. |
| // It returns nil if the Conn has no more packets to send at this time. |
| func (tc *testConn) readPacket() *testPacket { |
| tc.t.Helper() |
| for len(tc.sentPackets) == 0 { |
| d := tc.readDatagram() |
| if d == nil { |
| return nil |
| } |
| for _, p := range d.packets { |
| if len(p.frames) == 0 { |
| tc.lastPacket = p |
| continue |
| } |
| tc.sentPackets = append(tc.sentPackets, p) |
| } |
| } |
| p := tc.sentPackets[0] |
| tc.sentPackets = tc.sentPackets[1:] |
| tc.lastPacket = p |
| return p |
| } |
| |
| // readFrame reads the next frame sent by the Conn. |
| // It returns nil if the Conn has no more frames to send at this time. |
| func (tc *testConn) readFrame() (debugFrame, packetType) { |
| tc.t.Helper() |
| for len(tc.sentFrames) == 0 { |
| p := tc.readPacket() |
| if p == nil { |
| return nil, packetTypeInvalid |
| } |
| tc.sentFrames = p.frames |
| } |
| f := tc.sentFrames[0] |
| tc.sentFrames = tc.sentFrames[1:] |
| return f, tc.lastPacket.ptype |
| } |
| |
| // wantDatagram indicates that we expect the Conn to send a datagram. |
| func (tc *testConn) wantDatagram(expectation string, want *testDatagram) { |
| tc.t.Helper() |
| got := tc.readDatagram() |
| if !reflect.DeepEqual(got, want) { |
| tc.t.Fatalf("%v:\ngot datagram: %v\nwant datagram: %v", expectation, got, want) |
| } |
| } |
| |
| func datagramEqual(a, b *testDatagram) bool { |
| if a.paddedSize != b.paddedSize || |
| a.addr != b.addr || |
| len(a.packets) != len(b.packets) { |
| return false |
| } |
| for i := range a.packets { |
| if !packetEqual(a.packets[i], b.packets[i]) { |
| return false |
| } |
| } |
| return true |
| } |
| |
| // wantPacket indicates that we expect the Conn to send a packet. |
| func (tc *testConn) wantPacket(expectation string, want *testPacket) { |
| tc.t.Helper() |
| got := tc.readPacket() |
| if !reflect.DeepEqual(got, want) { |
| tc.t.Fatalf("%v:\ngot packet: %v\nwant packet: %v", expectation, got, want) |
| } |
| } |
| |
| func packetEqual(a, b *testPacket) bool { |
| ac := *a |
| ac.frames = nil |
| bc := *b |
| bc.frames = nil |
| if !reflect.DeepEqual(ac, bc) { |
| return false |
| } |
| if len(a.frames) != len(b.frames) { |
| return false |
| } |
| for i := range a.frames { |
| if !frameEqual(a.frames[i], b.frames[i]) { |
| return false |
| } |
| } |
| return true |
| } |
| |
| // wantFrame indicates that we expect the Conn to send a frame. |
| func (tc *testConn) wantFrame(expectation string, wantType packetType, want debugFrame) { |
| tc.t.Helper() |
| got, gotType := tc.readFrame() |
| if got == nil { |
| tc.t.Fatalf("%v:\nconnection is idle\nwant %v frame: %v", expectation, wantType, want) |
| } |
| if gotType != wantType { |
| tc.t.Fatalf("%v:\ngot %v packet, want %v\ngot frame: %v", expectation, gotType, wantType, got) |
| } |
| if !frameEqual(got, want) { |
| tc.t.Fatalf("%v:\ngot frame: %v\nwant frame: %v", expectation, got, want) |
| } |
| } |
| |
| func frameEqual(a, b debugFrame) bool { |
| switch af := a.(type) { |
| case debugFrameConnectionCloseTransport: |
| bf, ok := b.(debugFrameConnectionCloseTransport) |
| return ok && af.code == bf.code |
| } |
| return reflect.DeepEqual(a, b) |
| } |
| |
| // wantFrameType indicates that we expect the Conn to send a frame, |
| // although we don't care about the contents. |
| func (tc *testConn) wantFrameType(expectation string, wantType packetType, want debugFrame) { |
| tc.t.Helper() |
| got, gotType := tc.readFrame() |
| if got == nil { |
| tc.t.Fatalf("%v:\nconnection is idle\nwant %v frame: %v", expectation, wantType, want) |
| } |
| if gotType != wantType { |
| tc.t.Fatalf("%v:\ngot %v packet, want %v\ngot frame: %v", expectation, gotType, wantType, got) |
| } |
| if reflect.TypeOf(got) != reflect.TypeOf(want) { |
| tc.t.Fatalf("%v:\ngot frame: %v\nwant frame of type: %v", expectation, got, want) |
| } |
| } |
| |
| // wantIdle indicates that we expect the Conn to not send any more frames. |
| func (tc *testConn) wantIdle(expectation string) { |
| tc.t.Helper() |
| switch { |
| case len(tc.sentFrames) > 0: |
| tc.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, tc.sentFrames[0]) |
| case len(tc.sentPackets) > 0: |
| tc.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, tc.sentPackets[0]) |
| } |
| if f, _ := tc.readFrame(); f != nil { |
| tc.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, f) |
| } |
| } |
| |
| func encodeTestPacket(t *testing.T, tc *testConn, p *testPacket, pad int) []byte { |
| t.Helper() |
| var w packetWriter |
| w.reset(1200) |
| var pnumMaxAcked packetNumber |
| switch p.ptype { |
| case packetTypeRetry: |
| return encodeRetryPacket(p.originalDstConnID, retryPacket{ |
| srcConnID: p.srcConnID, |
| dstConnID: p.dstConnID, |
| token: p.token, |
| }) |
| case packetType1RTT: |
| w.start1RTTPacket(p.num, pnumMaxAcked, p.dstConnID) |
| default: |
| w.startProtectedLongHeaderPacket(pnumMaxAcked, longPacket{ |
| ptype: p.ptype, |
| version: p.version, |
| num: p.num, |
| dstConnID: p.dstConnID, |
| srcConnID: p.srcConnID, |
| extra: p.token, |
| }) |
| } |
| for _, f := range p.frames { |
| f.write(&w) |
| } |
| w.appendPaddingTo(pad) |
| if p.ptype != packetType1RTT { |
| var k fixedKeys |
| if tc == nil { |
| if p.ptype == packetTypeInitial { |
| k = initialKeys(p.dstConnID, serverSide).r |
| } else { |
| t.Fatalf("sending %v packet with no conn", p.ptype) |
| } |
| } else { |
| switch p.ptype { |
| case packetTypeInitial: |
| k = tc.keysInitial.w |
| case packetTypeHandshake: |
| k = tc.keysHandshake.w |
| } |
| } |
| if !k.isSet() { |
| t.Fatalf("sending %v packet with no write key", p.ptype) |
| } |
| w.finishProtectedLongHeaderPacket(pnumMaxAcked, k, longPacket{ |
| ptype: p.ptype, |
| version: p.version, |
| num: p.num, |
| dstConnID: p.dstConnID, |
| srcConnID: p.srcConnID, |
| extra: p.token, |
| }) |
| } else { |
| if tc == nil || !tc.wkeyAppData.hdr.isSet() { |
| t.Fatalf("sending 1-RTT packet with no write key") |
| } |
| // Somewhat hackish: Generate a temporary updatingKeyPair that will |
| // always use our desired key phase. |
| k := &updatingKeyPair{ |
| w: updatingKeys{ |
| hdr: tc.wkeyAppData.hdr, |
| pkt: [2]packetKey{ |
| tc.wkeyAppData.pkt[p.keyNumber], |
| tc.wkeyAppData.pkt[p.keyNumber], |
| }, |
| }, |
| updateAfter: maxPacketNumber, |
| } |
| if p.keyPhaseBit { |
| k.phase |= keyPhaseBit |
| } |
| w.finish1RTTPacket(p.num, pnumMaxAcked, p.dstConnID, k) |
| } |
| return w.datagram() |
| } |
| |
| func parseTestDatagram(t *testing.T, te *testEndpoint, tc *testConn, buf []byte) *testDatagram { |
| t.Helper() |
| bufSize := len(buf) |
| d := &testDatagram{} |
| size := len(buf) |
| for len(buf) > 0 { |
| if buf[0] == 0 { |
| d.paddedSize = bufSize |
| break |
| } |
| ptype := getPacketType(buf) |
| switch ptype { |
| case packetTypeRetry: |
| retry, ok := parseRetryPacket(buf, te.lastInitialDstConnID) |
| if !ok { |
| t.Fatalf("could not parse %v packet", ptype) |
| } |
| return &testDatagram{ |
| packets: []*testPacket{{ |
| ptype: packetTypeRetry, |
| dstConnID: retry.dstConnID, |
| srcConnID: retry.srcConnID, |
| token: retry.token, |
| }}, |
| } |
| case packetTypeInitial, packetTypeHandshake: |
| var k fixedKeys |
| if tc == nil { |
| if ptype == packetTypeInitial { |
| p, _ := parseGenericLongHeaderPacket(buf) |
| k = initialKeys(p.srcConnID, serverSide).w |
| } else { |
| t.Fatalf("reading %v packet with no conn", ptype) |
| } |
| } else { |
| switch ptype { |
| case packetTypeInitial: |
| k = tc.keysInitial.r |
| case packetTypeHandshake: |
| k = tc.keysHandshake.r |
| } |
| } |
| if !k.isSet() { |
| t.Fatalf("reading %v packet with no read key", ptype) |
| } |
| var pnumMax packetNumber // TODO: Track packet numbers. |
| p, n := parseLongHeaderPacket(buf, k, pnumMax) |
| if n < 0 { |
| t.Fatalf("packet parse error") |
| } |
| frames, err := parseTestFrames(t, p.payload) |
| if err != nil { |
| t.Fatal(err) |
| } |
| var token []byte |
| if ptype == packetTypeInitial && len(p.extra) > 0 { |
| token = p.extra |
| } |
| d.packets = append(d.packets, &testPacket{ |
| ptype: p.ptype, |
| version: p.version, |
| num: p.num, |
| dstConnID: p.dstConnID, |
| srcConnID: p.srcConnID, |
| token: token, |
| frames: frames, |
| }) |
| buf = buf[n:] |
| case packetType1RTT: |
| if tc == nil || !tc.rkeyAppData.hdr.isSet() { |
| t.Fatalf("reading 1-RTT packet with no read key") |
| } |
| var pnumMax packetNumber // TODO: Track packet numbers. |
| pnumOff := 1 + len(tc.peerConnID) |
| // Try unprotecting the packet with the first maxTestKeyPhases keys. |
| var phase int |
| var pnum packetNumber |
| var hdr []byte |
| var pay []byte |
| var err error |
| for phase = 0; phase < maxTestKeyPhases; phase++ { |
| b := append([]byte{}, buf...) |
| hdr, pay, pnum, err = tc.rkeyAppData.hdr.unprotect(b, pnumOff, pnumMax) |
| if err != nil { |
| t.Fatalf("1-RTT packet header parse error") |
| } |
| k := tc.rkeyAppData.pkt[phase] |
| pay, err = k.unprotect(hdr, pay, pnum) |
| if err == nil { |
| break |
| } |
| } |
| if err != nil { |
| t.Fatalf("1-RTT packet payload parse error") |
| } |
| frames, err := parseTestFrames(t, pay) |
| if err != nil { |
| t.Fatal(err) |
| } |
| d.packets = append(d.packets, &testPacket{ |
| ptype: packetType1RTT, |
| num: pnum, |
| dstConnID: hdr[1:][:len(tc.peerConnID)], |
| keyPhaseBit: hdr[0]&keyPhaseBit != 0, |
| keyNumber: phase, |
| frames: frames, |
| }) |
| buf = buf[len(buf):] |
| default: |
| t.Fatalf("unhandled packet type %v", ptype) |
| } |
| } |
| // This is rather hackish: If the last frame in the last packet |
| // in the datagram is PADDING, then remove it and record |
| // the padded size in the testDatagram.paddedSize. |
| // |
| // This makes it easier to write a test that expects a datagram |
| // padded to 1200 bytes. |
| if len(d.packets) > 0 && len(d.packets[len(d.packets)-1].frames) > 0 { |
| p := d.packets[len(d.packets)-1] |
| f := p.frames[len(p.frames)-1] |
| if _, ok := f.(debugFramePadding); ok { |
| p.frames = p.frames[:len(p.frames)-1] |
| d.paddedSize = size |
| } |
| } |
| return d |
| } |
| |
| func parseTestFrames(t *testing.T, payload []byte) ([]debugFrame, error) { |
| t.Helper() |
| var frames []debugFrame |
| for len(payload) > 0 { |
| f, n := parseDebugFrame(payload) |
| if n < 0 { |
| return nil, errors.New("error parsing frames") |
| } |
| frames = append(frames, f) |
| payload = payload[n:] |
| } |
| return frames, nil |
| } |
| |
| func spaceForPacketType(ptype packetType) numberSpace { |
| switch ptype { |
| case packetTypeInitial: |
| return initialSpace |
| case packetType0RTT: |
| panic("TODO: packetType0RTT") |
| case packetTypeHandshake: |
| return handshakeSpace |
| case packetTypeRetry: |
| panic("retry packets have no number space") |
| case packetType1RTT: |
| return appDataSpace |
| } |
| panic("unknown packet type") |
| } |
| |
| // testConnHooks implements connTestHooks. |
| type testConnHooks testConn |
| |
| func (tc *testConnHooks) init() { |
| tc.conn.keysAppData.updateAfter = maxPacketNumber // disable key updates |
| tc.keysInitial.r = tc.conn.keysInitial.w |
| tc.keysInitial.w = tc.conn.keysInitial.r |
| if tc.conn.side == serverSide { |
| tc.endpoint.acceptQueue = append(tc.endpoint.acceptQueue, (*testConn)(tc)) |
| } |
| } |
| |
| // handleTLSEvent processes TLS events generated by |
| // the connection under test's tls.QUICConn. |
| // |
| // We maintain a second tls.QUICConn representing the peer, |
| // and feed the TLS handshake data into it. |
| // |
| // We stash TLS handshake data from both sides in the testConn, |
| // where it can be used by tests. |
| // |
| // We snoop packet protection keys out of the tls.QUICConns, |
| // and verify that both sides of the connection are getting |
| // matching keys. |
| func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) { |
| checkKey := func(typ string, secrets *[numberSpaceCount]keySecret, e tls.QUICEvent) { |
| var space numberSpace |
| switch { |
| case e.Level == tls.QUICEncryptionLevelHandshake: |
| space = handshakeSpace |
| case e.Level == tls.QUICEncryptionLevelApplication: |
| space = appDataSpace |
| default: |
| tc.t.Errorf("unexpected encryption level %v", e.Level) |
| return |
| } |
| if secrets[space].secret == nil { |
| secrets[space].suite = e.Suite |
| secrets[space].secret = append([]byte{}, e.Data...) |
| } else if secrets[space].suite != e.Suite || !bytes.Equal(secrets[space].secret, e.Data) { |
| tc.t.Errorf("%v key mismatch for level for level %v", typ, e.Level) |
| } |
| } |
| setAppDataKey := func(suite uint16, secret []byte, k *test1RTTKeys) { |
| k.hdr.init(suite, secret) |
| for i := 0; i < len(k.pkt); i++ { |
| k.pkt[i].init(suite, secret) |
| secret = updateSecret(suite, secret) |
| } |
| } |
| switch e.Kind { |
| case tls.QUICSetReadSecret: |
| checkKey("write", &tc.wsecrets, e) |
| switch e.Level { |
| case tls.QUICEncryptionLevelHandshake: |
| tc.keysHandshake.w.init(e.Suite, e.Data) |
| case tls.QUICEncryptionLevelApplication: |
| setAppDataKey(e.Suite, e.Data, &tc.wkeyAppData) |
| } |
| case tls.QUICSetWriteSecret: |
| checkKey("read", &tc.rsecrets, e) |
| switch e.Level { |
| case tls.QUICEncryptionLevelHandshake: |
| tc.keysHandshake.r.init(e.Suite, e.Data) |
| case tls.QUICEncryptionLevelApplication: |
| setAppDataKey(e.Suite, e.Data, &tc.rkeyAppData) |
| } |
| case tls.QUICWriteData: |
| tc.cryptoDataOut[e.Level] = append(tc.cryptoDataOut[e.Level], e.Data...) |
| tc.peerTLSConn.HandleData(e.Level, e.Data) |
| } |
| for { |
| e := tc.peerTLSConn.NextEvent() |
| switch e.Kind { |
| case tls.QUICNoEvent: |
| return |
| case tls.QUICSetReadSecret: |
| checkKey("write", &tc.rsecrets, e) |
| switch e.Level { |
| case tls.QUICEncryptionLevelHandshake: |
| tc.keysHandshake.r.init(e.Suite, e.Data) |
| case tls.QUICEncryptionLevelApplication: |
| setAppDataKey(e.Suite, e.Data, &tc.rkeyAppData) |
| } |
| case tls.QUICSetWriteSecret: |
| checkKey("read", &tc.wsecrets, e) |
| switch e.Level { |
| case tls.QUICEncryptionLevelHandshake: |
| tc.keysHandshake.w.init(e.Suite, e.Data) |
| case tls.QUICEncryptionLevelApplication: |
| setAppDataKey(e.Suite, e.Data, &tc.wkeyAppData) |
| } |
| case tls.QUICWriteData: |
| tc.cryptoDataIn[e.Level] = append(tc.cryptoDataIn[e.Level], e.Data...) |
| case tls.QUICTransportParameters: |
| p, err := unmarshalTransportParams(e.Data) |
| if err != nil { |
| tc.t.Logf("sent unparseable transport parameters %x %v", e.Data, err) |
| } else { |
| tc.sentTransportParameters = &p |
| } |
| } |
| } |
| } |
| |
| // nextMessage is called by the Conn's event loop to request its next event. |
| func (tc *testConnHooks) nextMessage(msgc chan any, timer time.Time) (now time.Time, m any) { |
| tc.timer = timer |
| for { |
| if !timer.IsZero() && !timer.After(tc.endpoint.now) { |
| if timer.Equal(tc.timerLastFired) { |
| // If the connection timer fires at time T, the Conn should take some |
| // action to advance the timer into the future. If the Conn reschedules |
| // the timer for the same time, it isn't making progress and we have a bug. |
| tc.t.Errorf("connection timer spinning; now=%v timer=%v", tc.endpoint.now, timer) |
| } else { |
| tc.timerLastFired = timer |
| return tc.endpoint.now, timerEvent{} |
| } |
| } |
| select { |
| case m := <-msgc: |
| return tc.endpoint.now, m |
| default: |
| } |
| if !tc.wakeAsync() { |
| break |
| } |
| } |
| // If the message queue is empty, then the conn is idle. |
| if tc.idlec != nil { |
| idlec := tc.idlec |
| tc.idlec = nil |
| close(idlec) |
| } |
| m = <-msgc |
| return tc.endpoint.now, m |
| } |
| |
| func (tc *testConnHooks) newConnID(seq int64) ([]byte, error) { |
| return testLocalConnID(seq), nil |
| } |
| |
| func (tc *testConnHooks) timeNow() time.Time { |
| return tc.endpoint.now |
| } |
| |
| // testLocalConnID returns the connection ID with a given sequence number |
| // used by a Conn under test. |
| func testLocalConnID(seq int64) []byte { |
| cid := make([]byte, connIDLen) |
| copy(cid, []byte{0xc0, 0xff, 0xee}) |
| cid[len(cid)-1] = byte(seq) |
| return cid |
| } |
| |
| // testPeerConnID returns the connection ID with a given sequence number |
| // used by the fake peer of a Conn under test. |
| func testPeerConnID(seq int64) []byte { |
| // Use a different length than we choose for our own conn ids, |
| // to help catch any bad assumptions. |
| return []byte{0xbe, 0xee, 0xff, byte(seq)} |
| } |
| |
| func testPeerStatelessResetToken(seq int64) statelessResetToken { |
| return statelessResetToken{ |
| 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, |
| 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, 0xee, byte(seq), |
| } |
| } |
| |
| // canceledContext returns a canceled Context. |
| // |
| // Functions which take a context preference progress over cancelation. |
| // For example, a read with a canceled context will return data if any is available. |
| // Tests use canceled contexts to perform non-blocking operations. |
| func canceledContext() context.Context { |
| ctx, cancel := context.WithCancel(context.Background()) |
| cancel() |
| return ctx |
| } |