netutil: make LimitListener tests more robust

In CL 372495 I cleaned up TestLimitListener so that it would not fail
spuriously. However, upon further thought I realized that the original
test was actually checking two different properties (steady-state
saturation, and actual overload), and the cleaned-up test was only
checking one of those (overload).

This change adds a separate test for steady-state saturation, and
makes the overload test more robust to spurious connections (which
could occur, for example, if another test running on the machine
accidentally dials this test's open port).

The test cleanup also revealed a bad interaction with an existing bug
in the js/wasm net.TCPListener implementation (filed as
golang/go#50216), for which I have added a workaround in
(*limitListener).Accept.

For golang/go#22926

Change-Id: I727050a8254f527c7455de296ed3525b6dc90141
Reviewed-on: https://go-review.googlesource.com/c/net/+/372714
Trust: Bryan Mills <bcmills@google.com>
Run-TryBot: Bryan Mills <bcmills@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
diff --git a/netutil/listen.go b/netutil/listen.go
index cee46e3..d5dfbab 100644
--- a/netutil/listen.go
+++ b/netutil/listen.go
@@ -42,14 +42,27 @@
 func (l *limitListener) release() { <-l.sem }
 
 func (l *limitListener) Accept() (net.Conn, error) {
-	acquired := l.acquire()
-	// If the semaphore isn't acquired because the listener was closed, expect
-	// that this call to accept won't block, but immediately return an error.
+	if !l.acquire() {
+		// If the semaphore isn't acquired because the listener was closed, expect
+		// that this call to accept won't block, but immediately return an error.
+		// If it instead returns a spurious connection (due to a bug in the
+		// Listener, such as https://golang.org/issue/50216), we immediately close
+		// it and try again. Some buggy Listener implementations (like the one in
+		// the aforementioned issue) seem to assume that Accept will be called to
+		// completion, and may otherwise fail to clean up the client end of pending
+		// connections.
+		for {
+			c, err := l.Listener.Accept()
+			if err != nil {
+				return nil, err
+			}
+			c.Close()
+		}
+	}
+
 	c, err := l.Listener.Accept()
 	if err != nil {
-		if acquired {
-			l.release()
-		}
+		l.release()
 		return nil, err
 	}
 	return &limitListenerConn{Conn: c, release: l.release}, nil
diff --git a/netutil/listen_test.go b/netutil/listen_test.go
index ab8e599..793a91d 100644
--- a/netutil/listen_test.go
+++ b/netutil/listen_test.go
@@ -15,7 +15,7 @@
 	"time"
 )
 
-func TestLimitListener(t *testing.T) {
+func TestLimitListenerOverload(t *testing.T) {
 	const (
 		max      = 5
 		attempts = max * 2
@@ -30,6 +30,7 @@
 
 	var wg sync.WaitGroup
 	wg.Add(1)
+	saturated := make(chan struct{})
 	go func() {
 		defer wg.Done()
 
@@ -40,69 +41,174 @@
 				break
 			}
 			accepted++
+			if accepted == max {
+				close(saturated)
+			}
 			io.WriteString(c, msg)
 
-			defer c.Close() // Leave c open until the listener is closed.
+			// Leave c open until the listener is closed.
+			defer c.Close()
 		}
-		if accepted > max {
-			t.Errorf("accepted %d simultaneous connections; want at most %d", accepted, max)
+		t.Logf("with limit %d, accepted %d simultaneous connections", max, accepted)
+		// The listener accounts open connections based on Listener-side Close
+		// calls, so even if the client hangs up early (for example, because it
+		// was a random dial from another process instead of from this test), we
+		// should not end up accepting more connections than expected.
+		if accepted != max {
+			t.Errorf("want exactly %d", max)
 		}
 	}()
 
-	// connc keeps the client end of the dialed connections alive until the
-	// test completes.
-	connc := make(chan []net.Conn, 1)
-	connc <- nil
-
 	dialCtx, cancelDial := context.WithCancel(context.Background())
 	defer cancelDial()
 	dialer := &net.Dialer{}
 
-	var served int32
+	var dialed, served int32
+	var pendingDials sync.WaitGroup
 	for n := attempts; n > 0; n-- {
 		wg.Add(1)
+		pendingDials.Add(1)
 		go func() {
 			defer wg.Done()
 
 			c, err := dialer.DialContext(dialCtx, l.Addr().Network(), l.Addr().String())
+			pendingDials.Done()
 			if err != nil {
 				t.Log(err)
 				return
 			}
+			atomic.AddInt32(&dialed, 1)
 			defer c.Close()
 
-			// Keep this end of the connection alive until after the Listener
-			// finishes.
-			conns := append(<-connc, c)
-			if len(conns) == max {
-				go func() {
-					// Give the server a bit of time to make sure it doesn't exceed its
-					// limit after serving this connection, then cancel the remaining
-					// Dials (if any).
-					time.Sleep(10 * time.Millisecond)
-					cancelDial()
-					l.Close()
-				}()
-			}
-			connc <- conns
-
-			b := make([]byte, len(msg))
-			if n, err := c.Read(b); n < len(b) {
+			// The kernel may queue more than max connections (allowing their dials to
+			// succeed), but only max of them should actually be accepted by the
+			// server. We can distinguish the two based on whether the listener writes
+			// anything to the connection — a connection that was queued but not
+			// accepted will be closed without transferring any data.
+			if b, err := io.ReadAll(c); len(b) < len(msg) {
 				t.Log(err)
 				return
 			}
 			atomic.AddInt32(&served, 1)
 		}()
 	}
+
+	// Give the server a bit of time after it saturates to make sure it doesn't
+	// exceed its limit after serving this connection, then cancel the remaining
+	// dials (if any).
+	<-saturated
+	time.Sleep(10 * time.Millisecond)
+	cancelDial()
+	// Wait for the dials to complete to ensure that the port isn't reused before
+	// the dials are actually attempted.
+	pendingDials.Wait()
+	l.Close()
 	wg.Wait()
 
-	conns := <-connc
-	for _, c := range conns {
-		c.Close()
+	t.Logf("served %d simultaneous connections (of %d dialed, %d attempted)", served, dialed, attempts)
+
+	// If some other process (such as a port scan or another test) happens to dial
+	// the listener at the same time, the listener could end up burning its quota
+	// on that, resulting in fewer than max test connections being served.
+	// But the number served certainly cannot be greater.
+	if served > max {
+		t.Errorf("expected at most %d served", max)
 	}
-	t.Logf("with limit %d, served %d connections (of %d dialed, %d attempted)", max, served, len(conns), attempts)
-	if served != max {
-		t.Errorf("expected exactly %d served", max)
+}
+
+func TestLimitListenerSaturation(t *testing.T) {
+	const (
+		max             = 5
+		attemptsPerWave = max * 2
+		waves           = 10
+		msg             = "bye\n"
+	)
+
+	l, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		t.Fatal(err)
+	}
+	l = LimitListener(l, max)
+
+	acceptDone := make(chan struct{})
+	defer func() {
+		l.Close()
+		<-acceptDone
+	}()
+	go func() {
+		defer close(acceptDone)
+
+		var open, peakOpen int32
+		var (
+			saturated     = make(chan struct{})
+			saturatedOnce sync.Once
+		)
+		var wg sync.WaitGroup
+		for {
+			c, err := l.Accept()
+			if err != nil {
+				break
+			}
+			if n := atomic.AddInt32(&open, 1); n > peakOpen {
+				peakOpen = n
+				if n == max {
+					saturatedOnce.Do(func() {
+						// Wait a bit to make sure the listener doesn't exceed its limit
+						// after accepting this connection, then allow the in-flight
+						// connections to write out and close.
+						time.AfterFunc(10*time.Millisecond, func() { close(saturated) })
+					})
+				}
+			}
+			wg.Add(1)
+			go func() {
+				<-saturated
+				io.WriteString(c, msg)
+				atomic.AddInt32(&open, -1)
+				c.Close()
+				wg.Done()
+			}()
+		}
+		wg.Wait()
+
+		t.Logf("with limit %d, accepted a peak of %d simultaneous connections", max, peakOpen)
+		if peakOpen > max {
+			t.Errorf("want at most %d", max)
+		}
+	}()
+
+	for wave := 0; wave < waves; wave++ {
+		var dialed, served int32
+		var wg sync.WaitGroup
+		for n := attemptsPerWave; n > 0; n-- {
+			wg.Add(1)
+			go func() {
+				defer wg.Done()
+
+				c, err := net.Dial(l.Addr().Network(), l.Addr().String())
+				if err != nil {
+					t.Log(err)
+					return
+				}
+				atomic.AddInt32(&dialed, 1)
+				defer c.Close()
+
+				if b, err := io.ReadAll(c); len(b) < len(msg) {
+					t.Log(err)
+					return
+				}
+				atomic.AddInt32(&served, 1)
+			}()
+		}
+		wg.Wait()
+
+		t.Logf("served %d connections (of %d dialed, %d attempted)", served, dialed, attemptsPerWave)
+		// We expect that the kernel can queue at least attemptsPerWave
+		// connections at a time (since it's only a small number), so every
+		// connection should eventually be served.
+		if served != attemptsPerWave {
+			t.Errorf("expected %d served", attemptsPerWave)
+		}
 	}
 }
 
@@ -160,9 +266,7 @@
 
 	// Allow the subsequent Accept to block before closing the listener.
 	// (Accept should unblock and return.)
-	timer := time.AfterFunc(10*time.Millisecond, func() {
-		ln.Close()
-	})
+	timer := time.AfterFunc(10*time.Millisecond, func() { ln.Close() })
 
 	c, err = ln.Accept()
 	if err == nil {