quic: skip packet numbers for optimistic ack defense
An "optimistic ACK attack" involves an attacker sending ACKs
for packets it hasn't received, causing the victim's
congestion controller to improperly send at a higher rate.
The standard defense against this attack is to skip the occasional
packet number, and to close the connection with an error if the
peer ACKs an unsent packet.
Implement this defense, increasing the gap between skipped
packet numbers as a connection's lifetime grows and correspondingly
the amount of work required on the part of the attacker.
Change-Id: I01f44f13367821b86af6535ffb69d380e2b4d7b7
Reviewed-on: https://go-review.googlesource.com/c/net/+/664298
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
Auto-Submit: Damien Neil <dneil@google.com>
diff --git a/quic/conn.go b/quic/conn.go
index ba2fd7e..b9ec0e4 100644
--- a/quic/conn.go
+++ b/quic/conn.go
@@ -6,10 +6,12 @@
import (
"context"
+ cryptorand "crypto/rand"
"crypto/tls"
"errors"
"fmt"
"log/slog"
+ "math/rand/v2"
"net/netip"
"time"
)
@@ -24,6 +26,7 @@
testHooks connTestHooks
peerAddr netip.AddrPort
localAddr netip.AddrPort
+ prng *rand.Rand
msgc chan any
donec chan struct{} // closed when conn loop exits
@@ -36,6 +39,7 @@
loss lossState
streams streamsState
path pathState
+ skip skipState
// Packet protection keys, CRYPTO streams, and TLS state.
keysInitial fixedKeyPair
@@ -136,6 +140,14 @@
}
}
+ // A per-conn ChaCha8 PRNG is probably more than we need,
+ // but at least it's fairly small.
+ var seed [32]byte
+ if _, err := cryptorand.Read(seed[:]); err != nil {
+ panic(err)
+ }
+ c.prng = rand.New(rand.NewChaCha8(seed))
+
// TODO: PMTU discovery.
c.logConnectionStarted(cids.originalDstConnID, peerAddr)
c.keysAppData.init()
@@ -143,6 +155,7 @@
c.streamsInit()
c.lifetimeInit()
c.restartIdleTimer(now)
+ c.skip.init(c)
if err := c.startTLS(now, initialConnID, peerHostname, transportParameters{
initialSrcConnID: c.connIDState.srcConnID(),
diff --git a/quic/conn_recv.go b/quic/conn_recv.go
index 753fbb9..a24fc36 100644
--- a/quic/conn_recv.go
+++ b/quic/conn_recv.go
@@ -421,15 +421,10 @@
func (c *Conn) handleAckFrame(now time.Time, space numberSpace, payload []byte) int {
c.loss.receiveAckStart()
largest, ackDelay, n := consumeAckFrame(payload, func(rangeIndex int, start, end packetNumber) {
- if end > c.loss.nextNumber(space) {
- // Acknowledgement of a packet we never sent.
- c.abort(now, localTransportError{
- code: errProtocolViolation,
- reason: "acknowledgement for unsent packet",
- })
+ if err := c.loss.receiveAckRange(now, space, rangeIndex, start, end, c.handleAckOrLoss); err != nil {
+ c.abort(now, err)
return
}
- c.loss.receiveAckRange(now, space, rangeIndex, start, end, c.handleAckOrLoss)
})
// Prior to receiving the peer's transport parameters, we cannot
// interpret the ACK Delay field because we don't know the ack_delay_exponent
diff --git a/quic/conn_send.go b/quic/conn_send.go
index 2b6d6f0..d6fb149 100644
--- a/quic/conn_send.go
+++ b/quic/conn_send.go
@@ -142,6 +142,10 @@
}
if sent := c.w.finish1RTTPacket(pnum, pnumMaxAcked, dstConnID, &c.keysAppData); sent != nil {
c.packetSent(now, appDataSpace, sent)
+ if c.skip.shouldSkip(pnum + 1) {
+ c.loss.skipNumber(now, appDataSpace)
+ c.skip.updateNumberSkip(c)
+ }
}
}
diff --git a/quic/conn_send_test.go b/quic/conn_send_test.go
index d16b093..c5cf936 100644
--- a/quic/conn_send_test.go
+++ b/quic/conn_send_test.go
@@ -66,7 +66,11 @@
// current packet and the max acked one is sufficiently large.
for want := maxAcked + 1; want < maxAcked+0x100; want++ {
p := recvPing()
- if p.num != want {
+ if p.num == want+1 {
+ // The conn skipped a packet number
+ // (defense against optimistic ACK attacks).
+ want++
+ } else if p.num != want {
t.Fatalf("received packet number %v, want %v", p.num, want)
}
gotPnumLen := int(p.header&0x03) + 1
diff --git a/quic/conn_streams_test.go b/quic/conn_streams_test.go
index c292f69..af3c1de 100644
--- a/quic/conn_streams_test.go
+++ b/quic/conn_streams_test.go
@@ -242,9 +242,7 @@
if p == nil {
break
}
- tc.writeFrames(packetType1RTT, debugFrameAck{
- ranges: []i64range[packetNumber]{{0, p.num}},
- })
+ tc.writeAckForLatest()
for _, f := range p.frames {
sf, ok := f.(debugFrameStream)
if !ok {
diff --git a/quic/loss.go b/quic/loss.go
index b89aabd..ffbf69d 100644
--- a/quic/loss.go
+++ b/quic/loss.go
@@ -178,6 +178,15 @@
return c.spaces[space].nextNum
}
+// skipPacketNumber skips a packet number as a defense against optimistic ACK attacks.
+func (c *lossState) skipNumber(now time.Time, space numberSpace) {
+ sent := newSentPacket()
+ sent.num = c.spaces[space].nextNum
+ sent.time = now
+ sent.state = sentPacketUnsent
+ c.spaces[space].add(sent)
+}
+
// packetSent records a sent packet.
func (c *lossState) packetSent(now time.Time, log *slog.Logger, space numberSpace, sent *sentPacket) {
sent.time = now
@@ -230,17 +239,20 @@
// receiveAckRange processes a range within an ACK frame.
// The ackf function is called for each newly-acknowledged packet.
-func (c *lossState) receiveAckRange(now time.Time, space numberSpace, rangeIndex int, start, end packetNumber, ackf func(numberSpace, *sentPacket, packetFate)) {
+func (c *lossState) receiveAckRange(now time.Time, space numberSpace, rangeIndex int, start, end packetNumber, ackf func(numberSpace, *sentPacket, packetFate)) error {
// Limit our range to the intersection of the ACK range and
// the in-flight packets we have state for.
if s := c.spaces[space].start(); start < s {
start = s
}
if e := c.spaces[space].end(); end > e {
- end = e
+ return localTransportError{
+ code: errProtocolViolation,
+ reason: "acknowledgement for unsent packet",
+ }
}
if start >= end {
- return
+ return nil
}
if rangeIndex == 0 {
// If the latest packet in the ACK frame is newly-acked,
@@ -252,6 +264,12 @@
}
for pnum := start; pnum < end; pnum++ {
sent := c.spaces[space].num(pnum)
+ if sent.state == sentPacketUnsent {
+ return localTransportError{
+ code: errProtocolViolation,
+ reason: "acknowledgement for unsent packet",
+ }
+ }
if sent.state != sentPacketSent {
continue
}
@@ -266,6 +284,7 @@
c.ackFrameContainsAckEliciting = true
}
}
+ return nil
}
// receiveAckEnd finishes processing an ack frame.
diff --git a/quic/sent_packet.go b/quic/sent_packet.go
index 457c50e..f67606b 100644
--- a/quic/sent_packet.go
+++ b/quic/sent_packet.go
@@ -38,9 +38,10 @@
type sentPacketState uint8
const (
- sentPacketSent = sentPacketState(iota) // sent but neither acked nor lost
- sentPacketAcked // acked
- sentPacketLost // declared lost
+ sentPacketSent = sentPacketState(iota) // sent but neither acked nor lost
+ sentPacketAcked // acked
+ sentPacketLost // declared lost
+ sentPacketUnsent // never sent
)
var sentPool = sync.Pool{
diff --git a/quic/skip.go b/quic/skip.go
new file mode 100644
index 0000000..f5ba764
--- /dev/null
+++ b/quic/skip.go
@@ -0,0 +1,62 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package quic
+
+// skipState is state for optimistic ACK defenses.
+//
+// An endpoint performs an optimistic ACK attack by sending acknowledgements for packets
+// which it has not received, potentially convincing the sender's congestion controller to
+// send at rates beyond what the network supports.
+//
+// We defend against this by periodically skipping packet numbers.
+// Receiving an ACK for an unsent packet number is a PROTOCOL_VIOLATION error.
+//
+// We only skip packet numbers in the Application Data number space.
+// The total data sent in the Initial/Handshake spaces should generally fit into
+// the initial congestion window.
+//
+// https://www.rfc-editor.org/rfc/rfc9000.html#section-21.4
+type skipState struct {
+ // skip is the next packet number (in the Application Data space) we should skip.
+ skip packetNumber
+
+ // maxSkip is the maximum number of packets to send before skipping another number.
+ // Increases over time.
+ maxSkip int64
+}
+
+func (ss *skipState) init(c *Conn) {
+ ss.maxSkip = 256 // skip our first packet number within this range
+ ss.updateNumberSkip(c)
+}
+
+// shouldSkipAfter returns whether we should skip the given packet number.
+func (ss *skipState) shouldSkip(num packetNumber) bool {
+ return ss.skip == num
+}
+
+// updateNumberSkip schedules a packet to be skipped after skipping lastSkipped.
+func (ss *skipState) updateNumberSkip(c *Conn) {
+ // Send at least this many packets before skipping.
+ // Limits the impact of skipping a little,
+ // plus allows most tests to ignore skipping.
+ const minSkip = 64
+
+ skip := minSkip + c.prng.Int64N(ss.maxSkip-minSkip)
+ ss.skip += packetNumber(skip)
+
+ // Double the size of the skip each time until we reach 128k.
+ // The idea here is that an attacker needs to correctly ack ~N packets in order
+ // to send an optimistic ack for another ~N packets.
+ // Skipping packet numbers comes with a small cost (it causes the receiver to
+ // send an immediate ACK rather than the usual delayed ACK), so we increase the
+ // time between skips as a connection's lifetime grows.
+ //
+ // The 128k cap is arbitrary, chosen so that we skip a packet number
+ // about once a second when sending full-size datagrams at 1Gbps.
+ if ss.maxSkip < 128*1024 {
+ ss.maxSkip *= 2
+ }
+}
diff --git a/quic/skip_test.go b/quic/skip_test.go
new file mode 100644
index 0000000..1fcb735
--- /dev/null
+++ b/quic/skip_test.go
@@ -0,0 +1,81 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package quic
+
+import "testing"
+
+func TestSkipPackets(t *testing.T) {
+ tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, permissiveTransportParameters)
+ connWritesPacket := func() {
+ s.WriteByte(0)
+ s.Flush()
+ tc.wantFrameType("conn sends STREAM data",
+ packetType1RTT, debugFrameStream{})
+ tc.writeAckForLatest()
+ tc.wantIdle("conn is idle")
+ }
+ connWritesPacket()
+
+expectSkip:
+ for maxUntilSkip := 256; maxUntilSkip <= 1024; maxUntilSkip *= 2 {
+ for range maxUntilSkip + 1 {
+ nextNum := tc.lastPacket.num + 1
+
+ connWritesPacket()
+
+ if tc.lastPacket.num == nextNum+1 {
+ // A packet number was skipped, as expected.
+ continue expectSkip
+ }
+ if tc.lastPacket.num != nextNum {
+ t.Fatalf("got packet number %v, want %v or %v+1", tc.lastPacket.num, nextNum, nextNum)
+ }
+
+ }
+ t.Fatalf("no numbers skipped after %v packets", maxUntilSkip)
+ }
+}
+
+func TestSkipAckForSkippedPacket(t *testing.T) {
+ tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, permissiveTransportParameters)
+
+ // Cause the connection to send packets until it skips a packet number.
+ for {
+ // Cause the connection to send a packet.
+ last := tc.lastPacket
+ s.WriteByte(0)
+ s.Flush()
+ tc.wantFrameType("conn sends STREAM data",
+ packetType1RTT, debugFrameStream{})
+
+ if tc.lastPacket.num > 256 {
+ t.Fatalf("no numbers skipped after 256 packets")
+ }
+
+ // Acknowledge everything up to the packet before the one we just received.
+ // We don't acknowledge the most-recently-received packet, because doing
+ // so will cause the connection to drop state for the skipped packet number.
+ // (We only retain state up to the oldest in-flight packet.)
+ //
+ // If the conn has skipped a packet number, then this ack will improperly
+ // acknowledge the unsent packet.
+ t.Log(tc.lastPacket.num)
+ tc.writeFrames(tc.lastPacket.ptype, debugFrameAck{
+ ranges: []i64range[packetNumber]{{0, tc.lastPacket.num}},
+ })
+
+ if last != nil && tc.lastPacket.num == last.num+2 {
+ // The connection has skipped a packet number.
+ break
+ }
+ }
+
+ // We wrote an ACK for a skipped packet number.
+ // The connection should close.
+ tc.wantFrame("ACK for skipped packet causes CONNECTION_CLOSE",
+ packetType1RTT, debugFrameConnectionCloseTransport{
+ code: errProtocolViolation,
+ })
+}