blob: a5cc690ac4dca1a6b232eef8cfa4874fe3cabc6c [file] [log] [blame]
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.21
package quic
import (
"bytes"
"context"
"crypto/tls"
"io"
"net"
"net/netip"
"reflect"
"testing"
"time"
)
func TestConnect(t *testing.T) {
newLocalConnPair(t, &Config{}, &Config{})
}
func TestStreamTransfer(t *testing.T) {
ctx := context.Background()
cli, srv := newLocalConnPair(t, &Config{}, &Config{})
data := makeTestData(1 << 20)
srvdone := make(chan struct{})
go func() {
defer close(srvdone)
s, err := srv.AcceptStream(ctx)
if err != nil {
t.Errorf("AcceptStream: %v", err)
return
}
b, err := io.ReadAll(s)
if err != nil {
t.Errorf("io.ReadAll(s): %v", err)
return
}
if !bytes.Equal(b, data) {
t.Errorf("read data mismatch (got %v bytes, want %v", len(b), len(data))
}
if err := s.Close(); err != nil {
t.Errorf("s.Close() = %v", err)
}
}()
s, err := cli.NewStream(ctx)
if err != nil {
t.Fatalf("NewStream: %v", err)
}
n, err := io.Copy(s, bytes.NewBuffer(data))
if n != int64(len(data)) || err != nil {
t.Fatalf("io.Copy(s, data) = %v, %v; want %v, nil", n, err, len(data))
}
if err := s.Close(); err != nil {
t.Fatalf("s.Close() = %v", err)
}
}
func newLocalConnPair(t *testing.T, conf1, conf2 *Config) (clientConn, serverConn *Conn) {
t.Helper()
ctx := context.Background()
l1 := newLocalListener(t, serverSide, conf1)
l2 := newLocalListener(t, clientSide, conf2)
c2, err := l2.Dial(ctx, "udp", l1.LocalAddr().String())
if err != nil {
t.Fatal(err)
}
c1, err := l1.Accept(ctx)
if err != nil {
t.Fatal(err)
}
return c2, c1
}
func newLocalListener(t *testing.T, side connSide, conf *Config) *Listener {
t.Helper()
if conf.TLSConfig == nil {
conf.TLSConfig = newTestTLSConfig(side)
}
l, err := Listen("udp", "127.0.0.1:0", conf)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
l.Close(context.Background())
})
return l
}
type testListener struct {
t *testing.T
l *Listener
now time.Time
recvc chan *datagram
idlec chan struct{}
conns map[*Conn]*testConn
acceptQueue []*testConn
configTransportParams []func(*transportParameters)
sentDatagrams [][]byte
peerTLSConn *tls.QUICConn
lastInitialDstConnID []byte // for parsing Retry packets
}
func newTestListener(t *testing.T, config *Config) *testListener {
tl := &testListener{
t: t,
now: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC),
recvc: make(chan *datagram),
idlec: make(chan struct{}),
conns: make(map[*Conn]*testConn),
}
var err error
tl.l, err = newListener((*testListenerUDPConn)(tl), config, (*testListenerHooks)(tl))
if err != nil {
t.Fatal(err)
}
t.Cleanup(tl.cleanup)
return tl
}
func (tl *testListener) cleanup() {
tl.l.Close(canceledContext())
}
func (tl *testListener) wait() {
select {
case tl.idlec <- struct{}{}:
case <-tl.l.closec:
}
for _, tc := range tl.conns {
tc.wait()
}
}
// accept returns a server connection from the listener.
// Unlike Listener.Accept, connections are available as soon as they are created.
func (tl *testListener) accept() *testConn {
if len(tl.acceptQueue) == 0 {
tl.t.Fatalf("accept: expected available conn, but found none")
}
tc := tl.acceptQueue[0]
tl.acceptQueue = tl.acceptQueue[1:]
return tc
}
func (tl *testListener) write(d *datagram) {
tl.recvc <- d
tl.wait()
}
var testClientAddr = netip.MustParseAddrPort("10.0.0.1:8000")
func (tl *testListener) writeDatagram(d *testDatagram) {
tl.t.Helper()
logDatagram(tl.t, "<- listener under test receives", d)
var buf []byte
for _, p := range d.packets {
tc := tl.connForDestination(p.dstConnID)
if p.ptype != packetTypeRetry && tc != nil {
space := spaceForPacketType(p.ptype)
if p.num >= tc.peerNextPacketNum[space] {
tc.peerNextPacketNum[space] = p.num + 1
}
}
if p.ptype == packetTypeInitial {
tl.lastInitialDstConnID = p.dstConnID
}
pad := 0
if p.ptype == packetType1RTT {
pad = d.paddedSize
}
buf = append(buf, encodeTestPacket(tl.t, tc, p, pad)...)
}
for len(buf) < d.paddedSize {
buf = append(buf, 0)
}
addr := d.addr
if !addr.IsValid() {
addr = testClientAddr
}
tl.write(&datagram{
b: buf,
addr: addr,
})
}
func (tl *testListener) connForDestination(dstConnID []byte) *testConn {
for _, tc := range tl.conns {
for _, loc := range tc.conn.connIDState.local {
if bytes.Equal(loc.cid, dstConnID) {
return tc
}
}
}
return nil
}
func (tl *testListener) connForSource(srcConnID []byte) *testConn {
for _, tc := range tl.conns {
for _, loc := range tc.conn.connIDState.remote {
if bytes.Equal(loc.cid, srcConnID) {
return tc
}
}
}
return nil
}
func (tl *testListener) read() []byte {
tl.t.Helper()
tl.wait()
if len(tl.sentDatagrams) == 0 {
return nil
}
d := tl.sentDatagrams[0]
tl.sentDatagrams = tl.sentDatagrams[1:]
return d
}
func (tl *testListener) readDatagram() *testDatagram {
tl.t.Helper()
buf := tl.read()
if buf == nil {
return nil
}
p, _ := parseGenericLongHeaderPacket(buf)
tc := tl.connForSource(p.dstConnID)
d := parseTestDatagram(tl.t, tl, tc, buf)
logDatagram(tl.t, "-> listener under test sends", d)
return d
}
// wantDatagram indicates that we expect the Listener to send a datagram.
func (tl *testListener) wantDatagram(expectation string, want *testDatagram) {
tl.t.Helper()
got := tl.readDatagram()
if !reflect.DeepEqual(got, want) {
tl.t.Fatalf("%v:\ngot datagram: %v\nwant datagram: %v", expectation, got, want)
}
}
// wantIdle indicates that we expect the Listener to not send any more datagrams.
func (tl *testListener) wantIdle(expectation string) {
if got := tl.readDatagram(); got != nil {
tl.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, got)
}
}
func (tl *testListener) newClientTLS(srcConnID, dstConnID []byte) []byte {
peerProvidedParams := defaultTransportParameters()
peerProvidedParams.initialSrcConnID = srcConnID
peerProvidedParams.originalDstConnID = dstConnID
for _, f := range tl.configTransportParams {
f(&peerProvidedParams)
}
config := &tls.QUICConfig{TLSConfig: newTestTLSConfig(clientSide)}
tl.peerTLSConn = tls.QUICClient(config)
tl.peerTLSConn.SetTransportParameters(marshalTransportParameters(peerProvidedParams))
tl.peerTLSConn.Start(context.Background())
var data []byte
for {
e := tl.peerTLSConn.NextEvent()
switch e.Kind {
case tls.QUICNoEvent:
return data
case tls.QUICWriteData:
if e.Level != tls.QUICEncryptionLevelInitial {
tl.t.Fatal("initial data at unexpected level")
}
data = append(data, e.Data...)
}
}
}
// advance causes time to pass.
func (tl *testListener) advance(d time.Duration) {
tl.t.Helper()
tl.advanceTo(tl.now.Add(d))
}
// advanceTo sets the current time.
func (tl *testListener) advanceTo(now time.Time) {
tl.t.Helper()
if tl.now.After(now) {
tl.t.Fatalf("time moved backwards: %v -> %v", tl.now, now)
}
tl.now = now
for _, tc := range tl.conns {
if !tc.timer.After(tl.now) {
tc.conn.sendMsg(timerEvent{})
tc.wait()
}
}
}
// testListenerHooks implements listenerTestHooks.
type testListenerHooks testListener
func (tl *testListenerHooks) timeNow() time.Time {
return tl.now
}
func (tl *testListenerHooks) newConn(c *Conn) {
tc := newTestConnForConn(tl.t, (*testListener)(tl), c)
tl.conns[c] = tc
}
// testListenerUDPConn implements UDPConn.
type testListenerUDPConn testListener
func (tl *testListenerUDPConn) Close() error {
close(tl.recvc)
return nil
}
func (tl *testListenerUDPConn) LocalAddr() net.Addr {
return net.UDPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:443"))
}
func (tl *testListenerUDPConn) ReadMsgUDPAddrPort(b, control []byte) (n, controln, flags int, _ netip.AddrPort, _ error) {
for {
select {
case d, ok := <-tl.recvc:
if !ok {
return 0, 0, 0, netip.AddrPort{}, io.EOF
}
n = copy(b, d.b)
return n, 0, 0, d.addr, nil
case <-tl.idlec:
}
}
}
func (tl *testListenerUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
tl.sentDatagrams = append(tl.sentDatagrams, append([]byte(nil), b...))
return len(b), nil
}