blob: 31e789b1dabee4f798872bad03716c3f8057b815 [file] [log] [blame] [edit]
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.21
package quic
import (
"context"
"crypto/tls"
"errors"
"fmt"
"log/slog"
"net/netip"
"time"
)
// A Conn is a QUIC connection.
//
// Multiple goroutines may invoke methods on a Conn simultaneously.
type Conn struct {
side connSide
endpoint *Endpoint
config *Config
testHooks connTestHooks
peerAddr netip.AddrPort
msgc chan any
donec chan struct{} // closed when conn loop exits
w packetWriter
acks [numberSpaceCount]ackState // indexed by number space
lifetime lifetimeState
idle idleState
connIDState connIDState
loss lossState
streams streamsState
// Packet protection keys, CRYPTO streams, and TLS state.
keysInitial fixedKeyPair
keysHandshake fixedKeyPair
keysAppData updatingKeyPair
crypto [numberSpaceCount]cryptoStream
tls *tls.QUICConn
// retryToken is the token provided by the peer in a Retry packet.
retryToken []byte
// handshakeConfirmed is set when the handshake is confirmed.
// For server connections, it tracks sending HANDSHAKE_DONE.
handshakeConfirmed sentVal
peerAckDelayExponent int8 // -1 when unknown
// Tests only: Send a PING in a specific number space.
testSendPingSpace numberSpace
testSendPing sentVal
log *slog.Logger
}
// connTestHooks override conn behavior in tests.
type connTestHooks interface {
// init is called after a conn is created.
init()
// nextMessage is called to request the next event from msgc.
// Used to give tests control of the connection event loop.
nextMessage(msgc chan any, nextTimeout time.Time) (now time.Time, message any)
// handleTLSEvent is called with each TLS event.
handleTLSEvent(tls.QUICEvent)
// newConnID is called to generate a new connection ID.
// Permits tests to generate consistent connection IDs rather than random ones.
newConnID(seq int64) ([]byte, error)
// waitUntil blocks until the until func returns true or the context is done.
// Used to synchronize asynchronous blocking operations in tests.
waitUntil(ctx context.Context, until func() bool) error
// timeNow returns the current time.
timeNow() time.Time
}
// newServerConnIDs is connection IDs associated with a new server connection.
type newServerConnIDs struct {
srcConnID []byte // source from client's current Initial
dstConnID []byte // destination from client's current Initial
originalDstConnID []byte // destination from client's first Initial
retrySrcConnID []byte // source from server's Retry
}
func newConn(now time.Time, side connSide, cids newServerConnIDs, peerAddr netip.AddrPort, config *Config, e *Endpoint) (conn *Conn, _ error) {
c := &Conn{
side: side,
endpoint: e,
config: config,
peerAddr: peerAddr,
msgc: make(chan any, 1),
donec: make(chan struct{}),
peerAckDelayExponent: -1,
}
defer func() {
// If we hit an error in newConn, close donec so tests don't get stuck waiting for it.
// This is only relevant if we've got a bug, but it makes tracking that bug down
// much easier.
if conn == nil {
close(c.donec)
}
}()
// A one-element buffer allows us to wake a Conn's event loop as a
// non-blocking operation.
c.msgc = make(chan any, 1)
if e.testHooks != nil {
e.testHooks.newConn(c)
}
// initialConnID is the connection ID used to generate Initial packet protection keys.
var initialConnID []byte
if c.side == clientSide {
if err := c.connIDState.initClient(c); err != nil {
return nil, err
}
initialConnID, _ = c.connIDState.dstConnID()
} else {
initialConnID = cids.originalDstConnID
if cids.retrySrcConnID != nil {
initialConnID = cids.retrySrcConnID
}
if err := c.connIDState.initServer(c, cids); err != nil {
return nil, err
}
}
// TODO: PMTU discovery.
c.logConnectionStarted(cids.originalDstConnID, peerAddr)
c.keysAppData.init()
c.loss.init(c.side, smallestMaxDatagramSize, now)
c.streamsInit()
c.lifetimeInit()
c.restartIdleTimer(now)
if err := c.startTLS(now, initialConnID, transportParameters{
initialSrcConnID: c.connIDState.srcConnID(),
originalDstConnID: cids.originalDstConnID,
retrySrcConnID: cids.retrySrcConnID,
ackDelayExponent: ackDelayExponent,
maxUDPPayloadSize: maxUDPPayloadSize,
maxAckDelay: maxAckDelay,
disableActiveMigration: true,
initialMaxData: config.maxConnReadBufferSize(),
initialMaxStreamDataBidiLocal: config.maxStreamReadBufferSize(),
initialMaxStreamDataBidiRemote: config.maxStreamReadBufferSize(),
initialMaxStreamDataUni: config.maxStreamReadBufferSize(),
initialMaxStreamsBidi: c.streams.remoteLimit[bidiStream].max,
initialMaxStreamsUni: c.streams.remoteLimit[uniStream].max,
activeConnIDLimit: activeConnIDLimit,
}); err != nil {
return nil, err
}
if c.testHooks != nil {
c.testHooks.init()
}
go c.loop(now)
return c, nil
}
func (c *Conn) String() string {
return fmt.Sprintf("quic.Conn(%v,->%v)", c.side, c.peerAddr)
}
// confirmHandshake is called when the handshake is confirmed.
// https://www.rfc-editor.org/rfc/rfc9001#section-4.1.2
func (c *Conn) confirmHandshake(now time.Time) {
// If handshakeConfirmed is unset, the handshake is not confirmed.
// If it is unsent, the handshake is confirmed and we need to send a HANDSHAKE_DONE.
// If it is sent, we have sent a HANDSHAKE_DONE.
// If it is received, the handshake is confirmed and we do not need to send anything.
if c.handshakeConfirmed.isSet() {
return // already confirmed
}
if c.side == serverSide {
// When the server confirms the handshake, it sends a HANDSHAKE_DONE.
c.handshakeConfirmed.setUnsent()
c.endpoint.serverConnEstablished(c)
} else {
// The client never sends a HANDSHAKE_DONE, so we set handshakeConfirmed
// to the received state, indicating that the handshake is confirmed and we
// don't need to send anything.
c.handshakeConfirmed.setReceived()
}
c.restartIdleTimer(now)
c.loss.confirmHandshake()
// "An endpoint MUST discard its Handshake keys when the TLS handshake is confirmed"
// https://www.rfc-editor.org/rfc/rfc9001#section-4.9.2-1
c.discardKeys(now, handshakeSpace)
}
// discardKeys discards unused packet protection keys.
// https://www.rfc-editor.org/rfc/rfc9001#section-4.9
func (c *Conn) discardKeys(now time.Time, space numberSpace) {
switch space {
case initialSpace:
c.keysInitial.discard()
case handshakeSpace:
c.keysHandshake.discard()
}
c.loss.discardKeys(now, space)
}
// receiveTransportParameters applies transport parameters sent by the peer.
func (c *Conn) receiveTransportParameters(p transportParameters) error {
isRetry := c.retryToken != nil
if err := c.connIDState.validateTransportParameters(c, isRetry, p); err != nil {
return err
}
c.streams.outflow.setMaxData(p.initialMaxData)
c.streams.localLimit[bidiStream].setMax(p.initialMaxStreamsBidi)
c.streams.localLimit[uniStream].setMax(p.initialMaxStreamsUni)
c.streams.peerInitialMaxStreamDataBidiLocal = p.initialMaxStreamDataBidiLocal
c.streams.peerInitialMaxStreamDataRemote[bidiStream] = p.initialMaxStreamDataBidiRemote
c.streams.peerInitialMaxStreamDataRemote[uniStream] = p.initialMaxStreamDataUni
c.receivePeerMaxIdleTimeout(p.maxIdleTimeout)
c.peerAckDelayExponent = p.ackDelayExponent
c.loss.setMaxAckDelay(p.maxAckDelay)
if err := c.connIDState.setPeerActiveConnIDLimit(c, p.activeConnIDLimit); err != nil {
return err
}
if p.preferredAddrConnID != nil {
var (
seq int64 = 1 // sequence number of this conn id is 1
retirePriorTo int64 = 0 // retire nothing
resetToken [16]byte
)
copy(resetToken[:], p.preferredAddrResetToken)
if err := c.connIDState.handleNewConnID(c, seq, retirePriorTo, p.preferredAddrConnID, resetToken); err != nil {
return err
}
}
// TODO: stateless_reset_token
// TODO: max_udp_payload_size
// TODO: disable_active_migration
// TODO: preferred_address
return nil
}
type (
timerEvent struct{}
wakeEvent struct{}
)
var errIdleTimeout = errors.New("idle timeout")
// loop is the connection main loop.
//
// Except where otherwise noted, all connection state is owned by the loop goroutine.
//
// The loop processes messages from c.msgc and timer events.
// Other goroutines may examine or modify conn state by sending the loop funcs to execute.
func (c *Conn) loop(now time.Time) {
defer close(c.donec)
defer c.tls.Close()
defer c.endpoint.connDrained(c)
defer c.logConnectionClosed()
// The connection timer sends a message to the connection loop on expiry.
// We need to give it an expiry when creating it, so set the initial timeout to
// an arbitrary large value. The timer will be reset before this expires (and it
// isn't a problem if it does anyway). Skip creating the timer in tests which
// take control of the connection message loop.
var timer *time.Timer
var lastTimeout time.Time
hooks := c.testHooks
if hooks == nil {
timer = time.AfterFunc(1*time.Hour, func() {
c.sendMsg(timerEvent{})
})
defer timer.Stop()
}
for c.lifetime.state != connStateDone {
sendTimeout := c.maybeSend(now) // try sending
// Note that we only need to consider the ack timer for the App Data space,
// since the Initial and Handshake spaces always ack immediately.
nextTimeout := sendTimeout
nextTimeout = firstTime(nextTimeout, c.idle.nextTimeout)
if c.isAlive() {
nextTimeout = firstTime(nextTimeout, c.loss.timer)
nextTimeout = firstTime(nextTimeout, c.acks[appDataSpace].nextAck)
} else {
nextTimeout = firstTime(nextTimeout, c.lifetime.drainEndTime)
}
var m any
if hooks != nil {
// Tests only: Wait for the test to tell us to continue.
now, m = hooks.nextMessage(c.msgc, nextTimeout)
} else if !nextTimeout.IsZero() && nextTimeout.Before(now) {
// A connection timer has expired.
now = time.Now()
m = timerEvent{}
} else {
// Reschedule the connection timer if necessary
// and wait for the next event.
if !nextTimeout.Equal(lastTimeout) && !nextTimeout.IsZero() {
// Resetting a timer created with time.AfterFunc guarantees
// that the timer will run again. We might generate a spurious
// timer event under some circumstances, but that's okay.
timer.Reset(nextTimeout.Sub(now))
lastTimeout = nextTimeout
}
m = <-c.msgc
now = time.Now()
}
switch m := m.(type) {
case *datagram:
c.handleDatagram(now, m)
m.recycle()
case timerEvent:
// A connection timer has expired.
if c.idleAdvance(now) {
// The connection idle timer has expired.
c.abortImmediately(now, errIdleTimeout)
return
}
c.loss.advance(now, c.handleAckOrLoss)
if c.lifetimeAdvance(now) {
// The connection has completed the draining period,
// and may be shut down.
return
}
case wakeEvent:
// We're being woken up to try sending some frames.
case func(time.Time, *Conn):
// Send a func to msgc to run it on the main Conn goroutine
m(now, c)
default:
panic(fmt.Sprintf("quic: unrecognized conn message %T", m))
}
}
}
// sendMsg sends a message to the conn's loop.
// It does not wait for the message to be processed.
// The conn may close before processing the message, in which case it is lost.
func (c *Conn) sendMsg(m any) {
select {
case c.msgc <- m:
case <-c.donec:
}
}
// wake wakes up the conn's loop.
func (c *Conn) wake() {
select {
case c.msgc <- wakeEvent{}:
default:
}
}
// runOnLoop executes a function within the conn's loop goroutine.
func (c *Conn) runOnLoop(f func(now time.Time, c *Conn)) error {
donec := make(chan struct{})
c.sendMsg(func(now time.Time, c *Conn) {
defer close(donec)
f(now, c)
})
select {
case <-donec:
case <-c.donec:
return errors.New("quic: connection closed")
}
return nil
}
func (c *Conn) waitOnDone(ctx context.Context, ch <-chan struct{}) error {
if c.testHooks != nil {
return c.testHooks.waitUntil(ctx, func() bool {
select {
case <-ch:
return true
default:
}
return false
})
}
// Check the channel before the context.
// We always prefer to return results when available,
// even when provided with an already-canceled context.
select {
case <-ch:
return nil
default:
}
select {
case <-ch:
case <-ctx.Done():
return ctx.Err()
}
return nil
}
// firstTime returns the earliest non-zero time, or zero if both times are zero.
func firstTime(a, b time.Time) time.Time {
switch {
case a.IsZero():
return b
case b.IsZero():
return a
case a.Before(b):
return a
default:
return b
}
}