quic: remove streams from the conn when done

When a stream has been fully shut down--the peer has closed
its end and acked every frame we will send for it--remove
it from the Conn's set of active streams.

We do the actual removal on the conn's loop, so stream cleanup
can access conn state without worrying about locking.

For golang/go#58547

Change-Id: Id9715693649929b07d303f0c4b3a782d135f0326
Reviewed-on: https://go-review.googlesource.com/c/net/+/524296
Reviewed-by: Jonathan Amsterdam <jba@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/internal/quic/atomic_bits.go b/internal/quic/atomic_bits.go
new file mode 100644
index 0000000..e1e2594
--- /dev/null
+++ b/internal/quic/atomic_bits.go
@@ -0,0 +1,33 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import "sync/atomic"
+
+// atomicBits is an atomic uint32 that supports setting individual bits.
+type atomicBits[T ~uint32] struct {
+	bits atomic.Uint32
+}
+
+// set sets the bits in mask to the corresponding bits in v.
+// It returns the new value.
+func (a *atomicBits[T]) set(v, mask T) T {
+	if v&^mask != 0 {
+		panic("BUG: bits in v are not in mask")
+	}
+	for {
+		o := a.bits.Load()
+		n := (o &^ uint32(mask)) | uint32(v)
+		if a.bits.CompareAndSwap(o, n) {
+			return T(n)
+		}
+	}
+}
+
+func (a *atomicBits[T]) load() T {
+	return T(a.bits.Load())
+}
diff --git a/internal/quic/conn_streams.go b/internal/quic/conn_streams.go
index dd35e34..0ede284 100644
--- a/internal/quic/conn_streams.go
+++ b/internal/quic/conn_streams.go
@@ -185,24 +185,46 @@
 	for {
 		s := c.streams.sendHead
 		const pto = false
-		if !s.appendInFrames(w, pnum, pto) {
-			return false
-		}
-		avail := w.avail()
-		if !s.appendOutFrames(w, pnum, pto) {
-			// We've sent some data for this stream, but it still has more to send.
-			// If the stream got a reasonable chance to put data in a packet,
-			// advance sendHead to the next stream in line, to avoid starvation.
-			// We'll come back to this stream after going through the others.
-			//
-			// If the packet was already mostly out of space, leave sendHead alone
-			// and come back to this stream again on the next packet.
-			if avail > 512 {
-				c.streams.sendHead = s.next
-				c.streams.sendTail = s
+
+		state := s.state.load()
+		if state&streamInSend != 0 {
+			s.ingate.lock()
+			ok := s.appendInFramesLocked(w, pnum, pto)
+			state = s.inUnlockNoQueue()
+			if !ok {
+				return false
 			}
-			return false
 		}
+
+		if state&streamOutSend != 0 {
+			avail := w.avail()
+			s.outgate.lock()
+			ok := s.appendOutFramesLocked(w, pnum, pto)
+			state = s.outUnlockNoQueue()
+			if !ok {
+				// We've sent some data for this stream, but it still has more to send.
+				// If the stream got a reasonable chance to put data in a packet,
+				// advance sendHead to the next stream in line, to avoid starvation.
+				// We'll come back to this stream after going through the others.
+				//
+				// If the packet was already mostly out of space, leave sendHead alone
+				// and come back to this stream again on the next packet.
+				if avail > 512 {
+					c.streams.sendHead = s.next
+					c.streams.sendTail = s
+				}
+				return false
+			}
+		}
+
+		if state == streamInDone|streamOutDone {
+			// Stream is finished, remove it from the conn.
+			s.state.set(streamConnRemoved, streamConnRemoved)
+			delete(c.streams.streams, s.id)
+
+			// TODO: Provide the peer with additional stream quota (MAX_STREAMS).
+		}
+
 		next := s.next
 		s.next = nil
 		if (next == s) != (s == c.streams.sendTail) {
@@ -231,10 +253,16 @@
 	defer c.streams.sendMu.Unlock()
 	for _, s := range c.streams.streams {
 		const pto = true
-		if !s.appendInFrames(w, pnum, pto) {
+		s.ingate.lock()
+		inOK := s.appendInFramesLocked(w, pnum, pto)
+		s.inUnlockNoQueue()
+		if !inOK {
 			return false
 		}
-		if !s.appendOutFrames(w, pnum, pto) {
+		s.outgate.lock()
+		outOK := s.appendOutFramesLocked(w, pnum, pto)
+		s.outUnlockNoQueue()
+		if !outOK {
 			return false
 		}
 	}
diff --git a/internal/quic/conn_streams_test.go b/internal/quic/conn_streams_test.go
index 877dbb9..9bbc994 100644
--- a/internal/quic/conn_streams_test.go
+++ b/internal/quic/conn_streams_test.go
@@ -8,6 +8,8 @@
 
 import (
 	"context"
+	"fmt"
+	"io"
 	"testing"
 )
 
@@ -253,3 +255,90 @@
 		}
 	}
 }
+
+func TestStreamsShutdown(t *testing.T) {
+	// These tests verify that a stream is removed from the Conn's map of live streams
+	// after it is fully shut down.
+	//
+	// Each case consists of a setup step, after which one stream should exist,
+	// and a shutdown step, after which no streams should remain in the Conn.
+	for _, test := range []struct {
+		name     string
+		side     streamSide
+		styp     streamType
+		setup    func(*testing.T, *testConn, *Stream)
+		shutdown func(*testing.T, *testConn, *Stream)
+	}{{
+		name: "closed",
+		side: localStream,
+		styp: uniStream,
+		setup: func(t *testing.T, tc *testConn, s *Stream) {
+			s.CloseContext(canceledContext())
+		},
+		shutdown: func(t *testing.T, tc *testConn, s *Stream) {
+			tc.writeAckForAll()
+		},
+	}, {
+		name: "local close",
+		side: localStream,
+		styp: bidiStream,
+		setup: func(t *testing.T, tc *testConn, s *Stream) {
+			tc.writeFrames(packetType1RTT, debugFrameResetStream{
+				id: s.id,
+			})
+			s.CloseContext(canceledContext())
+		},
+		shutdown: func(t *testing.T, tc *testConn, s *Stream) {
+			tc.writeAckForAll()
+		},
+	}, {
+		name: "remote reset",
+		side: localStream,
+		styp: bidiStream,
+		setup: func(t *testing.T, tc *testConn, s *Stream) {
+			s.CloseContext(canceledContext())
+			tc.wantIdle("all frames after CloseContext are ignored")
+			tc.writeAckForAll()
+		},
+		shutdown: func(t *testing.T, tc *testConn, s *Stream) {
+			tc.writeFrames(packetType1RTT, debugFrameResetStream{
+				id: s.id,
+			})
+		},
+	}, {
+		name: "local close",
+		side: remoteStream,
+		styp: uniStream,
+		setup: func(t *testing.T, tc *testConn, s *Stream) {
+			ctx := canceledContext()
+			tc.writeFrames(packetType1RTT, debugFrameStream{
+				id:  s.id,
+				fin: true,
+			})
+			if n, err := s.ReadContext(ctx, make([]byte, 16)); n != 0 || err != io.EOF {
+				t.Errorf("ReadContext() = %v, %v; want 0, io.EOF", n, err)
+			}
+		},
+		shutdown: func(t *testing.T, tc *testConn, s *Stream) {
+			s.CloseRead()
+		},
+	}} {
+		name := fmt.Sprintf("%v/%v/%v", test.side, test.styp, test.name)
+		t.Run(name, func(t *testing.T) {
+			tc, s := newTestConnAndStream(t, serverSide, test.side, test.styp,
+				permissiveTransportParameters)
+			tc.ignoreFrame(frameTypeStreamBase)
+			tc.ignoreFrame(frameTypeStopSending)
+			test.setup(t, tc, s)
+			tc.wantIdle("conn should be idle after setup")
+			if got, want := len(tc.conn.streams.streams), 1; got != want {
+				t.Fatalf("after setup: %v streams in Conn's map; want %v", got, want)
+			}
+			test.shutdown(t, tc, s)
+			tc.wantIdle("conn should be idle after shutdown")
+			if got, want := len(tc.conn.streams.streams), 0; got != want {
+				t.Fatalf("after shutdown: %v streams in Conn's map; want %v", got, want)
+			}
+		})
+	}
+}
diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go
index d8c4455..ea720d5 100644
--- a/internal/quic/conn_test.go
+++ b/internal/quic/conn_test.go
@@ -394,6 +394,7 @@
 // writeAckForAll sends the Conn a datagram containing an ack for all packets up to the
 // last one received.
 func (tc *testConn) writeAckForAll() {
+	tc.t.Helper()
 	if tc.lastPacket == nil {
 		return
 	}
@@ -405,6 +406,7 @@
 // writeAckForLatest sends the Conn a datagram containing an ack for the
 // most recent packet received.
 func (tc *testConn) writeAckForLatest() {
+	tc.t.Helper()
 	if tc.lastPacket == nil {
 		return
 	}
diff --git a/internal/quic/stream.go b/internal/quic/stream.go
index 1033cbb..2dbf446 100644
--- a/internal/quic/stream.go
+++ b/internal/quic/stream.go
@@ -49,9 +49,38 @@
 	outresetcode uint64          // reset code to send in RESET_STREAM
 	outdone      chan struct{}   // closed when all data sent
 
+	// Atomic stream state bits.
+	//
+	// These bits provide a fast way to coordinate between the
+	// send and receive sides of the stream, and the conn's loop.
+	//
+	// streamIn* bits must be set with ingate held.
+	// streamOut* bits must be set with outgate held.
+	// streamConn* bits are set by the conn's loop.
+	state atomicBits[streamState]
+
 	prev, next *Stream // guarded by streamsState.sendMu
 }
 
+type streamState uint32
+
+const (
+	// streamInSend and streamOutSend are set when there are
+	// frames to send for the inbound or outbound sides of the stream.
+	// For example, MAX_STREAM_DATA or STREAM_DATA_BLOCKED.
+	streamInSend = streamState(1 << iota)
+	streamOutSend
+
+	// streamInDone and streamOutDone are set when the inbound or outbound
+	// sides of the stream are finished. When both are set, the stream
+	// can be removed from the Conn and forgotten.
+	streamInDone
+	streamOutDone
+
+	// streamConnRemoved is set when the stream has been removed from the conn.
+	streamConnRemoved
+)
+
 // newStream returns a new stream.
 //
 // The stream's ingate and outgate are locked.
@@ -289,15 +318,34 @@
 // that the stream was terminated abruptly.
 // Any blocked writes will be unblocked and return errors.
 //
-// Reset sends the application protocol error code to the peer.
+// Reset sends the application protocol error code, which must be
+// less than 2^62, to the peer.
 // It does not wait for the peer to acknowledge receipt of the error.
 // Use CloseContext to wait for the peer's acknowledgement.
+//
+// Reset does not affect reads.
+// Use CloseRead to abort reads on the stream.
 func (s *Stream) Reset(code uint64) {
+	const userClosed = true
+	s.resetInternal(code, userClosed)
+}
+
+func (s *Stream) resetInternal(code uint64, userClosed bool) {
 	s.outgate.lock()
 	defer s.outUnlock()
+	if s.IsReadOnly() {
+		return
+	}
+	if userClosed {
+		// Mark that the user closed the stream.
+		s.outclosed.set()
+	}
 	if s.outreset.isSet() {
 		return
 	}
+	if code > maxVarint {
+		code = maxVarint
+	}
 	// We could check here to see if the stream is closed and the
 	// peer has acked all the data and the FIN, but sending an
 	// extra RESET_STREAM in this case is harmless.
@@ -310,44 +358,67 @@
 
 // inUnlock unlocks s.ingate.
 // It sets the gate condition if reads from s will not block.
-// If s has receive-related frames to write, it notifies the Conn.
+// If s has receive-related frames to write or if both directions
+// are done and the stream should be removed, it notifies the Conn.
 func (s *Stream) inUnlock() {
-	if s.inUnlockNoQueue() {
+	state := s.inUnlockNoQueue()
+	if state&streamInSend != 0 || state == streamInDone|streamOutDone {
 		s.conn.queueStreamForSend(s)
 	}
 }
 
 // inUnlockNoQueue is inUnlock,
 // but reports whether s has frames to write rather than notifying the Conn.
-func (s *Stream) inUnlockNoQueue() (shouldSend bool) {
+func (s *Stream) inUnlockNoQueue() streamState {
 	canRead := s.inset.contains(s.in.start) || // data available to read
 		s.insize == s.in.start || // at EOF
 		s.inresetcode != -1 || // reset by peer
 		s.inclosed.isSet() // closed locally
 	defer s.ingate.unlock(canRead)
-	return s.insendmax.shouldSend() || // STREAM_MAX_DATA
-		s.inclosed.shouldSend() // STOP_SENDING
+	var state streamState
+	switch {
+	case s.IsWriteOnly():
+		state = streamInDone
+	case s.inresetcode != -1: // reset by peer
+		fallthrough
+	case s.in.start == s.insize: // all data received and read
+		// We don't increase MAX_STREAMS until the user calls ReadClose or Close,
+		// so the receive side is not finished until inclosed is set.
+		if s.inclosed.isSet() {
+			state = streamInDone
+		}
+	case s.insendmax.shouldSend(): // STREAM_MAX_DATA
+		state = streamInSend
+	case s.inclosed.shouldSend(): // STOP_SENDING
+		state = streamInSend
+	}
+	const mask = streamInDone | streamInSend
+	return s.state.set(state, mask)
 }
 
 // outUnlock unlocks s.outgate.
 // It sets the gate condition if writes to s will not block.
-// If s has send-related frames to write, it notifies the Conn.
+// If s has send-related frames to write or if both directions
+// are done and the stream should be removed, it notifies the Conn.
 func (s *Stream) outUnlock() {
-	if s.outUnlockNoQueue() {
+	state := s.outUnlockNoQueue()
+	if state&streamOutSend != 0 || state == streamInDone|streamOutDone {
 		s.conn.queueStreamForSend(s)
 	}
 }
 
 // outUnlockNoQueue is outUnlock,
 // but reports whether s has frames to write rather than notifying the Conn.
-func (s *Stream) outUnlockNoQueue() (shouldSend bool) {
+func (s *Stream) outUnlockNoQueue() streamState {
 	isDone := s.outclosed.isReceived() && s.outacked.isrange(0, s.out.end) || // all data acked
 		s.outreset.isSet() // reset locally
 	if isDone {
 		select {
 		case <-s.outdone:
 		default:
-			close(s.outdone)
+			if !s.IsReadOnly() {
+				close(s.outdone)
+			}
 		}
 	}
 	lim := min(s.out.start+s.outmaxbuf, s.outwin)
@@ -355,14 +426,32 @@
 		s.outclosed.isSet() || // closed locally
 		s.outreset.isSet() // reset locally
 	defer s.outgate.unlock(canWrite)
-	if s.outreset.isSet() {
-		// If the stream is reset locally, the only frame we'll send is RESET_STREAM.
-		return s.outreset.shouldSend()
+	var state streamState
+	switch {
+	case s.IsReadOnly():
+		state = streamOutDone
+	case s.outclosed.isReceived() && s.outacked.isrange(0, s.out.end): // all data sent and acked
+		fallthrough
+	case s.outreset.isReceived(): // RESET_STREAM sent and acked
+		// We don't increase MAX_STREAMS until the user calls WriteClose or Close,
+		// so the send side is not finished until outclosed is set.
+		if s.outclosed.isSet() {
+			state = streamOutDone
+		}
+	case s.outreset.shouldSend(): // RESET_STREAM
+		state = streamOutSend
+	case s.outreset.isSet(): // RESET_STREAM sent but not acknowledged
+	case len(s.outunsent) > 0: // STREAM frame with data
+		state = streamOutSend
+	case s.outclosed.shouldSend(): // STREAM frame with FIN bit
+		state = streamOutSend
+	case s.outopened.shouldSend(): // STREAM frame with no data
+		state = streamOutSend
+	case s.outblocked.shouldSend(): // STREAM_DATA_BLOCKED
+		state = streamOutSend
 	}
-	return len(s.outunsent) > 0 || // STREAM frame with data
-		s.outclosed.shouldSend() || // STREAM frame with FIN bit
-		s.outopened.shouldSend() || // STREAM frame with no data
-		s.outblocked.shouldSend() // STREAM_DATA_BLOCKED
+	const mask = streamOutDone | streamOutSend
+	return s.state.set(state, mask)
 }
 
 // handleData handles data received in a STREAM frame.
@@ -431,7 +520,8 @@
 func (s *Stream) handleStopSending(code uint64) error {
 	// Peer requests that we reset this stream.
 	// https://www.rfc-editor.org/rfc/rfc9000#section-3.5-4
-	s.Reset(code)
+	const userReset = false
+	s.resetInternal(code, userReset)
 	return nil
 }
 
@@ -504,14 +594,12 @@
 	}
 }
 
-// appendInFrames appends STOP_SENDING and MAX_STREAM_DATA frames
+// appendInFramesLocked appends STOP_SENDING and MAX_STREAM_DATA frames
 // to the current packet.
 //
 // It returns true if no more frames need appending,
 // false if not everything fit in the current packet.
-func (s *Stream) appendInFrames(w *packetWriter, pnum packetNumber, pto bool) bool {
-	s.ingate.lock()
-	defer s.inUnlockNoQueue()
+func (s *Stream) appendInFramesLocked(w *packetWriter, pnum packetNumber, pto bool) bool {
 	if s.inclosed.shouldSendPTO(pto) {
 		// We don't currently have an API for setting the error code.
 		// Just send zero.
@@ -534,14 +622,12 @@
 	return true
 }
 
-// appendOutFrames appends RESET_STREAM, STREAM_DATA_BLOCKED, and STREAM frames
+// appendOutFramesLocked appends RESET_STREAM, STREAM_DATA_BLOCKED, and STREAM frames
 // to the current packet.
 //
 // It returns true if no more frames need appending,
 // false if not everything fit in the current packet.
-func (s *Stream) appendOutFrames(w *packetWriter, pnum packetNumber, pto bool) bool {
-	s.outgate.lock()
-	defer s.outUnlockNoQueue()
+func (s *Stream) appendOutFramesLocked(w *packetWriter, pnum packetNumber, pto bool) bool {
 	if s.outreset.isSet() {
 		// RESET_STREAM
 		if s.outreset.shouldSendPTO(pto) {
diff --git a/internal/quic/stream_test.go b/internal/quic/stream_test.go
index 79377c6..e22e043 100644
--- a/internal/quic/stream_test.go
+++ b/internal/quic/stream_test.go
@@ -1111,6 +1111,24 @@
 	})
 }
 
+func TestStreamResetInvalidCode(t *testing.T) {
+	tc, s := newTestConnAndLocalStream(t, serverSide, uniStream)
+	s.Reset(1 << 62)
+	tc.wantFrame("reset with invalid code sends a RESET_STREAM anyway",
+		packetType1RTT, debugFrameResetStream{
+			id: s.id,
+			// The code we send here isn't specified,
+			// so this could really be any value.
+			code: (1 << 62) - 1,
+		})
+}
+
+func TestStreamResetReceiveOnly(t *testing.T) {
+	tc, s := newTestConnAndRemoteStream(t, serverSide, uniStream)
+	s.Reset(0)
+	tc.wantIdle("resetting a receive-only stream has no effect")
+}
+
 func TestStreamPeerStopSendingForActiveStream(t *testing.T) {
 	// "An endpoint that receives a STOP_SENDING frame MUST send a RESET_STREAM frame if
 	// the stream is in the "Ready" or "Send" state."
@@ -1145,6 +1163,21 @@
 	})
 }
 
+type streamSide string
+
+const (
+	localStream  = streamSide("local")
+	remoteStream = streamSide("remote")
+)
+
+func newTestConnAndStream(t *testing.T, side connSide, sside streamSide, styp streamType, opts ...any) (*testConn, *Stream) {
+	if sside == localStream {
+		return newTestConnAndLocalStream(t, side, styp, opts...)
+	} else {
+		return newTestConnAndRemoteStream(t, side, styp, opts...)
+	}
+}
+
 func newTestConnAndLocalStream(t *testing.T, side connSide, styp streamType, opts ...any) (*testConn, *Stream) {
 	t.Helper()
 	ctx := canceledContext()