quic: simplify gate operations

Unify the waitAndLockDate and waitOnDone test hooks into a single
waitUntil, which takes a func param reporting when the operation
is done.

Make gate.waitAndLock take a Context, drop waitAndLockContext.
Everything that locks a gate passes a Context; there's no need
for the context-free variant.

Drop gate.waitWithLock, nothing used it.

Add a connTestHooks parameter to gate.waitAndLock and queue.get.
This parameter is an abstraction layer violation, but pretending
we're not always passing it through is just unnecessary confusion.

For golang/go#58547

Change-Id: Ifefb73b5a4ae0bac9822a5334117f3b3989f019e
Reviewed-on: https://go-review.googlesource.com/c/net/+/524957
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/internal/quic/conn.go b/internal/quic/conn.go
index 642c507..707b335 100644
--- a/internal/quic/conn.go
+++ b/internal/quic/conn.go
@@ -72,8 +72,7 @@
 	nextMessage(msgc chan any, nextTimeout time.Time) (now time.Time, message any)
 	handleTLSEvent(tls.QUICEvent)
 	newConnID(seq int64) ([]byte, error)
-	waitAndLockGate(ctx context.Context, g *gate) error
-	waitOnDone(ctx context.Context, ch <-chan struct{}) error
+	waitUntil(ctx context.Context, until func() bool) error
 }
 
 func newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip.AddrPort, config *Config, l connListener, hooks connTestHooks) (*Conn, error) {
@@ -315,16 +314,16 @@
 	return nil
 }
 
-func (c *Conn) waitAndLockGate(ctx context.Context, g *gate) error {
-	if c.testHooks != nil {
-		return c.testHooks.waitAndLockGate(ctx, g)
-	}
-	return g.waitAndLockContext(ctx)
-}
-
 func (c *Conn) waitOnDone(ctx context.Context, ch <-chan struct{}) error {
 	if c.testHooks != nil {
-		return c.testHooks.waitOnDone(ctx, ch)
+		return c.testHooks.waitUntil(ctx, func() bool {
+			select {
+			case <-ch:
+				return true
+			default:
+			}
+			return false
+		})
 	}
 	// Check the channel before the context.
 	// We always prefer to return results when available,
diff --git a/internal/quic/conn_async_test.go b/internal/quic/conn_async_test.go
index 5b419c4..dc2a57f 100644
--- a/internal/quic/conn_async_test.go
+++ b/internal/quic/conn_async_test.go
@@ -83,10 +83,7 @@
 
 // A blockedAsync is a blocked async operation.
 type blockedAsync struct {
-	// Exactly one of these will be set, depending on the type of blocked operation.
-	g  *gate
-	ch <-chan struct{}
-
+	until func() bool   // when this returns true, the operation is unblocked
 	donec chan struct{} // closed when the operation is unblocked
 }
 
@@ -130,31 +127,12 @@
 	return a
 }
 
