quic: move ack_delay_exponent handling out of frame parsing

The ACK Delay field of ACK frames contains a duration.
The field contains an integer which is multiplied by two to the
power of the sender's ack_delay_exponent transport parameter to
arrive at the delay in microseconds.

Change the frame parsing and encoding layer to operate on the
unscaled field value, rather than passing the ack_delay_exponent
and a duration. This better expresses the fact that we may
parse an ACK frame without knowing the ack_delay_exponent, if
the ACK is received before transport parameters.

For golang/go#58547

Change-Id: Ic26256761961ce89aea0618b849e5661b0502b12
Reviewed-on: https://go-review.googlesource.com/c/net/+/504855
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Damien Neil <dneil@google.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/internal/quic/ack_delay.go b/internal/quic/ack_delay.go
new file mode 100644
index 0000000..66bdf3c
--- /dev/null
+++ b/internal/quic/ack_delay.go
@@ -0,0 +1,28 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+	"math"
+	"time"
+)
+
+// An unscaledAckDelay is an ACK Delay field value from an ACK packet,
+// without the ack_delay_exponent scaling applied.
+type unscaledAckDelay int64
+
+func unscaledAckDelayFromDuration(d time.Duration, ackDelayExponent uint8) unscaledAckDelay {
+	return unscaledAckDelay(d.Microseconds() >> ackDelayExponent)
+}
+
+func (d unscaledAckDelay) Duration(ackDelayExponent uint8) time.Duration {
+	if int64(d) > (math.MaxInt64>>ackDelayExponent)/int64(time.Microsecond) {
+		// If scaling the delay would overflow, ignore the delay.
+		return 0
+	}
+	return time.Duration(d<<ackDelayExponent) * time.Microsecond
+}
diff --git a/internal/quic/ack_delay_test.go b/internal/quic/ack_delay_test.go
new file mode 100644
index 0000000..038964a
--- /dev/null
+++ b/internal/quic/ack_delay_test.go
@@ -0,0 +1,81 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+	"math"
+	"testing"
+	"time"
+)
+
+func TestAckDelayFromDuration(t *testing.T) {
+	for _, test := range []struct {
+		d                time.Duration
+		ackDelayExponent uint8
+		want             unscaledAckDelay
+	}{{
+		d:                8 * time.Microsecond,
+		ackDelayExponent: 3,
+		want:             1,
+	}, {
+		d:                1 * time.Nanosecond,
+		ackDelayExponent: 3,
+		want:             0, // rounds to zero
+	}, {
+		d:                3 * (1 << 20) * time.Microsecond,
+		ackDelayExponent: 20,
+		want:             3,
+	}} {
+		got := unscaledAckDelayFromDuration(test.d, test.ackDelayExponent)
+		if got != test.want {
+			t.Errorf("unscaledAckDelayFromDuration(%v, %v) = %v, want %v",
+				test.d, test.ackDelayExponent, got, test.want)
+		}
+	}
+}
+
+func TestAckDelayToDuration(t *testing.T) {
+	for _, test := range []struct {
+		d                unscaledAckDelay
+		ackDelayExponent uint8
+		want             time.Duration
+	}{{
+		d:                1,
+		ackDelayExponent: 3,
+		want:             8 * time.Microsecond,
+	}, {
+		d:                0,
+		ackDelayExponent: 3,
+		want:             0,
+	}, {
+		d:                3,
+		ackDelayExponent: 20,
+		want:             3 * (1 << 20) * time.Microsecond,
+	}, {
+		d:                math.MaxInt64 / 1000,
+		ackDelayExponent: 0,
+		want:             (math.MaxInt64 / 1000) * time.Microsecond,
+	}, {
+		d:                (math.MaxInt64 / 1000) + 1,
+		ackDelayExponent: 0,
+		want:             0, // return 0 on overflow
+	}, {
+		d:                math.MaxInt64 / 1000 / 8,
+		ackDelayExponent: 3,
+		want:             (math.MaxInt64 / 1000 / 8) * 8 * time.Microsecond,
+	}, {
+		d:                (math.MaxInt64 / 1000 / 8) + 1,
+		ackDelayExponent: 3,
+		want:             0, // return 0 on overflow
+	}} {
+		got := test.d.Duration(test.ackDelayExponent)
+		if got != test.want {
+			t.Errorf("unscaledAckDelay(%v).Duration(%v) = %v, want %v",
+				test.d, test.ackDelayExponent, int64(got), int64(test.want))
+		}
+	}
+}
diff --git a/internal/quic/frame_debug.go b/internal/quic/frame_debug.go
index 93ddf55..945bb9d 100644
--- a/internal/quic/frame_debug.go
+++ b/internal/quic/frame_debug.go
@@ -8,7 +8,6 @@
 
 import (
 	"fmt"
-	"time"
 )
 
 // A debugFrame is a representation of the contents of a QUIC frame,
