quic: send and receive UDP datagrams

Add the Listener type, which manages a UDP socket.

For golang/go#58547

Change-Id: Ia23a8b726ef46f8f84c9e052aa4dfc10eab034d6
Reviewed-on: https://go-review.googlesource.com/c/net/+/527758
Reviewed-by: Jonathan Amsterdam <jba@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/internal/quic/conn.go b/internal/quic/conn.go
index c24e790..0063965 100644
--- a/internal/quic/conn.go
+++ b/internal/quic/conn.go
@@ -20,13 +20,14 @@
 // Multiple goroutines may invoke methods on a Conn simultaneously.
 type Conn struct {
 	side      connSide
-	listener  connListener
+	listener  *Listener
 	config    *Config
 	testHooks connTestHooks
 	peerAddr  netip.AddrPort
 
 	msgc   chan any
 	donec  chan struct{} // closed when conn loop exits
+	readyc chan struct{} // closed when TLS handshake completes
 	exited bool          // set to make the conn loop exit immediately
 
 	w           packetWriter
@@ -61,21 +62,16 @@
 	testSendPing      sentVal
 }
 
-// The connListener is the Conn's Listener.
-// Defined as an interface so we can swap it out in tests.
-type connListener interface {
-	sendDatagram(p []byte, addr netip.AddrPort) error
-}
-
 // connTestHooks override conn behavior in tests.
 type connTestHooks interface {
 	nextMessage(msgc chan any, nextTimeout time.Time) (now time.Time, message any)
 	handleTLSEvent(tls.QUICEvent)
 	newConnID(seq int64) ([]byte, error)
 	waitUntil(ctx context.Context, until func() bool) error
+	timeNow() time.Time
 }
 