-// waitAndLockGate replaces gate.waitAndLock in tests.
-func (as *asyncTestState) waitAndLockGate(ctx context.Context, g *gate) error {
-	if g.lockIfSet() {
-		// Gate can be acquired without blocking.
+// waitUntil waits for a blocked async operation to complete.
+// The operation is complete when the until func returns true.
+func (as *asyncTestState) waitUntil(ctx context.Context, until func() bool) error {
+	if until() {
 		return nil
 	}
-	return as.block(ctx, &blockedAsync{
-		g: g,
-	})
-}
-
-// waitOnDone replaces receiving from a chan struct{} in tests.
-func (as *asyncTestState) waitOnDone(ctx context.Context, ch <-chan struct{}) error {
-	select {
-	case <-ch:
-		return nil // read without blocking
-	default:
-	}
-	return as.block(ctx, &blockedAsync{
-		ch: ch,
-	})
-}
-
-// block waits for a blocked async operation to complete.
-func (as *asyncTestState) block(ctx context.Context, b *blockedAsync) error {
 	if err := ctx.Err(); err != nil {
 		// Context has already expired.
 		return err
@@ -166,7 +144,10 @@
 		// which may have unpredictable results.
 		panic("blocking async point with unexpected Context")
 	}
-	b.donec = make(chan struct{})
+	b := &blockedAsync{
+		until: until,
+		donec: make(chan struct{}),
+	}
 	// Record this as a pending blocking operation.
 	as.mu.Lock()
 	as.blocked[b] = struct{}{}
@@ -188,20 +169,9 @@
 	as.mu.Lock()
 	var woken *blockedAsync
 	for w := range as.blocked {
-		switch {
-		case w.g != nil:
-			if w.g.lockIfSet() {
-				woken = w
-			}
-		case w.ch != nil:
-			select {
-			case <-w.ch:
-				woken = w
-			default:
-			}
-		}
-		if woken != nil {
-			delete(as.blocked, woken)
+		if w.until() {
+			woken = w
+			delete(as.blocked, w)
 			break
 		}
 	}
diff --git a/internal/quic/conn_streams.go b/internal/quic/conn_streams.go
index 716ed2d..9ec2fa0 100644
--- a/internal/quic/conn_streams.go
+++ b/internal/quic/conn_streams.go
@@ -47,7 +47,7 @@
 
 // AcceptStream waits for and returns the next stream created by the peer.
 func (c *Conn) AcceptStream(ctx context.Context) (*Stream, error) {
-	return c.streams.queue.getWithHooks(ctx, c.testHooks)
+	return c.streams.queue.get(ctx, c.testHooks)
 }
 
 // NewStream creates a stream.
diff --git a/internal/quic/gate.go b/internal/quic/gate.go
index 27ab07a..a2fb537 100644
--- a/internal/quic/gate.go
+++ b/internal/quic/gate.go
@@ -47,13 +47,11 @@
 }
 
 // waitAndLock waits until the condition is set before acquiring the gate.
-func (g *gate) waitAndLock() {
-	<-g.set
-}
-
-// waitAndLockContext waits until the condition is set before acquiring the gate.
-// If the context expires, waitAndLockContext returns an error and does not acquire the gate.
-func (g *gate) waitAndLockContext(ctx context.Context) error {
+// If the context expires, waitAndLock returns an error and does not acquire the gate.
+func (g *gate) waitAndLock(ctx context.Context, testHooks connTestHooks) error {
+	if testHooks != nil {
+		return testHooks.waitUntil(ctx, g.lockIfSet)
+	}
 	select {
 	case <-g.set:
 		return nil
@@ -67,23 +65,6 @@
 	}
 }
 
-// waitWithLock releases an acquired gate until the condition is set.
-// The caller must have previously acquired the gate.
-// Upon return from waitWithLock, the gate will still be held.
-// If waitWithLock returns nil, the condition is set.
-func (g *gate) waitWithLock(ctx context.Context) error {
-	g.unlock(false)
-	err := g.waitAndLockContext(ctx)
-	if err != nil {
-		if g.lock() {
-			// The condition was set in between the context expiring
-			// and us reacquiring the gate.
-			err = nil
-		}
-	}
-	return err
-}
-
 // lockIfSet acquires the gate if and only if the condition is set.
 func (g *gate) lockIfSet() (acquired bool) {
 	select {
diff --git a/internal/quic/gate_test.go b/internal/quic/gate_test.go
index 0122e39..9e84a84 100644
--- a/internal/quic/gate_test.go
+++ b/internal/quic/gate_test.go
@@ -41,82 +41,35 @@
 	}
 }
 
-func TestGateWaitAndLock(t *testing.T) {
+func TestGateWaitAndLockContext(t *testing.T) {
 	g := newGate()
+	// waitAndLock is canceled
+	ctx, cancel := context.WithCancel(context.Background())
+	go func() {
+		time.Sleep(1 * time.Millisecond)
+		cancel()
+	}()
+	if err := g.waitAndLock(ctx, nil); err != context.Canceled {
+		t.Errorf("g.waitAndLock() = %v, want context.Canceled", err)
+	}
+	// waitAndLock succeeds
 	set := false
 	go func() {
-		for i := 0; i < 3; i++ {
-			g.lock()
-			g.unlock(false)
-			time.Sleep(1 * time.Millisecond)
-		}
+		time.Sleep(1 * time.Millisecond)
 		g.lock()
 		set = true
 		g.unlock(true)
 	}()
-	g.waitAndLock()
+	if err := g.waitAndLock(context.Background(), nil); err != nil {
+		t.Errorf("g.waitAndLock() = %v, want nil", err)
+	}
 	if !set {
 		t.Errorf("g.waitAndLock() returned before gate was set")
 	}
-}
-
-func TestGateWaitAndLockContext(t *testing.T) {
-	g := newGate()
-	// waitAndLockContext is canceled
-	ctx, cancel := context.WithCancel(context.Background())
-	go func() {
-		time.Sleep(1 * time.Millisecond)
-		cancel()
-	}()
-	if err := g.waitAndLockContext(ctx); err != context.Canceled {
-		t.Errorf("g.waitAndLockContext() = %v, want context.Canceled", err)
-	}
-	// waitAndLockContext succeeds
-	set := false
-	go func() {
-		time.Sleep(1 * time.Millisecond)
-		g.lock()
-		set = true
-		g.unlock(true)
-	}()
-	if err := g.waitAndLockContext(context.Background()); err != nil {
-		t.Errorf("g.waitAndLockContext() = %v, want nil", err)
-	}
-	if !set {
-		t.Errorf("g.waitAndLockContext() returned before gate was set")
-	}
 	g.unlock(true)
-	// waitAndLockContext succeeds when the gate is set and the context is canceled
-	if err := g.waitAndLockContext(ctx); err != nil {
-		t.Errorf("g.waitAndLockContext() = %v, want nil", err)
-	}
-}
-
-func TestGateWaitWithLock(t *testing.T) {
-	g := newGate()
-	// waitWithLock is canceled
-	ctx, cancel := context.WithCancel(context.Background())
-	go func() {
-		time.Sleep(1 * time.Millisecond)
-		cancel()
-	}()
-	g.lock()
-	if err := g.waitWithLock(ctx); err != context.Canceled {
-		t.Errorf("g.waitWithLock() = %v, want context.Canceled", err)
-	}
-	// waitWithLock succeeds
-	set := false
-	go func() {
-		g.lock()
-		set = true
-		g.unlock(true)
-	}()
-	time.Sleep(1 * time.Millisecond)
-	if err := g.waitWithLock(context.Background()); err != nil {
-		t.Errorf("g.waitWithLock() = %v, want nil", err)
-	}
-	if !set {
-		t.Errorf("g.waitWithLock() returned before gate was set")
+	// waitAndLock succeeds when the gate is set and the context is canceled
+	if err := g.waitAndLock(ctx, nil); err != nil {
+		t.Errorf("g.waitAndLock() = %v, want nil", err)
 	}
 }
 
@@ -138,5 +91,5 @@
 		g.lock()
 		defer g.unlockFunc(func() bool { return true })
 	}()
-	g.waitAndLock()
+	g.waitAndLock(context.Background(), nil)
 }
diff --git a/internal/quic/queue.go b/internal/quic/queue.go
index 489721a..7085e57 100644
--- a/internal/quic/queue.go
+++ b/internal/quic/queue.go
@@ -44,21 +44,9 @@
 
 // get removes the first item from the queue, blocking until ctx is done, an item is available,
 // or the queue is closed.
-func (q *queue[T]) get(ctx context.Context) (T, error) {
-	return q.getWithHooks(ctx, nil)
-}
-
-// getWithHooks is get, but uses testHooks for locking when non-nil.
-// This is a bit of an layer violation, but a simplification overall.
-func (q *queue[T]) getWithHooks(ctx context.Context, testHooks connTestHooks) (T, error) {
+func (q *queue[T]) get(ctx context.Context, testHooks connTestHooks) (T, error) {
 	var zero T
-	var err error
-	if testHooks != nil {
-		err = testHooks.waitAndLockGate(ctx, &q.gate)
-	} else {
-		err = q.gate.waitAndLockContext(ctx)
-	}
-	if err != nil {
+	if err := q.gate.waitAndLock(ctx, testHooks); err != nil {
 		return zero, err
 	}
 	defer q.unlock()
diff --git a/internal/quic/queue_test.go b/internal/quic/queue_test.go
index 8debeff..d78216b 100644
--- a/internal/quic/queue_test.go
+++ b/internal/quic/queue_test.go
@@ -18,7 +18,7 @@
 	cancel()
 
 	q := newQueue[int]()
-	if got, err := q.get(nonblocking); err != context.Canceled {
+	if got, err := q.get(nonblocking, nil); err != context.Canceled {
 		t.Fatalf("q.get() = %v, %v, want nil, contex.Canceled", got, err)
 	}
 
@@ -28,13 +28,13 @@
 	if !q.put(2) {
 		t.Fatalf("q.put(2) = false, want true")
 	}
-	if got, err := q.get(nonblocking); got != 1 || err != nil {
+	if got, err := q.get(nonblocking, nil); got != 1 || err != nil {
 		t.Fatalf("q.get() = %v, %v, want 1, nil", got, err)
 	}
-	if got, err := q.get(nonblocking); got != 2 || err != nil {
+	if got, err := q.get(nonblocking, nil); got != 2 || err != nil {
 		t.Fatalf("q.get() = %v, %v, want 2, nil", got, err)
 	}
-	if got, err := q.get(nonblocking); err != context.Canceled {
+	if got, err := q.get(nonblocking, nil); err != context.Canceled {
 		t.Fatalf("q.get() = %v, %v, want nil, contex.Canceled", got, err)
 	}
 
@@ -42,7 +42,7 @@
 		time.Sleep(1 * time.Millisecond)
 		q.put(3)
 	}()
-	if got, err := q.get(context.Background()); got != 3 || err != nil {
+	if got, err := q.get(context.Background(), nil); got != 3 || err != nil {
 		t.Fatalf("q.get() = %v, %v, want 3, nil", got, err)
 	}
 
@@ -50,7 +50,7 @@
 		t.Fatalf("q.put(2) = false, want true")
 	}
 	q.close(io.EOF)
-	if got, err := q.get(context.Background()); got != 0 || err != io.EOF {
+	if got, err := q.get(context.Background(), nil); got != 0 || err != io.EOF {
 		t.Fatalf("q.get() = %v, %v, want 0, io.EOF", got, err)
 	}
 	if q.put(5) {
diff --git a/internal/quic/stream.go b/internal/quic/stream.go
index b759e40..d2f2cd7 100644
--- a/internal/quic/stream.go
+++ b/internal/quic/stream.go
@@ -133,7 +133,7 @@
 		return 0, errors.New("read from write-only stream")
 	}
 	// Wait until data is available.
-	if err := s.conn.waitAndLockGate(ctx, &s.ingate); err != nil {
+	if err := s.ingate.waitAndLock(ctx, s.conn.testHooks); err != nil {
 		return 0, err
 	}
 	defer s.inUnlock()
@@ -211,7 +211,7 @@
 				s.outblocked.setUnsent()
 			}
 			s.outUnlock()
-			if err := s.conn.waitAndLockGate(ctx, &s.outgate); err != nil {
+			if err := s.outgate.waitAndLock(ctx, s.conn.testHooks); err != nil {
 				return n, err
 			}
 			// Successfully returning from waitAndLockGate means we are no longer
diff --git a/internal/quic/stream_limits.go b/internal/quic/stream_limits.go
index 5ea7146..db3ab22 100644
--- a/internal/quic/stream_limits.go
+++ b/internal/quic/stream_limits.go
@@ -31,7 +31,7 @@
 // open creates a new local stream, blocking until MAX_STREAMS quota is available.
 func (lim *localStreamLimits) open(ctx context.Context, c *Conn) (num int64, err error) {
 	// TODO: Send a STREAMS_BLOCKED when blocked.
-	if err := c.waitAndLockGate(ctx, &lim.gate); err != nil {
+	if err := lim.gate.waitAndLock(ctx, c.testHooks); err != nil {
 		return 0, err
 	}
 	n := lim.opened