quic: send and receive UDP datagrams
Add the Listener type, which manages a UDP socket.
For golang/go#58547
Change-Id: Ia23a8b726ef46f8f84c9e052aa4dfc10eab034d6
Reviewed-on: https://go-review.googlesource.com/c/net/+/527758
Reviewed-by: Jonathan Amsterdam <jba@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/internal/quic/conn.go b/internal/quic/conn.go
index c24e790..0063965 100644
--- a/internal/quic/conn.go
+++ b/internal/quic/conn.go
@@ -20,13 +20,14 @@
// Multiple goroutines may invoke methods on a Conn simultaneously.
type Conn struct {
side connSide
- listener connListener
+ listener *Listener
config *Config
testHooks connTestHooks
peerAddr netip.AddrPort
msgc chan any
donec chan struct{} // closed when conn loop exits
+ readyc chan struct{} // closed when TLS handshake completes
exited bool // set to make the conn loop exit immediately
w packetWriter
@@ -61,21 +62,16 @@
testSendPing sentVal
}
-// The connListener is the Conn's Listener.
-// Defined as an interface so we can swap it out in tests.
-type connListener interface {
- sendDatagram(p []byte, addr netip.AddrPort) error
-}
-
// connTestHooks override conn behavior in tests.
type connTestHooks interface {
nextMessage(msgc chan any, nextTimeout time.Time) (now time.Time, message any)
handleTLSEvent(tls.QUICEvent)
newConnID(seq int64) ([]byte, error)
waitUntil(ctx context.Context, until func() bool) error
+ timeNow() time.Time
}
-func newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip.AddrPort, config *Config, l connListener, hooks connTestHooks) (*Conn, error) {
+func newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip.AddrPort, config *Config, l *Listener, hooks connTestHooks) (*Conn, error) {
c := &Conn{
side: side,
listener: l,
@@ -83,6 +79,7 @@
peerAddr: peerAddr,
msgc: make(chan any, 1),
donec: make(chan struct{}),
+ readyc: make(chan struct{}),
testHooks: hooks,
maxIdleTimeout: defaultMaxIdleTimeout,
idleTimeout: now.Add(defaultMaxIdleTimeout),
@@ -94,12 +91,12 @@
c.msgc = make(chan any, 1)
if c.side == clientSide {
- if err := c.connIDState.initClient(c.newConnIDFunc()); err != nil {
+ if err := c.connIDState.initClient(c); err != nil {
return nil, err
}
initialConnID, _ = c.connIDState.dstConnID()
} else {
- if err := c.connIDState.initServer(c.newConnIDFunc(), initialConnID); err != nil {
+ if err := c.connIDState.initServer(c, initialConnID); err != nil {
return nil, err
}
}
@@ -134,6 +131,14 @@
return fmt.Sprintf("quic.Conn(%v,->%v)", c.side, c.peerAddr)
}
+func (c *Conn) Close() error {
+ // TODO: Implement shutdown for real.
+ c.runOnLoop(func(now time.Time, c *Conn) {
+ c.exited = true
+ })
+ return nil
+}
+
// confirmHandshake is called when the handshake is confirmed.
// https://www.rfc-editor.org/rfc/rfc9001#section-4.1.2
func (c *Conn) confirmHandshake(now time.Time) {
@@ -147,6 +152,7 @@
if c.side == serverSide {
// When the server confirms the handshake, it sends a HANDSHAKE_DONE.
c.handshakeConfirmed.setUnsent()
+ c.listener.serverConnEstablished(c)
} else {
// The client never sends a HANDSHAKE_DONE, so we set handshakeConfirmed
// to the received state, indicating that the handshake is confirmed and we
@@ -177,7 +183,7 @@
c.streams.peerInitialMaxStreamDataRemote[uniStream] = p.initialMaxStreamDataUni
c.peerAckDelayExponent = p.ackDelayExponent
c.loss.setMaxAckDelay(p.maxAckDelay)
- if err := c.connIDState.setPeerActiveConnIDLimit(p.activeConnIDLimit, c.newConnIDFunc()); err != nil {
+ if err := c.connIDState.setPeerActiveConnIDLimit(c, p.activeConnIDLimit); err != nil {
return err
}
if p.preferredAddrConnID != nil {
@@ -211,6 +217,7 @@
func (c *Conn) loop(now time.Time) {
defer close(c.donec)
defer c.tls.Close()
+ defer c.listener.connDrained(c)
// The connection timer sends a message to the connection loop on expiry.
// We need to give it an expiry when creating it, so set the initial timeout to
@@ -371,10 +378,3 @@
return b
}
}
-
-func (c *Conn) newConnIDFunc() newConnIDFunc {
- if c.testHooks != nil {
- return c.testHooks.newConnID
- }
- return newRandomConnID
-}
diff --git a/internal/quic/conn_id.go b/internal/quic/conn_id.go
index 561dea2..eb2f3ec 100644
--- a/internal/quic/conn_id.go
+++ b/internal/quic/conn_id.go
@@ -55,10 +55,10 @@
send sentVal
}
-func (s *connIDState) initClient(newID newConnIDFunc) error {
+func (s *connIDState) initClient(c *Conn) error {
// Client chooses its initial connection ID, and sends it
// in the Source Connection ID field of the first Initial packet.
- locid, err := newID(0)
+ locid, err := c.newConnID(0)
if err != nil {
return err
}
@@ -70,7 +70,7 @@
// Client chooses an initial, transient connection ID for the server,
// and sends it in the Destination Connection ID field of the first Initial packet.
- remid, err := newID(-1)
+ remid, err := c.newConnID(-1)
if err != nil {
return err
}
@@ -78,10 +78,12 @@
seq: -1,
cid: remid,
})
+ const retired = false
+ c.listener.connIDsChanged(c, retired, s.local[:])
return nil
}
-func (s *connIDState) initServer(newID newConnIDFunc, dstConnID []byte) error {
+func (s *connIDState) initServer(c *Conn, dstConnID []byte) error {
// Client-chosen, transient connection ID received in the first Initial packet.
// The server will not use this as the Source Connection ID of packets it sends,
// but remembers it because it may receive packets sent to this destination.
@@ -92,7 +94,7 @@
// Server chooses a connection ID, and sends it in the Source Connection ID of
// the response to the clent.
- locid, err := newID(0)
+ locid, err := c.newConnID(0)
if err != nil {
return err
}
@@ -101,6 +103,8 @@
cid: locid,
})
s.nextLocalSeq = 1
+ const retired = false
+ c.listener.connIDsChanged(c, retired, s.local[:])
return nil
}
@@ -125,20 +129,21 @@
// setPeerActiveConnIDLimit sets the active_connection_id_limit
// transport parameter received from the peer.
-func (s *connIDState) setPeerActiveConnIDLimit(lim int64, newID newConnIDFunc) error {
+func (s *connIDState) setPeerActiveConnIDLimit(c *Conn, lim int64) error {
s.peerActiveConnIDLimit = lim
- return s.issueLocalIDs(newID)
+ return s.issueLocalIDs(c)
}
-func (s *connIDState) issueLocalIDs(newID newConnIDFunc) error {
+func (s *connIDState) issueLocalIDs(c *Conn) error {
toIssue := min(int(s.peerActiveConnIDLimit), maxPeerActiveConnIDLimit)
for i := range s.local {
if s.local[i].seq != -1 && !s.local[i].retired {
toIssue--
}
}
+ prev := len(s.local)
for toIssue > 0 {
- cid, err := newID(s.nextLocalSeq)
+ cid, err := c.newConnID(s.nextLocalSeq)
if err != nil {
return err
}
@@ -151,14 +156,16 @@
s.needSend = true
toIssue--
}
+ const retired = false
+ c.listener.connIDsChanged(c, retired, s.local[prev:])
return nil
}
// handlePacket updates the connection ID state during the handshake
// (Initial and Handshake packets).
-func (s *connIDState) handlePacket(side connSide, ptype packetType, srcConnID []byte) {
+func (s *connIDState) handlePacket(c *Conn, ptype packetType, srcConnID []byte) {
switch {
- case ptype == packetTypeInitial && side == clientSide:
+ case ptype == packetTypeInitial && c.side == clientSide:
if len(s.remote) == 1 && s.remote[0].seq == -1 {
// We're a client connection processing the first Initial packet
// from the server. Replace the transient remote connection ID
@@ -168,7 +175,7 @@
cid: cloneBytes(srcConnID),
}
}
- case ptype == packetTypeInitial && side == serverSide:
+ case ptype == packetTypeInitial && c.side == serverSide:
if len(s.remote) == 0 {
// We're a server connection processing the first Initial packet
// from the client. Set the client's connection ID.
@@ -177,11 +184,13 @@
cid: cloneBytes(srcConnID),
})
}
- case ptype == packetTypeHandshake && side == serverSide:
+ case ptype == packetTypeHandshake && c.side == serverSide:
if len(s.local) > 0 && s.local[0].seq == -1 {
// We're a server connection processing the first Handshake packet from
// the client. Discard the transient, client-chosen connection ID used
// for Initial packets; the client will never send it again.
+ const retired = true
+ c.listener.connIDsChanged(c, retired, s.local[0:1])
s.local = append(s.local[:0], s.local[1:]...)
}
}
@@ -263,17 +272,19 @@
s.needSend = true
}
-func (s *connIDState) handleRetireConnID(seq int64, newID newConnIDFunc) error {
+func (s *connIDState) handleRetireConnID(c *Conn, seq int64) error {
if seq >= s.nextLocalSeq {
return localTransportError(errProtocolViolation)
}
for i := range s.local {
if s.local[i].seq == seq {
+ const retired = true
+ c.listener.connIDsChanged(c, retired, s.local[i:i+1])
s.local = append(s.local[:i], s.local[i+1:]...)
break
}
}
- s.issueLocalIDs(newID)
+ s.issueLocalIDs(c)
return nil
}
@@ -355,7 +366,12 @@
return n
}
-type newConnIDFunc func(seq int64) ([]byte, error)
+func (c *Conn) newConnID(seq int64) ([]byte, error) {
+ if c.testHooks != nil {
+ return c.testHooks.newConnID(seq)
+ }
+ return newRandomConnID(seq)
+}
func newRandomConnID(_ int64) ([]byte, error) {
// It is not necessary for connection IDs to be cryptographically secure,
diff --git a/internal/quic/conn_id_test.go b/internal/quic/conn_id_test.go
index d479cd4..c528958 100644
--- a/internal/quic/conn_id_test.go
+++ b/internal/quic/conn_id_test.go
@@ -11,100 +11,135 @@
"crypto/tls"
"fmt"
"net/netip"
- "reflect"
+ "strings"
"testing"
)
func TestConnIDClientHandshake(t *testing.T) {
+ tc := newTestConn(t, clientSide)
// On initialization, the client chooses local and remote IDs.
//
// The order in which we allocate the two isn't actually important,
// but test is a lot simpler if we assume.
- var s connIDState
- s.initClient(newConnIDSequence())
- if got, want := string(s.srcConnID()), "local-1"; got != want {
- t.Errorf("after initClient: srcConnID = %q, want %q", got, want)
+ if got, want := tc.conn.connIDState.srcConnID(), testLocalConnID(0); !bytes.Equal(got, want) {
+ t.Errorf("after initialization: srcConnID = %x, want %x", got, want)
}
- dstConnID, _ := s.dstConnID()
- if got, want := string(dstConnID), "local-2"; got != want {
- t.Errorf("after initClient: dstConnID = %q, want %q", got, want)
+ dstConnID, _ := tc.conn.connIDState.dstConnID()
+ if got, want := dstConnID, testLocalConnID(-1); !bytes.Equal(got, want) {
+ t.Errorf("after initialization: dstConnID = %x, want %x", got, want)
}
// The server's first Initial packet provides the client with a
// non-transient remote connection ID.
- s.handlePacket(clientSide, packetTypeInitial, []byte("remote-1"))
- dstConnID, _ = s.dstConnID()
- if got, want := string(dstConnID), "remote-1"; got != want {
- t.Errorf("after receiving Initial: dstConnID = %q, want %q", got, want)
+ tc.writeFrames(packetTypeInitial,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial],
+ })
+ dstConnID, _ = tc.conn.connIDState.dstConnID()
+ if got, want := dstConnID, testPeerConnID(0); !bytes.Equal(got, want) {
+ t.Errorf("after receiving Initial: dstConnID = %x, want %x", got, want)
}
wantLocal := []connID{{
- cid: []byte("local-1"),
+ cid: testLocalConnID(0),
seq: 0,
}}
- if !reflect.DeepEqual(s.local, wantLocal) {
- t.Errorf("local ids: %v, want %v", s.local, wantLocal)
+ if got := tc.conn.connIDState.local; !connIDListEqual(got, wantLocal) {
+ t.Errorf("local ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantLocal))
}
wantRemote := []connID{{
- cid: []byte("remote-1"),
+ cid: testPeerConnID(0),
seq: 0,
}}
- if !reflect.DeepEqual(s.remote, wantRemote) {
- t.Errorf("remote ids: %v, want %v", s.remote, wantRemote)
+ if got := tc.conn.connIDState.remote; !connIDListEqual(got, wantRemote) {
+ t.Errorf("remote ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantRemote))
}
}
func TestConnIDServerHandshake(t *testing.T) {
+ tc := newTestConn(t, serverSide)
// On initialization, the server is provided with the client-chosen
// transient connection ID, and allocates an ID of its own.
// The Initial packet sets the remote connection ID.
- var s connIDState
- s.initServer(newConnIDSequence(), []byte("transient"))
- s.handlePacket(serverSide, packetTypeInitial, []byte("remote-1"))
- if got, want := string(s.srcConnID()), "local-1"; got != want {
+ tc.writeFrames(packetTypeInitial,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial][:1],
+ })
+ if got, want := tc.conn.connIDState.srcConnID(), testLocalConnID(0); !bytes.Equal(got, want) {
t.Errorf("after initClient: srcConnID = %q, want %q", got, want)
}
- dstConnID, _ := s.dstConnID()
- if got, want := string(dstConnID), "remote-1"; got != want {
+ dstConnID, _ := tc.conn.connIDState.dstConnID()
+ if got, want := dstConnID, testPeerConnID(0); !bytes.Equal(got, want) {
t.Errorf("after initClient: dstConnID = %q, want %q", got, want)
}
+ // The Initial flight of CRYPTO data includes transport parameters,
+ // which cause us to allocate another local connection ID.
+ tc.writeFrames(packetTypeInitial,
+ debugFrameCrypto{
+ off: 1,
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial][1:],
+ })
wantLocal := []connID{{
- cid: []byte("transient"),
+ cid: testPeerConnID(-1),
seq: -1,
}, {
- cid: []byte("local-1"),
+ cid: testLocalConnID(0),
seq: 0,
+ }, {
+ cid: testLocalConnID(1),
+ seq: 1,
}}
- if !reflect.DeepEqual(s.local, wantLocal) {
- t.Errorf("local ids: %v, want %v", s.local, wantLocal)
+ if got := tc.conn.connIDState.local; !connIDListEqual(got, wantLocal) {
+ t.Errorf("local ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantLocal))
}
wantRemote := []connID{{
- cid: []byte("remote-1"),
+ cid: testPeerConnID(0),
seq: 0,
}}
- if !reflect.DeepEqual(s.remote, wantRemote) {
- t.Errorf("remote ids: %v, want %v", s.remote, wantRemote)
+ if got := tc.conn.connIDState.remote; !connIDListEqual(got, wantRemote) {
+ t.Errorf("remote ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantRemote))
}
// The client's first Handshake packet permits the server to discard the
// transient connection ID.
- s.handlePacket(serverSide, packetTypeHandshake, []byte("remote-1"))
+ tc.writeFrames(packetTypeHandshake,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake],
+ })
wantLocal = []connID{{
- cid: []byte("local-1"),
+ cid: testLocalConnID(0),
seq: 0,
+ }, {
+ cid: testLocalConnID(1),
+ seq: 1,
}}
- if !reflect.DeepEqual(s.local, wantLocal) {
- t.Errorf("after handshake local ids: %v, want %v", s.local, wantLocal)
+ if got := tc.conn.connIDState.local; !connIDListEqual(got, wantLocal) {
+ t.Errorf("local ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantLocal))
}
}
-func newConnIDSequence() newConnIDFunc {
- var n uint64
- return func(_ int64) ([]byte, error) {
- n++
- return []byte(fmt.Sprintf("local-%v", n)), nil
+func connIDListEqual(a, b []connID) bool {
+ if len(a) != len(b) {
+ return false
}
+ for i := range a {
+ if a[i].seq != b[i].seq {
+ return false
+ }
+ if !bytes.Equal(a[i].cid, b[i].cid) {
+ return false
+ }
+ }
+ return true
+}
+
+func fmtConnIDList(s []connID) string {
+ var strs []string
+ for _, cid := range s {
+ strs = append(strs, fmt.Sprintf("[seq:%v cid:{%x}]", cid.seq, cid.cid))
+ }
+ return "{" + strings.Join(strs, " ") + "}"
}
func TestNewRandomConnID(t *testing.T) {
diff --git a/internal/quic/conn_recv.go b/internal/quic/conn_recv.go
index 07f17e3..b866d8a 100644
--- a/internal/quic/conn_recv.go
+++ b/internal/quic/conn_recv.go
@@ -63,7 +63,7 @@
if logPackets {
logInboundLongPacket(c, p)
}
- c.connIDState.handlePacket(c.side, p.ptype, p.srcConnID)
+ c.connIDState.handlePacket(c, p.ptype, p.srcConnID)
ackEliciting := c.handleFrames(now, ptype, space, p.payload)
c.acks[space].receive(now, space, p.num, ackEliciting)
if p.ptype == packetTypeHandshake && c.side == serverSide {
@@ -377,7 +377,7 @@
if n < 0 {
return -1
}
- if err := c.connIDState.handleRetireConnID(seq, c.newConnIDFunc()); err != nil {
+ if err := c.connIDState.handleRetireConnID(c, seq); err != nil {
c.abort(now, err)
}
return n
diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go
index ea720d5..cdbd466 100644
--- a/internal/quic/conn_test.go
+++ b/internal/quic/conn_test.go
@@ -13,7 +13,9 @@
"errors"
"flag"
"fmt"
+ "io"
"math"
+ "net"
"net/netip"
"reflect"
"strings"
@@ -105,6 +107,7 @@
type testConn struct {
t *testing.T
conn *Conn
+ listener *Listener
now time.Time
timer time.Time
timerLastFired time.Time
@@ -142,6 +145,8 @@
sentFrames []debugFrame
lastPacket *testPacket
+ recvDatagram chan *datagram
+
// Transport parameters sent by the conn.
sentTransportParameters *transportParameters
@@ -173,6 +178,7 @@
},
cryptoDataOut: make(map[tls.QUICEncryptionLevel][]byte),
cryptoDataIn: make(map[tls.QUICEncryptionLevel][]byte),
+ recvDatagram: make(chan *datagram),
}
t.Cleanup(tc.cleanup)
@@ -196,12 +202,7 @@
var initialConnID []byte
if side == serverSide {
// The initial connection ID for the server is chosen by the client.
- // When creating a server-side connection, pick a random connection ID here.
- var err error
- initialConnID, err = newRandomConnID(0)
- if err != nil {
- tc.t.Fatal(err)
- }
+ initialConnID = testPeerConnID(-1)
}
peerQUICConfig := &tls.QUICConfig{TLSConfig: newTestTLSConfig(side.peer())}
@@ -213,14 +214,12 @@
tc.peerTLSConn.SetTransportParameters(marshalTransportParameters(peerProvidedParams))
tc.peerTLSConn.Start(context.Background())
- conn, err := newConn(
+ tc.listener = newListener((*testConnUDPConn)(tc), config, (*testConnHooks)(tc))
+ conn, err := tc.listener.newConn(
tc.now,
side,
initialConnID,
- netip.MustParseAddrPort("127.0.0.1:443"),
- config,
- (*testConnListener)(tc),
- (*testConnHooks)(tc))
+ netip.MustParseAddrPort("127.0.0.1:443"))
if err != nil {
tc.t.Fatal(err)
}
@@ -316,6 +315,7 @@
return
}
tc.conn.exit()
+ tc.listener.Close(context.Background())
}
func (tc *testConn) logDatagram(text string, d *testDatagram) {
@@ -844,6 +844,10 @@
return testLocalConnID(seq), nil
}
+func (tc *testConnHooks) timeNow() time.Time {
+ return tc.now
+}
+
// testLocalConnID returns the connection ID with a given sequence number
// used by a Conn under test.
func testLocalConnID(seq int64) []byte {
@@ -861,14 +865,31 @@
return []byte{0xbe, 0xee, 0xff, byte(seq)}
}
-// testConnListener implements connListener.
-type testConnListener testConn
+// testConnUDPConn implements UDPConn.
+type testConnUDPConn testConn
-func (tc *testConnListener) sendDatagram(p []byte, addr netip.AddrPort) error {
- tc.sentDatagrams = append(tc.sentDatagrams, append([]byte(nil), p...))
+func (tc *testConnUDPConn) Close() error {
+ close(tc.recvDatagram)
return nil
}
+func (tc *testConnUDPConn) LocalAddr() net.Addr {
+ return net.UDPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:443"))
+}
+
+func (tc *testConnUDPConn) ReadMsgUDPAddrPort(b, control []byte) (n, controln, flags int, _ netip.AddrPort, _ error) {
+ for d := range tc.recvDatagram {
+ n = copy(b, d.b)
+ return n, 0, 0, d.addr, nil
+ }
+ return 0, 0, 0, netip.AddrPort{}, io.EOF
+}
+
+func (tc *testConnUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
+ tc.sentDatagrams = append(tc.sentDatagrams, append([]byte(nil), b...))
+ return len(b), nil
+}
+
// canceledContext returns a canceled Context.
//
// Functions which take a context preference progress over cancelation.
diff --git a/internal/quic/listener.go b/internal/quic/listener.go
new file mode 100644
index 0000000..9869f6e
--- /dev/null
+++ b/internal/quic/listener.go
@@ -0,0 +1,280 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "context"
+ "errors"
+ "net"
+ "net/netip"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+// A Listener listens for QUIC traffic on a network address.
+// It can accept inbound connections or create outbound ones.
+//
+// Multiple goroutines may invoke methods on a Listener simultaneously.
+type Listener struct {
+ config *Config
+ udpConn udpConn
+ testHooks connTestHooks
+
+ acceptQueue queue[*Conn] // new inbound connections
+
+ connsMu sync.Mutex
+ conns map[*Conn]struct{}
+ closing bool // set when Close is called
+ closec chan struct{} // closed when the listen loop exits
+
+ // The datagram receive loop keeps a mapping of connection IDs to conns.
+ // When a conn's connection IDs change, we add it to connIDUpdates and set
+ // connIDUpdateNeeded, and the receive loop updates its map.
+ connIDUpdateMu sync.Mutex
+ connIDUpdateNeeded atomic.Bool
+ connIDUpdates []connIDUpdate
+}
+
+// A udpConn is a UDP connection.
+// It is implemented by net.UDPConn.
+type udpConn interface {
+ Close() error
+ LocalAddr() net.Addr
+ ReadMsgUDPAddrPort(b, control []byte) (n, controln, flags int, _ netip.AddrPort, _ error)
+ WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error)
+}
+
+type connIDUpdate struct {
+ conn *Conn
+ retired bool
+ cid []byte
+}
+
+// Listen listens on a local network address.
+// The configuration config must be non-nil.
+func Listen(network, address string, config *Config) (*Listener, error) {
+ if config.TLSConfig == nil {
+ return nil, errors.New("TLSConfig is not set")
+ }
+ a, err := net.ResolveUDPAddr(network, address)
+ if err != nil {
+ return nil, err
+ }
+ udpConn, err := net.ListenUDP(network, a)
+ if err != nil {
+ return nil, err
+ }
+ return newListener(udpConn, config, nil), nil
+}
+
+func newListener(udpConn udpConn, config *Config, hooks connTestHooks) *Listener {
+ l := &Listener{
+ config: config,
+ udpConn: udpConn,
+ testHooks: hooks,
+ conns: make(map[*Conn]struct{}),
+ acceptQueue: newQueue[*Conn](),
+ closec: make(chan struct{}),
+ }
+ go l.listen()
+ return l
+}
+
+// LocalAddr returns the local network address.
+func (l *Listener) LocalAddr() netip.AddrPort {
+ a, _ := l.udpConn.LocalAddr().(*net.UDPAddr)
+ return a.AddrPort()
+}
+
+// Close closes the listener.
+// Any blocked operations on the Listener or associated Conns and Stream will be unblocked
+// and return errors.
+//
+// Close aborts every open connection.
+// Data in stream read and write buffers is discarded.
+// It waits for the peers of any open connection to acknowledge the connection has been closed.
+func (l *Listener) Close(ctx context.Context) error {
+ l.acceptQueue.close(errors.New("listener closed"))
+ l.connsMu.Lock()
+ if !l.closing {
+ l.closing = true
+ for c := range l.conns {
+ c.Close()
+ }
+ if len(l.conns) == 0 {
+ l.udpConn.Close()
+ }
+ }
+ l.connsMu.Unlock()
+ select {
+ case <-l.closec:
+ case <-ctx.Done():
+ l.connsMu.Lock()
+ for c := range l.conns {
+ c.exit()
+ }
+ l.connsMu.Unlock()
+ return ctx.Err()
+ }
+ return nil
+}
+
+// Accept waits for and returns the next connection to the listener.
+func (l *Listener) Accept(ctx context.Context) (*Conn, error) {
+ return l.acceptQueue.get(ctx, nil)
+}
+
+// Dial creates and returns a connection to a network address.
+func (l *Listener) Dial(ctx context.Context, network, address string) (*Conn, error) {
+ u, err := net.ResolveUDPAddr(network, address)
+ if err != nil {
+ return nil, err
+ }
+ addr := u.AddrPort()
+ addr = netip.AddrPortFrom(addr.Addr().Unmap(), addr.Port())
+ c, err := l.newConn(time.Now(), clientSide, nil, addr)
+ if err != nil {
+ return nil, err
+ }
+ select {
+ case <-c.readyc:
+ case <-ctx.Done():
+ c.Close()
+ return nil, ctx.Err()
+ }
+ return c, nil
+}
+
+func (l *Listener) newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip.AddrPort) (*Conn, error) {
+ l.connsMu.Lock()
+ defer l.connsMu.Unlock()
+ if l.closing {
+ return nil, errors.New("listener closed")
+ }
+ c, err := newConn(now, side, initialConnID, peerAddr, l.config, l, l.testHooks)
+ if err != nil {
+ return nil, err
+ }
+ l.conns[c] = struct{}{}
+ return c, nil
+}
+
+// serverConnEstablished is called by a conn when the handshake completes
+// for an inbound (serverSide) connection.
+func (l *Listener) serverConnEstablished(c *Conn) {
+ l.acceptQueue.put(c)
+}
+
+// connDrained is called by a conn when it leaves the draining state,
+// either when the peer acknowledges connection closure or the drain timeout expires.
+func (l *Listener) connDrained(c *Conn) {
+ l.connsMu.Lock()
+ defer l.connsMu.Unlock()
+ delete(l.conns, c)
+ if l.closing && len(l.conns) == 0 {
+ l.udpConn.Close()
+ }
+}
+
+// connIDsChanged is called by a conn when its connection IDs change.
+func (l *Listener) connIDsChanged(c *Conn, retired bool, cids []connID) {
+ l.connIDUpdateMu.Lock()
+ defer l.connIDUpdateMu.Unlock()
+ for _, cid := range cids {
+ l.connIDUpdates = append(l.connIDUpdates, connIDUpdate{
+ conn: c,
+ retired: retired,
+ cid: cid.cid,
+ })
+ }
+ l.connIDUpdateNeeded.Store(true)
+}
+
+// updateConnIDs is called by the datagram receive loop to update its connection ID map.
+func (l *Listener) updateConnIDs(conns map[string]*Conn) {
+ l.connIDUpdateMu.Lock()
+ defer l.connIDUpdateMu.Unlock()
+ for i, u := range l.connIDUpdates {
+ if u.retired {
+ delete(conns, string(u.cid))
+ } else {
+ conns[string(u.cid)] = u.conn
+ }
+ l.connIDUpdates[i] = connIDUpdate{} // drop refs
+ }
+ l.connIDUpdates = l.connIDUpdates[:0]
+ l.connIDUpdateNeeded.Store(false)
+}
+
+func (l *Listener) listen() {
+ defer close(l.closec)
+ conns := map[string]*Conn{}
+ for {
+ m := newDatagram()
+ // TODO: Read and process the ECN (explicit congestion notification) field.
+ // https://tools.ietf.org/html/draft-ietf-quic-transport-32#section-13.4
+ n, _, _, addr, err := l.udpConn.ReadMsgUDPAddrPort(m.b, nil)
+ if err != nil {
+ // The user has probably closed the listener.
+ // We currently don't surface errors from other causes;
+ // we could check to see if the listener has been closed and
+ // record the unexpected error if it has not.
+ return
+ }
+ if n == 0 {
+ continue
+ }
+ if l.connIDUpdateNeeded.Load() {
+ l.updateConnIDs(conns)
+ }
+ m.addr = addr
+ m.b = m.b[:n]
+ l.handleDatagram(m, conns)
+ }
+}
+
+func (l *Listener) handleDatagram(m *datagram, conns map[string]*Conn) {
+ dstConnID, ok := dstConnIDForDatagram(m.b)
+ if !ok {
+ return
+ }
+ c := conns[string(dstConnID)]
+ if c == nil {
+ if getPacketType(m.b) != packetTypeInitial {
+ // This packet isn't trying to create a new connection.
+ // It might be associated with some connection we've lost state for.
+ // TODO: Send a stateless reset when appropriate.
+ // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.3
+ return
+ }
+ var now time.Time
+ if l.testHooks != nil {
+ now = l.testHooks.timeNow()
+ } else {
+ now = time.Now()
+ }
+ var err error
+ c, err = l.newConn(now, serverSide, dstConnID, m.addr)
+ if err != nil {
+ // The accept queue is probably full.
+ // We could send a CONNECTION_CLOSE to the peer to reject the connection.
+ // Currently, we just drop the datagram.
+ // https://www.rfc-editor.org/rfc/rfc9000.html#section-5.2.2-5
+ return
+ }
+ }
+
+ // TODO: This can block the listener while waiting for the conn to accept the dgram.
+ // Think about buffering between the receive loop and the conn.
+ c.sendMsg(m)
+}
+
+func (l *Listener) sendDatagram(p []byte, addr netip.AddrPort) error {
+ _, err := l.udpConn.WriteToUDPAddrPort(p, addr)
+ return err
+}
diff --git a/internal/quic/listener_test.go b/internal/quic/listener_test.go
new file mode 100644
index 0000000..a6e0b34
--- /dev/null
+++ b/internal/quic/listener_test.go
@@ -0,0 +1,88 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "bytes"
+ "context"
+ "io"
+ "testing"
+)
+
+func TestConnect(t *testing.T) {
+ newLocalConnPair(t, &Config{}, &Config{})
+}
+
+func TestStreamTransfer(t *testing.T) {
+ ctx := context.Background()
+ cli, srv := newLocalConnPair(t, &Config{}, &Config{})
+ data := makeTestData(1 << 20)
+
+ srvdone := make(chan struct{})
+ go func() {
+ defer close(srvdone)
+ s, err := srv.AcceptStream(ctx)
+ if err != nil {
+ t.Errorf("AcceptStream: %v", err)
+ return
+ }
+ b, err := io.ReadAll(s)
+ if err != nil {
+ t.Errorf("io.ReadAll(s): %v", err)
+ return
+ }
+ if !bytes.Equal(b, data) {
+ t.Errorf("read data mismatch (got %v bytes, want %v", len(b), len(data))
+ }
+ if err := s.Close(); err != nil {
+ t.Errorf("s.Close() = %v", err)
+ }
+ }()
+
+ s, err := cli.NewStream(ctx)
+ if err != nil {
+ t.Fatalf("NewStream: %v", err)
+ }
+ n, err := io.Copy(s, bytes.NewBuffer(data))
+ if n != int64(len(data)) || err != nil {
+ t.Fatalf("io.Copy(s, data) = %v, %v; want %v, nil", n, err, len(data))
+ }
+ if err := s.Close(); err != nil {
+ t.Fatalf("s.Close() = %v", err)
+ }
+}
+
+func newLocalConnPair(t *testing.T, conf1, conf2 *Config) (clientConn, serverConn *Conn) {
+ t.Helper()
+ ctx := context.Background()
+ l1 := newLocalListener(t, serverSide, conf1)
+ l2 := newLocalListener(t, clientSide, conf2)
+ c2, err := l2.Dial(ctx, "udp", l1.LocalAddr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ c1, err := l1.Accept(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return c2, c1
+}
+
+func newLocalListener(t *testing.T, side connSide, conf *Config) *Listener {
+ t.Helper()
+ if conf.TLSConfig == nil {
+ conf.TLSConfig = newTestTLSConfig(side)
+ }
+ l, err := Listen("udp", "127.0.0.1:0", conf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Cleanup(func() {
+ l.Close(context.Background())
+ })
+ return l
+}
diff --git a/internal/quic/tls.go b/internal/quic/tls.go
index 584316f..1d07f17 100644
--- a/internal/quic/tls.go
+++ b/internal/quic/tls.go
@@ -73,6 +73,7 @@
// https://www.rfc-editor.org/rfc/rfc9001#section-4.1.2-1
c.confirmHandshake(now)
}
+ close(c.readyc)
case tls.QUICTransportParameters:
params, err := unmarshalTransportParams(e.Data)
if err != nil {