quic: add RTT estimator

Implement the round-trip time estimation algorithm from
RFC 9002, Section 5.

For golang/go#58547

Change-Id: I494e692e710f77270c9ad28354366f384feb4ac7
Reviewed-on: https://go-review.googlesource.com/c/net/+/499286
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Damien Neil <dneil@google.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/internal/quic/math.go b/internal/quic/math.go
new file mode 100644
index 0000000..f9dd754
--- /dev/null
+++ b/internal/quic/math.go
@@ -0,0 +1,14 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+func abs[T ~int | ~int64](a T) T {
+	if a < 0 {
+		return -a
+	}
+	return a
+}
diff --git a/internal/quic/rtt.go b/internal/quic/rtt.go
new file mode 100644
index 0000000..5bd8861
--- /dev/null
+++ b/internal/quic/rtt.go
@@ -0,0 +1,73 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+	"time"
+)
+
+type rttState struct {
+	minRTT          time.Duration
+	latestRTT       time.Duration
+	smoothedRTT     time.Duration
+	rttvar          time.Duration // RTT variation
+	firstSampleTime time.Time     // time of first RTT sample
+}
+
+func (r *rttState) init() {
+	r.minRTT = -1 // -1 indicates the first sample has not been taken yet
+
+	// "[...] the initial RTT SHOULD be set to 333 milliseconds."
+	// https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.2-1
+	const initialRTT = 333 * time.Millisecond
+
+	// https://www.rfc-editor.org/rfc/rfc9002.html#section-5.3-12
+	r.smoothedRTT = initialRTT
+	r.rttvar = initialRTT / 2
+}
+
+func (r *rttState) establishPersistentCongestion() {
+	// "Endpoints SHOULD set the min_rtt to the newest RTT sample
+	// after persistent congestion is established."
+	// https://www.rfc-editor.org/rfc/rfc9002#section-5.2-5
+	r.minRTT = r.latestRTT
+}
+
+// updateRTTSample is called when we generate a new RTT sample.
+// https://www.rfc-editor.org/rfc/rfc9002.html#section-5
+func (r *rttState) updateSample(now time.Time, handshakeConfirmed bool, spaceID numberSpace, latestRTT, ackDelay, maxAckDelay time.Duration) {
+	r.latestRTT = latestRTT
+
+	if r.minRTT < 0 {
+		// First RTT sample.
+		// "min_rtt MUST be set to the latest_rtt on the first RTT sample."
+		// https://www.rfc-editor.org/rfc/rfc9002.html#section-5.2-2
+		r.minRTT = latestRTT
+		// https://www.rfc-editor.org/rfc/rfc9002.html#section-5.3-14
+		r.smoothedRTT = latestRTT
+		r.rttvar = latestRTT / 2
+		r.firstSampleTime = now
+		return
+	}
+
+	// "min_rtt MUST be set to the lesser of min_rtt and latest_rtt [...]
+	// on all other samples."
+	// https://www.rfc-editor.org/rfc/rfc9002.html#section-5.2-2
+	r.minRTT = min(r.minRTT, latestRTT)
+
+	// https://www.rfc-editor.org/rfc/rfc9002.html#section-5.3-16
+	if handshakeConfirmed {
+		ackDelay = min(ackDelay, maxAckDelay)
+	}
+	adjustedRTT := latestRTT - ackDelay
+	if adjustedRTT < r.minRTT {
+		adjustedRTT = latestRTT
+	}
+	r.smoothedRTT = ((7 * r.smoothedRTT) + adjustedRTT) / 8
+	rttvarSample := abs(r.smoothedRTT - adjustedRTT)
+	r.rttvar = (3*r.rttvar + rttvarSample) / 4
+}
diff --git a/internal/quic/rtt_test.go b/internal/quic/rtt_test.go
new file mode 100644
index 0000000..63789c2
--- /dev/null
+++ b/internal/quic/rtt_test.go
@@ -0,0 +1,168 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+	"testing"
+	"time"
+)
+
+func TestRTTMinRTT(t *testing.T) {
+	var (
+		handshakeConfirmed = false
+		ackDelay           = 0 * time.Millisecond
+		maxAckDelay        = 25 * time.Millisecond
+		now                = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)
+	)
+	rtt := &rttState{}
+	rtt.init()
+
+	// "min_rtt MUST be set to the latest_rtt on the first RTT sample."
+	// https://www.rfc-editor.org/rfc/rfc9002.html#section-5.2-2
+	rtt.updateSample(now, handshakeConfirmed, initialSpace, 10*time.Millisecond, ackDelay, maxAckDelay)
+	if got, want := rtt.latestRTT, 10*time.Millisecond; got != want {
+		t.Errorf("on first sample: latest_rtt = %v, want %v", got, want)
+	}
+	if got, want := rtt.minRTT, 10*time.Millisecond; got != want {
+		t.Errorf("on first sample: min_rtt = %v, want %v", got, want)
+	}
+
+	// "min_rtt MUST be set to the lesser of min_rtt and latest_rtt [...]
+	// on all other samples."
+	rtt.updateSample(now, handshakeConfirmed, initialSpace, 20*time.Millisecond, ackDelay, maxAckDelay)
+	if got, want := rtt.latestRTT, 20*time.Millisecond; got != want {
+		t.Errorf("on increasing sample: latest_rtt = %v, want %v", got, want)
+	}
+	if got, want := rtt.minRTT, 10*time.Millisecond; got != want {
+		t.Errorf("on increasing sample: min_rtt = %v, want %v (no change)", got, want)
+	}
+
+	rtt.updateSample(now, handshakeConfirmed, initialSpace, 5*time.Millisecond, ackDelay, maxAckDelay)
+	if got, want := rtt.latestRTT, 5*time.Millisecond; got != want {
+		t.Errorf("on new minimum: latest_rtt = %v, want %v", got, want)
+	}
+	if got, want := rtt.minRTT, 5*time.Millisecond; got != want {
+		t.Errorf("on new minimum: min_rtt = %v, want %v", got, want)
+	}
+
+	// "Endpoints SHOULD set the min_rtt to the newest RTT sample
+	// after persistent congestion is established."
+	// https://www.rfc-editor.org/rfc/rfc9002.html#section-5.2-5
+	rtt.updateSample(now, handshakeConfirmed, initialSpace, 15*time.Millisecond, ackDelay, maxAckDelay)
+	if got, want := rtt.latestRTT, 15*time.Millisecond; got != want {
+		t.Errorf("on increasing sample: latest_rtt = %v, want %v", got, want)
+	}
+	if got, want := rtt.minRTT, 5*time.Millisecond; got != want {
+		t.Errorf("on increasing sample: min_rtt = %v, want %v (no change)", got, want)
+	}
+	rtt.establishPersistentCongestion()
+	if got, want := rtt.minRTT, 15*time.Millisecond; got != want {
+		t.Errorf("after persistent congestion: min_rtt = %v, want %v", got, want)
+	}
+}
+
+func TestRTTInitialRTT(t *testing.T) {
+	var (
+		handshakeConfirmed = false
+		ackDelay           = 0 * time.Millisecond
+		maxAckDelay        = 25 * time.Millisecond
+		now                = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)
+	)
+	rtt := &rttState{}
+	rtt.init()
+
+	// "When no previous RTT is available,
+	// the initial RTT SHOULD be set to 333 milliseconds."
+	// https://www.rfc-editor.org/rfc/rfc9002#section-6.2.2-1
+	if got, want := rtt.smoothedRTT, 333*time.Millisecond; got != want {
+		t.Errorf("initial smoothed_rtt = %v, want %v", got, want)
+	}
+	if got, want := rtt.rttvar, 333*time.Millisecond/2; got != want {
+		t.Errorf("initial rttvar = %v, want %v", got, want)
+	}
+
+	rtt.updateSample(now, handshakeConfirmed, initialSpace, 10*time.Millisecond, ackDelay, maxAckDelay)
+	smoothedRTT := 10 * time.Millisecond
+	if got, want := rtt.smoothedRTT, smoothedRTT; got != want {
+		t.Errorf("after first rtt sample of 10ms, smoothed_rtt = %v, want %v", got, want)
+	}
+	rttvar := 5 * time.Millisecond
+	if got, want := rtt.rttvar, rttvar; got != want {
+		t.Errorf("after first rtt sample of 10ms, rttvar = %v, want %v", got, want)
+	}
+
+	// "[...] MAY ignore the acknowledgment delay for Initial packets [...]"
+	// https://www.rfc-editor.org/rfc/rfc9002#section-5.3-7.1
+	ackDelay = 1 * time.Millisecond
+	rtt.updateSample(now, handshakeConfirmed, initialSpace, 10*time.Millisecond, ackDelay, maxAckDelay)
+	adjustedRTT := 10 * time.Millisecond
+	smoothedRTT = (7*smoothedRTT + adjustedRTT) / 8
+	if got, want := rtt.smoothedRTT, smoothedRTT; got != want {
+		t.Errorf("smoothed_rtt = %v, want %v", got, want)
+	}
+	rttvarSample := abs(smoothedRTT - adjustedRTT)
+	rttvar = (3*rttvar + rttvarSample) / 4
+	if got, want := rtt.rttvar, rttvar; got != want {
+		t.Errorf("rttvar = %v, want %v", got, want)
+	}
+
+	// "[...] SHOULD ignore the peer's max_ack_delay until the handshake is confirmed [...]"
+	// https://www.rfc-editor.org/rfc/rfc9002#section-5.3-7.2
+	ackDelay = 30 * time.Millisecond
+	maxAckDelay = 25 * time.Millisecond
+	rtt.updateSample(now, handshakeConfirmed, handshakeSpace, 40*time.Millisecond, ackDelay, maxAckDelay)
+	adjustedRTT = 10 * time.Millisecond // latest_rtt (40ms) - ack_delay (30ms)
+	smoothedRTT = (7*smoothedRTT + adjustedRTT) / 8
+	if got, want := rtt.smoothedRTT, smoothedRTT; got != want {
+		t.Errorf("smoothed_rtt = %v, want %v", got, want)
+	}
+	rttvarSample = abs(smoothedRTT - adjustedRTT)
+	rttvar = (3*rttvar + rttvarSample) / 4
+	if got, want := rtt.rttvar, rttvar; got != want {
+		t.Errorf("rttvar = %v, want %v", got, want)
+	}
+
+	// "[...] MUST use the lesser of the acknowledgment delay and
+	// the peer's max_ack_delay after the handshake is confirmed [...]"
+	// https://www.rfc-editor.org/rfc/rfc9002#section-5.3-7.3
+	ackDelay = 30 * time.Millisecond
+	maxAckDelay = 25 * time.Millisecond
+	handshakeConfirmed = true
+	rtt.updateSample(now, handshakeConfirmed, handshakeSpace, 40*time.Millisecond, ackDelay, maxAckDelay)
+	adjustedRTT = 15 * time.Millisecond // latest_rtt (40ms) - max_ack_delay (25ms)
+	smoothedRTT = (7*smoothedRTT + adjustedRTT) / 8
+	if got, want := rtt.smoothedRTT, smoothedRTT; got != want {
+		t.Errorf("smoothed_rtt = %v, want %v", got, want)
+	}
+	rttvarSample = abs(smoothedRTT - adjustedRTT)
+	rttvar = (3*rttvar + rttvarSample) / 4
+	if got, want := rtt.rttvar, rttvar; got != want {
+		t.Errorf("rttvar = %v, want %v", got, want)
+	}
+
+	// "[...] MUST NOT subtract the acknowledgment delay from
+	// the RTT sample if the resulting value is smaller than the min_rtt."
+	// https://www.rfc-editor.org/rfc/rfc9002#section-5.3-7.4
+	ackDelay = 25 * time.Millisecond
+	maxAckDelay = 25 * time.Millisecond
+	handshakeConfirmed = true
+	rtt.updateSample(now, handshakeConfirmed, handshakeSpace, 30*time.Millisecond, ackDelay, maxAckDelay)
+	if got, want := rtt.minRTT, 10*time.Millisecond; got != want {
+		t.Errorf("min_rtt = %v, want %v", got, want)
+	}
+	// latest_rtt (30ms) - ack_delay (25ms) = 5ms, which is less than min_rtt (10ms)
+	adjustedRTT = 30 * time.Millisecond // latest_rtt
+	smoothedRTT = (7*smoothedRTT + adjustedRTT) / 8
+	if got, want := rtt.smoothedRTT, smoothedRTT; got != want {
+		t.Errorf("smoothed_rtt = %v, want %v", got, want)
+	}
+	rttvarSample = abs(smoothedRTT - adjustedRTT)
+	rttvar = (3*rttvar + rttvarSample) / 4
+	if got, want := rtt.rttvar, rttvar; got != want {
+		t.Errorf("rttvar = %v, want %v", got, want)
+	}
+}