-func newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip.AddrPort, config *Config, l connListener, hooks connTestHooks) (*Conn, error) {
+func newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip.AddrPort, config *Config, l *Listener, hooks connTestHooks) (*Conn, error) {
 	c := &Conn{
 		side:                 side,
 		listener:             l,
@@ -83,6 +79,7 @@
 		peerAddr:             peerAddr,
 		msgc:                 make(chan any, 1),
 		donec:                make(chan struct{}),
+		readyc:               make(chan struct{}),
 		testHooks:            hooks,
 		maxIdleTimeout:       defaultMaxIdleTimeout,
 		idleTimeout:          now.Add(defaultMaxIdleTimeout),
@@ -94,12 +91,12 @@
 	c.msgc = make(chan any, 1)
 
 	if c.side == clientSide {
-		if err := c.connIDState.initClient(c.newConnIDFunc()); err != nil {
+		if err := c.connIDState.initClient(c); err != nil {
 			return nil, err
 		}
 		initialConnID, _ = c.connIDState.dstConnID()
 	} else {
-		if err := c.connIDState.initServer(c.newConnIDFunc(), initialConnID); err != nil {
+		if err := c.connIDState.initServer(c, initialConnID); err != nil {
 			return nil, err
 		}
 	}
@@ -134,6 +131,14 @@
 	return fmt.Sprintf("quic.Conn(%v,->%v)", c.side, c.peerAddr)
 }
 
+func (c *Conn) Close() error {
+	// TODO: Implement shutdown for real.
+	c.runOnLoop(func(now time.Time, c *Conn) {
+		c.exited = true
+	})
+	return nil
+}
+
 // confirmHandshake is called when the handshake is confirmed.
 // https://www.rfc-editor.org/rfc/rfc9001#section-4.1.2
 func (c *Conn) confirmHandshake(now time.Time) {
@@ -147,6 +152,7 @@
 	if c.side == serverSide {
 		// When the server confirms the handshake, it sends a HANDSHAKE_DONE.
 		c.handshakeConfirmed.setUnsent()
+		c.listener.serverConnEstablished(c)
 	} else {
 		// The client never sends a HANDSHAKE_DONE, so we set handshakeConfirmed
 		// to the received state, indicating that the handshake is confirmed and we
@@ -177,7 +183,7 @@
 	c.streams.peerInitialMaxStreamDataRemote[uniStream] = p.initialMaxStreamDataUni
 	c.peerAckDelayExponent = p.ackDelayExponent
 	c.loss.setMaxAckDelay(p.maxAckDelay)
-	if err := c.connIDState.setPeerActiveConnIDLimit(p.activeConnIDLimit, c.newConnIDFunc()); err != nil {
+	if err := c.connIDState.setPeerActiveConnIDLimit(c, p.activeConnIDLimit); err != nil {
 		return err
 	}
 	if p.preferredAddrConnID != nil {
@@ -211,6 +217,7 @@
 func (c *Conn) loop(now time.Time) {
 	defer close(c.donec)
 	defer c.tls.Close()
+	defer c.listener.connDrained(c)
 
 	// The connection timer sends a message to the connection loop on expiry.
 	// We need to give it an expiry when creating it, so set the initial timeout to
@@ -371,10 +378,3 @@
 		return b
 	}
 }
-
-func (c *Conn) newConnIDFunc() newConnIDFunc {
-	if c.testHooks != nil {
-		return c.testHooks.newConnID
-	}
-	return newRandomConnID
-}
diff --git a/internal/quic/conn_id.go b/internal/quic/conn_id.go
index 561dea2..eb2f3ec 100644
--- a/internal/quic/conn_id.go
+++ b/internal/quic/conn_id.go
@@ -55,10 +55,10 @@
 	send sentVal
 }
 
-func (s *connIDState) initClient(newID newConnIDFunc) error {
+func (s *connIDState) initClient(c *Conn) error {
 	// Client chooses its initial connection ID, and sends it
 	// in the Source Connection ID field of the first Initial packet.
-	locid, err := newID(0)
+	locid, err := c.newConnID(0)
 	if err != nil {
 		return err
 	}
@@ -70,7 +70,7 @@
 
 	// Client chooses an initial, transient connection ID for the server,
 	// and sends it in the Destination Connection ID field of the first Initial packet.
-	remid, err := newID(-1)
+	remid, err := c.newConnID(-1)
 	if err != nil {
 		return err
 	}
@@ -78,10 +78,12 @@
 		seq: -1,
 		cid: remid,
 	})
+	const retired = false
+	c.listener.connIDsChanged(c, retired, s.local[:])
 	return nil
 }
 
-func (s *connIDState) initServer(newID newConnIDFunc, dstConnID []byte) error {
+func (s *connIDState) initServer(c *Conn, dstConnID []byte) error {
 	// Client-chosen, transient connection ID received in the first Initial packet.
 	// The server will not use this as the Source Connection ID of packets it sends,
 	// but remembers it because it may receive packets sent to this destination.
@@ -92,7 +94,7 @@
 
 	// Server chooses a connection ID, and sends it in the Source Connection ID of
 	// the response to the clent.
-	locid, err := newID(0)
+	locid, err := c.newConnID(0)
 	if err != nil {
 		return err
 	}
@@ -101,6 +103,8 @@
 		cid: locid,
 	})
 	s.nextLocalSeq = 1
+	const retired = false
+	c.listener.connIDsChanged(c, retired, s.local[:])
 	return nil
 }
 
@@ -125,20 +129,21 @@
 
 // setPeerActiveConnIDLimit sets the active_connection_id_limit
 // transport parameter received from the peer.
-func (s *connIDState) setPeerActiveConnIDLimit(lim int64, newID newConnIDFunc) error {
+func (s *connIDState) setPeerActiveConnIDLimit(c *Conn, lim int64) error {
 	s.peerActiveConnIDLimit = lim
-	return s.issueLocalIDs(newID)
+	return s.issueLocalIDs(c)
 }
 
-func (s *connIDState) issueLocalIDs(newID newConnIDFunc) error {
+func (s *connIDState) issueLocalIDs(c *Conn) error {
 	toIssue := min(int(s.peerActiveConnIDLimit), maxPeerActiveConnIDLimit)
 	for i := range s.local {
 		if s.local[i].seq != -1 && !s.local[i].retired {
 			toIssue--
 		}
 	}
+	prev := len(s.local)
 	for toIssue > 0 {
-		cid, err := newID(s.nextLocalSeq)
+		cid, err := c.newConnID(s.nextLocalSeq)
 		if err != nil {
 			return err
 		}
@@ -151,14 +156,16 @@
 		s.needSend = true
 		toIssue--
 	}
+	const retired = false
+	c.listener.connIDsChanged(c, retired, s.local[prev:])
 	return nil
 }
 
 // handlePacket updates the connection ID state during the handshake
 // (Initial and Handshake packets).
-func (s *connIDState) handlePacket(side connSide, ptype packetType, srcConnID []byte) {
+func (s *connIDState) handlePacket(c *Conn, ptype packetType, srcConnID []byte) {
 	switch {
-	case ptype == packetTypeInitial && side == clientSide:
+	case ptype == packetTypeInitial && c.side == clientSide:
 		if len(s.remote) == 1 && s.remote[0].seq == -1 {
 			// We're a client connection processing the first Initial packet
 			// from the server. Replace the transient remote connection ID
@@ -168,7 +175,7 @@
 				cid: cloneBytes(srcConnID),
 			}
 		}
-	case ptype == packetTypeInitial && side == serverSide:
+	case ptype == packetTypeInitial && c.side == serverSide:
 		if len(s.remote) == 0 {
 			// We're a server connection processing the first Initial packet
 			// from the client. Set the client's connection ID.
@@ -177,11 +184,13 @@
 				cid: cloneBytes(srcConnID),
 			})
 		}
-	case ptype == packetTypeHandshake && side == serverSide:
+	case ptype == packetTypeHandshake && c.side == serverSide:
 		if len(s.local) > 0 && s.local[0].seq == -1 {
 			// We're a server connection processing the first Handshake packet from
 			// the client. Discard the transient, client-chosen connection ID used
 			// for Initial packets; the client will never send it again.
+			const retired = true
+			c.listener.connIDsChanged(c, retired, s.local[0:1])
 			s.local = append(s.local[:0], s.local[1:]...)
 		}
 	}
@@ -263,17 +272,19 @@
 	s.needSend = true
 }
 
-func (s *connIDState) handleRetireConnID(seq int64, newID newConnIDFunc) error {
+func (s *connIDState) handleRetireConnID(c *Conn, seq int64) error {
 	if seq >= s.nextLocalSeq {
 		return localTransportError(errProtocolViolation)
 	}
 	for i := range s.local {
 		if s.local[i].seq == seq {
+			const retired = true
+			c.listener.connIDsChanged(c, retired, s.local[i:i+1])
 			s.local = append(s.local[:i], s.local[i+1:]...)
 			break
 		}
 	}
-	s.issueLocalIDs(newID)
+	s.issueLocalIDs(c)
 	return nil
 }
 
@@ -355,7 +366,12 @@
 	return n
 }
 
-type newConnIDFunc func(seq int64) ([]byte, error)
+func (c *Conn) newConnID(seq int64) ([]byte, error) {
+	if c.testHooks != nil {
+		return c.testHooks.newConnID(seq)
+	}
+	return newRandomConnID(seq)
+}
 
 func newRandomConnID(_ int64) ([]byte, error) {
 	// It is not necessary for connection IDs to be cryptographically secure,
diff --git a/internal/quic/conn_id_test.go b/internal/quic/conn_id_test.go
index d479cd4..c528958 100644
--- a/internal/quic/conn_id_test.go
+++ b/internal/quic/conn_id_test.go
@@ -11,100 +11,135 @@
 	"crypto/tls"
 	"fmt"
 	"net/netip"
-	"reflect"
+	"strings"
 	"testing"
 )
 
 func TestConnIDClientHandshake(t *testing.T) {
+	tc := newTestConn(t, clientSide)
 	// On initialization, the client chooses local and remote IDs.
 	//
 	// The order in which we allocate the two isn't actually important,
 	// but test is a lot simpler if we assume.
-	var s connIDState
-	s.initClient(newConnIDSequence())
-	if got, want := string(s.srcConnID()), "local-1"; got != want {
-		t.Errorf("after initClient: srcConnID = %q, want %q", got, want)
+	if got, want := tc.conn.connIDState.srcConnID(), testLocalConnID(0); !bytes.Equal(got, want) {
+		t.Errorf("after initialization: srcConnID = %x, want %x", got, want)
 	}
-	dstConnID, _ := s.dstConnID()
-	if got, want := string(dstConnID), "local-2"; got != want {
-		t.Errorf("after initClient: dstConnID = %q, want %q", got, want)
+	dstConnID, _ := tc.conn.connIDState.dstConnID()
+	if got, want := dstConnID, testLocalConnID(-1); !bytes.Equal(got, want) {
+		t.Errorf("after initialization: dstConnID = %x, want %x", got, want)
 	}
 
 	// The server's first Initial packet provides the client with a
 	// non-transient remote connection ID.
-	s.handlePacket(clientSide, packetTypeInitial, []byte("remote-1"))
-	dstConnID, _ = s.dstConnID()
-	if got, want := string(dstConnID), "remote-1"; got != want {
-		t.Errorf("after receiving Initial: dstConnID = %q, want %q", got, want)
+	tc.writeFrames(packetTypeInitial,
+		debugFrameCrypto{
+			data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial],
+		})
+	dstConnID, _ = tc.conn.connIDState.dstConnID()
+	if got, want := dstConnID, testPeerConnID(0); !bytes.Equal(got, want) {
+		t.Errorf("after receiving Initial: dstConnID = %x, want %x", got, want)
 	}
 
 	wantLocal := []connID{{
-		cid: []byte("local-1"),
+		cid: testLocalConnID(0),
 		seq: 0,
 	}}
-	if !reflect.DeepEqual(s.local, wantLocal) {
-		t.Errorf("local ids: %v, want %v", s.local, wantLocal)
+	if got := tc.conn.connIDState.local; !connIDListEqual(got, wantLocal) {
+		t.Errorf("local ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantLocal))
 	}
 	wantRemote := []connID{{
-		cid: []byte("remote-1"),
+		cid: testPeerConnID(0),
 		seq: 0,
 	}}
-	if !reflect.DeepEqual(s.remote, wantRemote) {
-		t.Errorf("remote ids: %v, want %v", s.remote, wantRemote)
+	if got := tc.conn.connIDState.remote; !connIDListEqual(got, wantRemote) {
+		t.Errorf("remote ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantRemote))
 	}
 }
 
 func TestConnIDServerHandshake(t *testing.T) {
+	tc := newTestConn(t, serverSide)
 	// On initialization, the server is provided with the client-chosen
 	// transient connection ID, and allocates an ID of its own.
 	// The Initial packet sets the remote connection ID.
-	var s connIDState
-	s.initServer(newConnIDSequence(), []byte("transient"))
-	s.handlePacket(serverSide, packetTypeInitial, []byte("remote-1"))
-	if got, want := string(s.srcConnID()), "local-1"; got != want {
+	tc.writeFrames(packetTypeInitial,
+		debugFrameCrypto{
+			data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial][:1],
+		})
+	if got, want := tc.conn.connIDState.srcConnID(), testLocalConnID(0); !bytes.Equal(got, want) {
 		t.Errorf("after initClient: srcConnID = %q, want %q", got, want)
 	}
-	dstConnID, _ := s.dstConnID()
-	if got, want := string(dstConnID), "remote-1"; got != want {
+	dstConnID, _ := tc.conn.connIDState.dstConnID()
+	if got, want := dstConnID, testPeerConnID(0); !bytes.Equal(got, want) {
 		t.Errorf("after initClient: dstConnID = %q, want %q", got, want)
 	}
 
+	// The Initial flight of CRYPTO data includes transport parameters,
+	// which cause us to allocate another local connection ID.
+	tc.writeFrames(packetTypeInitial,
+		debugFrameCrypto{
+			off:  1,
+			data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial][1:],
+		})
 	wantLocal := []connID{{
-		cid: []byte("transient"),
+		cid: testPeerConnID(-1),
 		seq: -1,
 	}, {
-		cid: []byte("local-1"),
+		cid: testLocalConnID(0),
 		seq: 0,
+	}, {
+		cid: testLocalConnID(1),
+		seq: 1,
 	}}
-	if !reflect.DeepEqual(s.local, wantLocal) {
-		t.Errorf("local ids: %v, want %v", s.local, wantLocal)
+	if got := tc.conn.connIDState.local; !connIDListEqual(got, wantLocal) {
+		t.Errorf("local ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantLocal))
 	}
 	wantRemote := []connID{{
-		cid: []byte("remote-1"),
+		cid: testPeerConnID(0),
 		seq: 0,
 	}}
-	if !reflect.DeepEqual(s.remote, wantRemote) {
-		t.Errorf("remote ids: %v, want %v", s.remote, wantRemote)
+	if got := tc.conn.connIDState.remote; !connIDListEqual(got, wantRemote) {
+		t.Errorf("remote ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantRemote))
 	}
 
 	// The client's first Handshake packet permits the server to discard the
 	// transient connection ID.
-	s.handlePacket(serverSide, packetTypeHandshake, []byte("remote-1"))
+	tc.writeFrames(packetTypeHandshake,
+		debugFrameCrypto{
+			data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake],
+		})
 	wantLocal = []connID{{
-		cid: []byte("local-1"),
+		cid: testLocalConnID(0),
 		seq: 0,
+	}, {
+		cid: testLocalConnID(1),
+		seq: 1,
 	}}
-	if !reflect.DeepEqual(s.local, wantLocal) {
-		t.Errorf("after handshake local ids: %v, want %v", s.local, wantLocal)
+	if got := tc.conn.connIDState.local; !connIDListEqual(got, wantLocal) {
+		t.Errorf("local ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantLocal))
 	}
 }
 
-func newConnIDSequence() newConnIDFunc {
-	var n uint64
-	return func(_ int64) ([]byte, error) {
-		n++
-		return []byte(fmt.Sprintf("local-%v", n)), nil
+func connIDListEqual(a, b []connID) bool {
+	if len(a) != len(b) {
+		return false
 	}
+	for i := range a {
+		if a[i].seq != b[i].seq {
+			return false
+		}
+		if !bytes.Equal(a[i].cid, b[i].cid) {
+			return false
+		}
+	}
+	return true
+}
+
+func fmtConnIDList(s []connID) string {
+	var strs []string
+	for _, cid := range s {
+		strs = append(strs, fmt.Sprintf("[seq:%v cid:{%x}]", cid.seq, cid.cid))
+	}
+	return "{" + strings.Join(strs, " ") + "}"
 }
 
 func TestNewRandomConnID(t *testing.T) {
diff --git a/internal/quic/conn_recv.go b/internal/quic/conn_recv.go
index 07f17e3..b866d8a 100644
--- a/internal/quic/conn_recv.go
+++ b/internal/quic/conn_recv.go
@@ -63,7 +63,7 @@
 	if logPackets {
 		logInboundLongPacket(c, p)
 	}
-	c.connIDState.handlePacket(c.side, p.ptype, p.srcConnID)
+	c.connIDState.handlePacket(c, p.ptype, p.srcConnID)
 	ackEliciting := c.handleFrames(now, ptype, space, p.payload)
 	c.acks[space].receive(now, space, p.num, ackEliciting)
 	if p.ptype == packetTypeHandshake && c.side == serverSide {
@@ -377,7 +377,7 @@
 	if n < 0 {
 		return -1
 	}
-	if err := c.connIDState.handleRetireConnID(seq, c.newConnIDFunc()); err != nil {
+	if err := c.connIDState.handleRetireConnID(c, seq); err != nil {
 		c.abort(now, err)
 	}
 	return n
diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go
index ea720d5..cdbd466 100644
--- a/internal/quic/conn_test.go
+++ b/internal/quic/conn_test.go
@@ -13,7 +13,9 @@
 	"errors"
 	"flag"
 	"fmt"
+	"io"
 	"math"
+	"net"
 	"net/netip"
 	"reflect"
 	"strings"
@@ -105,6 +107,7 @@
 type testConn struct {
 	t              *testing.T
 	conn           *Conn
+	listener       *Listener
 	now            time.Time
 	timer          time.Time
 	timerLastFired time.Time
@@ -142,6 +145,8 @@
 	sentFrames    []debugFrame
 	lastPacket    *testPacket
 
+	recvDatagram chan *datagram
+
 	// Transport parameters sent by the conn.
 	sentTransportParameters *transportParameters
 
@@ -173,6 +178,7 @@
 		},
 		cryptoDataOut: make(map[tls.QUICEncryptionLevel][]byte),
 		cryptoDataIn:  make(map[tls.QUICEncryptionLevel][]byte),
+		recvDatagram:  make(chan *datagram),
 	}
 	t.Cleanup(tc.cleanup)
 
@@ -196,12 +202,7 @@
 	var initialConnID []byte
 	if side == serverSide {
 		// The initial connection ID for the server is chosen by the client.
-		// When creating a server-side connection, pick a random connection ID here.
-		var err error
-		initialConnID, err = newRandomConnID(0)
-		if err != nil {
-			tc.t.Fatal(err)
-		}
+		initialConnID = testPeerConnID(-1)
 	}
 
 	peerQUICConfig := &tls.QUICConfig{TLSConfig: newTestTLSConfig(side.peer())}
@@ -213,14 +214,12 @@
 	tc.peerTLSConn.SetTransportParameters(marshalTransportParameters(peerProvidedParams))
 	tc.peerTLSConn.Start(context.Background())
 
-	conn, err := newConn(
+	tc.listener = newListener((*testConnUDPConn)(tc), config, (*testConnHooks)(tc))
+	conn, err := tc.listener.newConn(
 		tc.now,
 		side,
 		initialConnID,
-		netip.MustParseAddrPort("127.0.0.1:443"),
-		config,
-		(*testConnListener)(tc),
-		(*testConnHooks)(tc))
+		netip.MustParseAddrPort("127.0.0.1:443"))
 	if err != nil {
 		tc.t.Fatal(err)
 	}
@@ -316,6 +315,7 @@
 		return
 	}
 	tc.conn.exit()
+	tc.listener.Close(context.Background())
 }
 
 func (tc *testConn) logDatagram(text string, d *testDatagram) {
@@ -844,6 +844,10 @@
 	return testLocalConnID(seq), nil
 }
 
+func (tc *testConnHooks) timeNow() time.Time {
+	return tc.now
+}
+
 // testLocalConnID returns the connection ID with a given sequence number
 // used by a Conn under test.
 func testLocalConnID(seq int64) []byte {
@@ -861,14 +865,31 @@
 	return []byte{0xbe, 0xee, 0xff, byte(seq)}
 }
 
-// testConnListener implements connListener.
-type testConnListener testConn
+// testConnUDPConn implements UDPConn.
+type testConnUDPConn testConn
 
-func (tc *testConnListener) sendDatagram(p []byte, addr netip.AddrPort) error {
-	tc.sentDatagrams = append(tc.sentDatagrams, append([]byte(nil), p...))
+func (tc *testConnUDPConn) Close() error {
+	close(tc.recvDatagram)
 	return nil
 }
 
+func (tc *testConnUDPConn) LocalAddr() net.Addr {
+	return net.UDPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:443"))
+}
+
+func (tc *testConnUDPConn) ReadMsgUDPAddrPort(b, control []byte) (n, controln, flags int, _ netip.AddrPort, _ error) {
+	for d := range tc.recvDatagram {
+		n = copy(b, d.b)
+		return n, 0, 0, d.addr, nil
+	}
+	return 0, 0, 0, netip.AddrPort{}, io.EOF
+}
+
+func (tc *testConnUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
+	tc.sentDatagrams = append(tc.sentDatagrams, append([]byte(nil), b...))
+	return len(b), nil
+}
+
 // canceledContext returns a canceled Context.
 //
 // Functions which take a context preference progress over cancelation.
diff --git a/internal/quic/listener.go b/internal/quic/listener.go
new file mode 100644
index 0000000..9869f6e
--- /dev/null
+++ b/internal/quic/listener.go
@@ -0,0 +1,280 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+	"context"
+	"errors"
+	"net"
+	"net/netip"
+	"sync"
+	"sync/atomic"
+	"time"
+)
+
+// A Listener listens for QUIC traffic on a network address.
+// It can accept inbound connections or create outbound ones.
+//
+// Multiple goroutines may invoke methods on a Listener simultaneously.
+type Listener struct {
+	config    *Config
+	udpConn   udpConn
+	testHooks connTestHooks
+
+	acceptQueue queue[*Conn] // new inbound connections
+
+	connsMu sync.Mutex
+	conns   map[*Conn]struct{}
+	closing bool          // set when Close is called
+	closec  chan struct{} // closed when the listen loop exits
+
+	// The datagram receive loop keeps a mapping of connection IDs to conns.
+	// When a conn's connection IDs change, we add it to connIDUpdates and set
+	// connIDUpdateNeeded, and the receive loop updates its map.
+	connIDUpdateMu     sync.Mutex
+	connIDUpdateNeeded atomic.Bool
+	connIDUpdates      []connIDUpdate
+}
+
+// A udpConn is a UDP connection.
+// It is implemented by net.UDPConn.
+type udpConn interface {
+	Close() error
+	LocalAddr() net.Addr
+	ReadMsgUDPAddrPort(b, control []byte) (n, controln, flags int, _ netip.AddrPort, _ error)
+	WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error)
+}
+
+type connIDUpdate struct {
+	conn    *Conn
+	retired bool
+	cid     []byte
+}
+
+// Listen listens on a local network address.
+// The configuration config must be non-nil.
+func Listen(network, address string, config *Config) (*Listener, error) {
+	if config.TLSConfig == nil {
+		return nil, errors.New("TLSConfig is not set")
+	}
+	a, err := net.ResolveUDPAddr(network, address)
+	if err != nil {
+		return nil, err
+	}
+	udpConn, err := net.ListenUDP(network, a)
+	if err != nil {
+		return nil, err
+	}
+	return newListener(udpConn, config, nil), nil
+}
+
+func newListener(udpConn udpConn, config *Config, hooks connTestHooks) *Listener {
+	l := &Listener{
+		config:      config,
+		udpConn:     udpConn,
+		testHooks:   hooks,
+		conns:       make(map[*Conn]struct{}),
+		acceptQueue: newQueue[*Conn](),
+		closec:      make(chan struct{}),
+	}
+	go l.listen()
+	return l
+}
+
+// LocalAddr returns the local network address.
+func (l *Listener) LocalAddr() netip.AddrPort {
+	a, _ := l.udpConn.LocalAddr().(*net.UDPAddr)
+	return a.AddrPort()
+}
+
+// Close closes the listener.
+// Any blocked operations on the Listener or associated Conns and Stream will be unblocked
+// and return errors.
+//
+// Close aborts every open connection.
+// Data in stream read and write buffers is discarded.
+// It waits for the peers of any open connection to acknowledge the connection has been closed.
+func (l *Listener) Close(ctx context.Context) error {
+	l.acceptQueue.close(errors.New("listener closed"))
+	l.connsMu.Lock()
+	if !l.closing {
+		l.closing = true
+		for c := range l.conns {
+			c.Close()
+		}
+		if len(l.conns) == 0 {
+			l.udpConn.Close()
+		}
+	}
+	l.connsMu.Unlock()
+	select {
+	case <-l.closec:
+	case <-ctx.Done():
+		l.connsMu.Lock()
+		for c := range l.conns {
+			c.exit()
+		}
+		l.connsMu.Unlock()
+		return ctx.Err()
+	}
+	return nil
+}
+
+// Accept waits for and returns the next connection to the listener.
+func (l *Listener) Accept(ctx context.Context) (*Conn, error) {
+	return l.acceptQueue.get(ctx, nil)
+}
+
+// Dial creates and returns a connection to a network address.
+func (l *Listener) Dial(ctx context.Context, network, address string) (*Conn, error) {
+	u, err := net.ResolveUDPAddr(network, address)
+	if err != nil {
+		return nil, err
+	}
+	addr := u.AddrPort()
+	addr = netip.AddrPortFrom(addr.Addr().Unmap(), addr.Port())
+	c, err := l.newConn(time.Now(), clientSide, nil, addr)
+	if err != nil {
+		return nil, err
+	}
+	select {
+	case <-c.readyc:
+	case <-ctx.Done():
+		c.Close()
+		return nil, ctx.Err()
+	}
+	return c, nil
+}
+
+func (l *Listener) newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip.AddrPort) (*Conn, error) {
+	l.connsMu.Lock()
+	defer l.connsMu.Unlock()
+	if l.closing {
+		return nil, errors.New("listener closed")
+	}
+	c, err := newConn(now, side, initialConnID, peerAddr, l.config, l, l.testHooks)
+	if err != nil {
+		return nil, err
+	}
+	l.conns[c] = struct{}{}
+	return c, nil
+}
+
+// serverConnEstablished is called by a conn when the handshake completes
+// for an inbound (serverSide) connection.
+func (l *Listener) serverConnEstablished(c *Conn) {
+	l.acceptQueue.put(c)
+}
+
+// connDrained is called by a conn when it leaves the draining state,
+// either when the peer acknowledges connection closure or the drain timeout expires.
+func (l *Listener) connDrained(c *Conn) {
+	l.connsMu.Lock()
+	defer l.connsMu.Unlock()
+	delete(l.conns, c)
+	if l.closing && len(l.conns) == 0 {
+		l.udpConn.Close()
+	}
+}
+
+// connIDsChanged is called by a conn when its connection IDs change.
+func (l *Listener) connIDsChanged(c *Conn, retired bool, cids []connID) {
+	l.connIDUpdateMu.Lock()
+	defer l.connIDUpdateMu.Unlock()
+	for _, cid := range cids {
+		l.connIDUpdates = append(l.connIDUpdates, connIDUpdate{
+			conn:    c,
+			retired: retired,
+			cid:     cid.cid,
+		})
+	}
+	l.connIDUpdateNeeded.Store(true)
+}
+
+// updateConnIDs is called by the datagram receive loop to update its connection ID map.
+func (l *Listener) updateConnIDs(conns map[string]*Conn) {
+	l.connIDUpdateMu.Lock()
+	defer l.connIDUpdateMu.Unlock()
+	for i, u := range l.connIDUpdates {
+		if u.retired {
+			delete(conns, string(u.cid))
+		} else {
+			conns[string(u.cid)] = u.conn
+		}
+		l.connIDUpdates[i] = connIDUpdate{} // drop refs
+	}
+	l.connIDUpdates = l.connIDUpdates[:0]
+	l.connIDUpdateNeeded.Store(false)
+}
+
+func (l *Listener) listen() {
+	defer close(l.closec)
+	conns := map[string]*Conn{}
+	for {
+		m := newDatagram()
+		// TODO: Read and process the ECN (explicit congestion notification) field.
+		// https://tools.ietf.org/html/draft-ietf-quic-transport-32#section-13.4
+		n, _, _, addr, err := l.udpConn.ReadMsgUDPAddrPort(m.b, nil)
+		if err != nil {
+			// The user has probably closed the listener.
+			// We currently don't surface errors from other causes;
+			// we could check to see if the listener has been closed and
+			// record the unexpected error if it has not.
+			return
+		}
+		if n == 0 {
+			continue
+		}
+		if l.connIDUpdateNeeded.Load() {
+			l.updateConnIDs(conns)
+		}
+		m.addr = addr
+		m.b = m.b[:n]
+		l.handleDatagram(m, conns)
+	}
+}
+
+func (l *Listener) handleDatagram(m *datagram, conns map[string]*Conn) {
+	dstConnID, ok := dstConnIDForDatagram(m.b)
+	if !ok {
+		return
+	}
+	c := conns[string(dstConnID)]
+	if c == nil {
+		if getPacketType(m.b) != packetTypeInitial {
+			// This packet isn't trying to create a new connection.
+			// It might be associated with some connection we've lost state for.
+			// TODO: Send a stateless reset when appropriate.
+			// https://www.rfc-editor.org/rfc/rfc9000.html#section-10.3
+			return
+		}
+		var now time.Time
+		if l.testHooks != nil {
+			now = l.testHooks.timeNow()
+		} else {
+			now = time.Now()
+		}
+		var err error
+		c, err = l.newConn(now, serverSide, dstConnID, m.addr)
+		if err != nil {
+			// The accept queue is probably full.
+			// We could send a CONNECTION_CLOSE to the peer to reject the connection.
+			// Currently, we just drop the datagram.
+			// https://www.rfc-editor.org/rfc/rfc9000.html#section-5.2.2-5
+			return
+		}
+	}
+
+	// TODO: This can block the listener while waiting for the conn to accept the dgram.
+	// Think about buffering between the receive loop and the conn.
+	c.sendMsg(m)
+}
+
+func (l *Listener) sendDatagram(p []byte, addr netip.AddrPort) error {
+	_, err := l.udpConn.WriteToUDPAddrPort(p, addr)
+	return err
+}
diff --git a/internal/quic/listener_test.go b/internal/quic/listener_test.go
new file mode 100644
index 0000000..a6e0b34
--- /dev/null
+++ b/internal/quic/listener_test.go
@@ -0,0 +1,88 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+	"bytes"
+	"context"
+	"io"
+	"testing"
+)
+
+func TestConnect(t *testing.T) {
+	newLocalConnPair(t, &Config{}, &Config{})
+}
+
+func TestStreamTransfer(t *testing.T) {
+	ctx := context.Background()
+	cli, srv := newLocalConnPair(t, &Config{}, &Config{})
+	data := makeTestData(1 << 20)
+
+	srvdone := make(chan struct{})
+	go func() {
+		defer close(srvdone)
+		s, err := srv.AcceptStream(ctx)
+		if err != nil {
+			t.Errorf("AcceptStream: %v", err)
+			return
+		}
+		b, err := io.ReadAll(s)
+		if err != nil {
+			t.Errorf("io.ReadAll(s): %v", err)
+			return
+		}
+		if !bytes.Equal(b, data) {
+			t.Errorf("read data mismatch (got %v bytes, want %v", len(b), len(data))
+		}
+		if err := s.Close(); err != nil {
+			t.Errorf("s.Close() = %v", err)
+		}
+	}()
+
+	s, err := cli.NewStream(ctx)
+	if err != nil {
+		t.Fatalf("NewStream: %v", err)
+	}
+	n, err := io.Copy(s, bytes.NewBuffer(data))
+	if n != int64(len(data)) || err != nil {
+		t.Fatalf("io.Copy(s, data) = %v, %v; want %v, nil", n, err, len(data))
+	}
+	if err := s.Close(); err != nil {
+		t.Fatalf("s.Close() = %v", err)
+	}
+}
+
+func newLocalConnPair(t *testing.T, conf1, conf2 *Config) (clientConn, serverConn *Conn) {
+	t.Helper()
+	ctx := context.Background()
+	l1 := newLocalListener(t, serverSide, conf1)
+	l2 := newLocalListener(t, clientSide, conf2)
+	c2, err := l2.Dial(ctx, "udp", l1.LocalAddr().String())
+	if err != nil {
+		t.Fatal(err)
+	}
+	c1, err := l1.Accept(ctx)
+	if err != nil {
+		t.Fatal(err)
+	}
+	return c2, c1
+}
+
+func newLocalListener(t *testing.T, side connSide, conf *Config) *Listener {
+	t.Helper()
+	if conf.TLSConfig == nil {
+		conf.TLSConfig = newTestTLSConfig(side)
+	}
+	l, err := Listen("udp", "127.0.0.1:0", conf)
+	if err != nil {
+		t.Fatal(err)
+	}
+	t.Cleanup(func() {
+		l.Close(context.Background())
+	})
+	return l
+}
diff --git a/internal/quic/tls.go b/internal/quic/tls.go
index 584316f..1d07f17 100644
--- a/internal/quic/tls.go
+++ b/internal/quic/tls.go
@@ -73,6 +73,7 @@
 				// https://www.rfc-editor.org/rfc/rfc9001#section-4.1.2-1
 				c.confirmHandshake(now)
 			}
+			close(c.readyc)
 		case tls.QUICTransportParameters:
 			params, err := unmarshalTransportParams(e.Data)
 			if err != nil {