@@ -115,13 +114,13 @@
 
 // debugFrameAck is an ACK frame.
 type debugFrameAck struct {
-	ackDelay time.Duration
+	ackDelay unscaledAckDelay
 	ranges   []i64range[packetNumber]
 }
 
 func parseDebugFrameAck(b []byte) (f debugFrameAck, n int) {
 	f.ranges = nil
-	_, f.ackDelay, n = consumeAckFrame(b, ackDelayExponent, func(start, end packetNumber) {
+	_, f.ackDelay, n = consumeAckFrame(b, func(start, end packetNumber) {
 		f.ranges = append(f.ranges, i64range[packetNumber]{
 			start: start,
 			end:   end,
@@ -144,7 +143,7 @@
 }
 
 func (f debugFrameAck) write(w *packetWriter) bool {
-	return w.appendAckFrame(rangeset[packetNumber](f.ranges), ackDelayExponent, f.ackDelay)
+	return w.appendAckFrame(rangeset[packetNumber](f.ranges), f.ackDelay)
 }
 
 // debugFrameResetStream is a RESET_STREAM frame.
diff --git a/internal/quic/packet_codec_test.go b/internal/quic/packet_codec_test.go
index efd519b..499ec4d 100644
--- a/internal/quic/packet_codec_test.go
+++ b/internal/quic/packet_codec_test.go
@@ -11,7 +11,6 @@
 	"crypto/tls"
 	"reflect"
 	"testing"
-	"time"
 )
 
 func TestParseLongHeaderPacket(t *testing.T) {
@@ -219,9 +218,9 @@
 			0x01, // TYPE(i) = 0x01
 		},
 	}, {
-		s: "ACK Delay=80µs [0,16) [17,32) [48,64)",
+		s: "ACK Delay=10 [0,16) [17,32) [48,64)",
 		f: debugFrameAck{
-			ackDelay: (10 << ackDelayExponent) * time.Microsecond,
+			ackDelay: 10,
 			ranges: []i64range[packetNumber]{
 				{0x00, 0x10},
 				{0x11, 0x20},
@@ -594,7 +593,7 @@
 	}, {
 		desc: "ACK frame with ECN counts",
 		want: debugFrameAck{
-			ackDelay: (10 << ackDelayExponent) * time.Microsecond,
+			ackDelay: 10,
 			ranges: []i64range[packetNumber]{
 				{0, 1},
 			},
diff --git a/internal/quic/packet_parser.go b/internal/quic/packet_parser.go
index cc025b6..e910e0e 100644
--- a/internal/quic/packet_parser.go
+++ b/internal/quic/packet_parser.go
@@ -6,10 +6,6 @@
 
 package quic
 
-import (
-	"time"
-)
-
 // parseLongHeaderPacket parses a QUIC long header packet.
 //
 // It does not parse Version Negotiation packets.
@@ -166,7 +162,7 @@
 // which includes both general parse failures and specific violations of frame
 // constraints.
 
-func consumeAckFrame(frame []byte, ackDelayExponent uint8, f func(start, end packetNumber)) (largest packetNumber, ackDelay time.Duration, n int) {
+func consumeAckFrame(frame []byte, f func(start, end packetNumber)) (largest packetNumber, ackDelay unscaledAckDelay, n int) {
 	b := frame[1:] // type
 
 	largestAck, n := consumeVarint(b)
@@ -175,12 +171,12 @@
 	}
 	b = b[n:]
 
-	ackDelayScaled, n := consumeVarint(b)
+	v, n := consumeVarintInt64(b)
 	if n < 0 {
 		return 0, 0, -1
 	}
 	b = b[n:]
-	ackDelay = time.Duration(ackDelayScaled*(1<<ackDelayExponent)) * time.Microsecond
+	ackDelay = unscaledAckDelay(v)
 
 	ackRangeCount, n := consumeVarint(b)
 	if n < 0 {
diff --git a/internal/quic/packet_writer.go b/internal/quic/packet_writer.go
index bfe9af7..97987e0 100644
--- a/internal/quic/packet_writer.go
+++ b/internal/quic/packet_writer.go
@@ -8,7 +8,6 @@
 
 import (
 	"encoding/binary"
-	"time"
 )
 
 // A packetWriter constructs QUIC datagrams.
@@ -257,21 +256,20 @@
 // to the peer potentially failing to receive an acknowledgement
 // for an older packet during a period of high packet loss or
 // reordering. This may result in unnecessary retransmissions.
-func (w *packetWriter) appendAckFrame(seen rangeset[packetNumber], ackDelayExponent uint8, delay time.Duration) (added bool) {
+func (w *packetWriter) appendAckFrame(seen rangeset[packetNumber], delay unscaledAckDelay) (added bool) {
 	if len(seen) == 0 {
 		return false
 	}
 	var (
 		largest    = uint64(seen.max())
-		mdelay     = uint64(delay.Microseconds() / (1 << ackDelayExponent))
 		firstRange = uint64(seen[len(seen)-1].size() - 1)
 	)
-	if w.avail() < 1+sizeVarint(largest)+sizeVarint(mdelay)+1+sizeVarint(firstRange) {
+	if w.avail() < 1+sizeVarint(largest)+sizeVarint(uint64(delay))+1+sizeVarint(firstRange) {
 		return false
 	}
 	w.b = append(w.b, frameTypeAck)
 	w.b = appendVarint(w.b, largest)
-	w.b = appendVarint(w.b, mdelay)
+	w.b = appendVarint(w.b, uint64(delay))
 	// The range count is technically a varint, but we'll reserve a single byte for it
 	// and never add more than 62 ranges (the maximum varint that fits in a byte).
 	rangeCountOff := len(w.b)