quic: CRYPTO stream handling

CRYPTO frames carry TLS handshake messages.

Add a cryptoStream type which manages the TLS handshake stream,
including retransmission of lost data, processing out-of-order
received data, etc.

For golang/go#58547

Change-Id: I8defa38e22d9c1bb8753f3a44d5ae0853fa56de8
Reviewed-on: https://go-review.googlesource.com/c/net/+/510616
Reviewed-by: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/internal/quic/crypto_stream.go b/internal/quic/crypto_stream.go
new file mode 100644
index 0000000..6cda657
--- /dev/null
+++ b/internal/quic/crypto_stream.go
@@ -0,0 +1,159 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+// "Implementations MUST support buffering at least 4096 bytes of data
+// received in out-of-order CRYPTO frames."
+// https://www.rfc-editor.org/rfc/rfc9000.html#section-7.5-2
+//
+// 4096 is too small for real-world cases, however, so we allow more.
+const cryptoBufferSize = 1 << 20
+
+// A cryptoStream is the stream of data passed in CRYPTO frames.
+// There is one cryptoStream per packet number space.
+type cryptoStream struct {
+	// CRYPTO data received from the peer.
+	in    pipe
+	inset rangeset[int64] // bytes received
+
+	// CRYPTO data queued for transmission to the peer.
+	out       pipe
+	outunsent rangeset[int64] // bytes in need of sending
+	outacked  rangeset[int64] // bytes acked by peer
+}
+
+// handleCrypto processes data received in a CRYPTO frame.
+func (s *cryptoStream) handleCrypto(off int64, b []byte, f func([]byte) error) error {
+	end := off + int64(len(b))
+	if end-s.inset.min() > cryptoBufferSize {
+		return localTransportError(errCryptoBufferExceeded)
+	}
+	s.inset.add(off, end)
+	if off == s.in.start {
+		// Fast path: This is the next chunk of data in the stream,
+		// so just handle it immediately.
+		if err := f(b); err != nil {
+			return err
+		}
+		s.in.discardBefore(end)
+	} else {
+		// This is either data we've already processed,
+		// data we can't process yet, or a mix of both.
+		s.in.writeAt(b, off)
+	}
+	// s.in.start is the next byte in sequence.
+	// If it's in s.inset, we have bytes to provide.
+	// If it isn't, we don't--we're either out of data,
+	// or only have data that comes after the next byte.
+	if !s.inset.contains(s.in.start) {
+		return nil
+	}
+	// size is the size of the first contiguous chunk of bytes
+	// that have not been processed yet.
+	size := int(s.inset[0].end - s.in.start)
+	if size <= 0 {
+		return nil
+	}
+	err := s.in.read(s.in.start, size, f)
+	s.in.discardBefore(s.inset[0].end)
+	return err
+}
+
+// write queues data for sending to the peer.
+// It does not block or limit the amount of buffered data.
+// QUIC connections don't communicate the amount of CRYPTO data they are willing to buffer,
+// so we send what we have and the peer can close the connection if it is too much.
+func (s *cryptoStream) write(b []byte) {
+	start := s.out.end
+	s.out.writeAt(b, start)
+	s.outunsent.add(start, s.out.end)
+}
+
+// ackOrLoss reports that an CRYPTO frame sent by us has been acknowledged by the peer, or lost.
+func (s *cryptoStream) ackOrLoss(start, end int64, fate packetFate) {
+	switch fate {
+	case packetAcked:
+		s.outacked.add(start, end)
+		s.outunsent.sub(start, end)
+		// If this ack is for data at the start of the send buffer, we can now discard it.
+		if s.outacked.contains(s.out.start) {
+			s.out.discardBefore(s.outacked[0].end)
+		}
+	case packetLost:
+		// Mark everything lost, but not previously acked, as needing retransmission.
+		// We do this by adding all the lost bytes to outunsent, and then
+		// removing everything already acked.
+		s.outunsent.add(start, end)
+		for _, a := range s.outacked {
+			s.outunsent.sub(a.start, a.end)
+		}
+	}
+}
+
+// dataToSend reports what data should be sent in CRYPTO frames to the peer.
+// It calls f with each range of data to send.
+// f uses sendData to get the bytes to send, and returns the number of bytes sent.
+// dataToSend calls f until no data is left, or f returns 0.
+//
+// This function is unusually indirect (why not just return a []byte,
+// or implement io.Reader?).
+//
+// Returning a []byte to the caller either requires that we store the
+// data to send contiguously (which we don't), allocate a temporary buffer
+// and copy into it (inefficient), or return less data than we have available
+// (requires complexity to avoid unnecessarily breaking data across frames).
+//
+// Accepting a []byte from the caller (io.Reader) makes packet construction
+// difficult. Since CRYPTO data is encoded with a varint length prefix, the
+// location of the data depends on the length of the data. (We could hardcode
+// a 2-byte length, of course.)
+//
+// Instead, we tell the caller how much data is, the caller figures out where
+// to put it (and possibly decides that it doesn't have space for this data
+// in the packet after all), and the caller then makes a separate call to
+// copy the data it wants into position.
+func (s *cryptoStream) dataToSend(pto bool, f func(off, size int64) (sent int64)) {
+	for {
+		var off, size int64
+		if pto {
+			// On PTO, resend unacked data that fits in the probe packet.
+			// For simplicity, we send the range starting at s.out.start
+			// (which is definitely unacked, or else we would have discarded it)
+			// up to the next acked byte (if any).
+			//
+			// This may miss unacked data starting after that acked byte,
+			// but avoids resending data the peer has acked.
+			off = s.out.start
+			end := s.out.end
+			for _, r := range s.outacked {
+				if r.start > off {
+					end = r.start
+					break
+				}
+			}
+			size = end - s.out.start
+		} else if s.outunsent.numRanges() > 0 {
+			off = s.outunsent.min()
+			size = s.outunsent[0].size()
+		}
+		if size == 0 {
+			return
+		}
+		n := f(off, size)
+		if n == 0 || pto {
+			return
+		}
+	}
+}
+
+// sendData fills b with data to send to the peer, starting at off,
+// and marks the data as sent. The caller must have already ascertained
+// that there is data to send in this region using dataToSend.
+func (s *cryptoStream) sendData(off int64, b []byte) {
+	s.out.copy(off, b)
+	s.outunsent.sub(off, off+int64(len(b)))
+}
diff --git a/internal/quic/crypto_stream_test.go b/internal/quic/crypto_stream_test.go
new file mode 100644
index 0000000..a6c1e1b
--- /dev/null
+++ b/internal/quic/crypto_stream_test.go
@@ -0,0 +1,265 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+	"crypto/rand"
+	"reflect"
+	"testing"
+)
+
+func TestCryptoStreamReceive(t *testing.T) {
+	data := make([]byte, 1<<20)
+	rand.Read(data) // doesn't need to be crypto/rand, but non-deprecated and harmless
+	type frame struct {
+		start int64
+		end   int64
+		want  int
+	}
+	for _, test := range []struct {
+		name   string
+		frames []frame
+	}{{
+		name: "linear",
+		frames: []frame{{
+			start: 0,
+			end:   1000,
+			want:  1000,
+		}, {
+			start: 1000,
+			end:   2000,
+			want:  2000,
+		}, {
+			// larger than any realistic packet can hold
+			start: 2000,
+			end:   1 << 20,
+			want:  1 << 20,
+		}},
+	}, {
+		name: "out of order",
+		frames: []frame{{
+			start: 1000,
+			end:   2000,
+		}, {
+			start: 2000,
+			end:   3000,
+		}, {
+			start: 0,
+			end:   1000,
+			want:  3000,
+		}},
+	}, {
+		name: "resent",
+		frames: []frame{{
+			start: 0,
+			end:   1000,
+			want:  1000,
+		}, {
+			start: 0,
+			end:   1000,
+			want:  1000,
+		}, {
+			start: 1000,
+			end:   2000,
+			want:  2000,
+		}, {
+			start: 0,
+			end:   1000,
+			want:  2000,
+		}, {
+			start: 1000,
+			end:   2000,
+			want:  2000,
+		}},
+	}, {
+		name: "overlapping",
+		frames: []frame{{
+			start: 0,
+			end:   1000,
+			want:  1000,
+		}, {
+			start: 3000,
+			end:   4000,
+			want:  1000,
+		}, {
+			start: 2000,
+			end:   3000,
+			want:  1000,
+		}, {
+			start: 1000,
+			end:   3000,
+			want:  4000,
+		}},
+	}} {
+		t.Run(test.name, func(t *testing.T) {
+			var s cryptoStream
+			var got []byte
+			for _, f := range test.frames {
+				t.Logf("receive [%v,%v)", f.start, f.end)
+				s.handleCrypto(
+					f.start,
+					data[f.start:f.end],
+					func(b []byte) error {
+						t.Logf("got new bytes [%v,%v)", len(got), len(got)+len(b))
+						got = append(got, b...)
+						return nil
+					},
+				)
+				if len(got) != f.want {
+					t.Fatalf("have bytes [0,%v), want [0,%v)", len(got), f.want)
+				}
+				for i := range got {
+					if got[i] != data[i] {
+						t.Fatalf("byte %v of received data = %v, want %v", i, got[i], data[i])
+					}
+				}
+			}
+		})
+	}
+}
+
+func TestCryptoStreamSends(t *testing.T) {
+	data := make([]byte, 1<<20)
+	rand.Read(data) // doesn't need to be crypto/rand, but non-deprecated and harmless
+	type (
+		sendOp i64range[int64]
+		ackOp  i64range[int64]
+		lossOp i64range[int64]
+	)
+	for _, test := range []struct {
+		name        string
+		size        int64
+		ops         []any
+		wantSend    []i64range[int64]
+		wantPTOSend []i64range[int64]
+	}{{
+		name: "writes with data remaining",
+		size: 4000,
+		ops: []any{
+			sendOp{0, 1000},
+			sendOp{1000, 2000},
+			sendOp{2000, 3000},
+		},
+		wantSend: []i64range[int64]{
+			{3000, 4000},
+		},
+		wantPTOSend: []i64range[int64]{
+			{0, 4000},
+		},
+	}, {
+		name: "lost data is resent",
+		size: 4000,
+		ops: []any{
+			sendOp{0, 1000},
+			sendOp{1000, 2000},
+			sendOp{2000, 3000},
+			sendOp{3000, 4000},
+			lossOp{1000, 2000},
+			lossOp{3000, 4000},
+		},
+		wantSend: []i64range[int64]{
+			{1000, 2000},
+			{3000, 4000},
+		},
+		wantPTOSend: []i64range[int64]{
+			{0, 4000},
+		},
+	}, {
+		name: "acked data at start of range",
+		size: 4000,
+		ops: []any{
+			sendOp{0, 4000},
+			ackOp{0, 1000},
+			ackOp{1000, 2000},
+			ackOp{2000, 3000},
+		},
+		wantSend: nil,
+		wantPTOSend: []i64range[int64]{
+			{3000, 4000},
+		},
+	}, {
+		name: "acked data is not resent on pto",
+		size: 4000,
+		ops: []any{
+			sendOp{0, 4000},
+			ackOp{1000, 2000},
+		},
+		wantSend: nil,
+		wantPTOSend: []i64range[int64]{
+			{0, 1000},
+		},
+	}, {
+		// This is an unusual, but possible scenario:
+		// Data is sent, resent, one of the two sends is acked, and the other is lost.
+		name: "acked and then lost data is not resent",
+		size: 4000,
+		ops: []any{
+			sendOp{0, 4000},
+			sendOp{1000, 2000}, // resent, no-op
+			ackOp{1000, 2000},
+			lossOp{1000, 2000},
+		},
+		wantSend: nil,
+		wantPTOSend: []i64range[int64]{
+			{0, 1000},
+		},
+	}, {
+		// The opposite of the above scenario: data is marked lost, and then acked
+		// before being resent.
+		name: "lost and then acked data is not resent",
+		size: 4000,
+		ops: []any{
+			sendOp{0, 4000},
+			sendOp{1000, 2000}, // resent, no-op
+			lossOp{1000, 2000},
+			ackOp{1000, 2000},
+		},
+		wantSend: nil,
+		wantPTOSend: []i64range[int64]{
+			{0, 1000},
+		},
+	}} {
+		t.Run(test.name, func(t *testing.T) {
+			var s cryptoStream
+			s.write(data[:test.size])
+			for _, op := range test.ops {
+				switch op := op.(type) {
+				case sendOp:
+					t.Logf("send [%v,%v)", op.start, op.end)
+					b := make([]byte, op.end-op.start)
+					s.sendData(op.start, b)
+				case ackOp:
+					t.Logf("ack  [%v,%v)", op.start, op.end)
+					s.ackOrLoss(op.start, op.end, packetAcked)
+				case lossOp:
+					t.Logf("loss [%v,%v)", op.start, op.end)
+					s.ackOrLoss(op.start, op.end, packetLost)
+				default:
+					t.Fatalf("unhandled type %T", op)
+				}
+			}
+			var gotSend []i64range[int64]
+			s.dataToSend(true, func(off, size int64) (wrote int64) {
+				gotSend = append(gotSend, i64range[int64]{off, off + size})
+				return 0
+			})
+			if !reflect.DeepEqual(gotSend, test.wantPTOSend) {
+				t.Fatalf("got data to send on PTO: %v, want %v", gotSend, test.wantPTOSend)
+			}
+			gotSend = nil
+			s.dataToSend(false, func(off, size int64) (wrote int64) {
+				gotSend = append(gotSend, i64range[int64]{off, off + size})
+				b := make([]byte, size)
+				s.sendData(off, b)
+				return int64(len(b))
+			})
+			if !reflect.DeepEqual(gotSend, test.wantSend) {
+				t.Fatalf("got data to send: %v, want %v", gotSend, test.wantSend)
+			}
+		})
+	}
+}