netutil: unblock LimitListener.Accept on Close

The net.Listener interface specifies that on Close:
// Any blocked Accept operations will be unblocked and return errors.

Fixes golang/go#24458

Change-Id: I4a61a79db9579a40b536aa65c8077da87aa25156
Reviewed-on: https://go-review.googlesource.com/101535
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/netutil/listen.go b/netutil/listen.go
index 56f43bf..cee46e3 100644
--- a/netutil/listen.go
+++ b/netutil/listen.go
@@ -14,27 +14,53 @@
 // LimitListener returns a Listener that accepts at most n simultaneous
 // connections from the provided Listener.
 func LimitListener(l net.Listener, n int) net.Listener {
-	return &limitListener{l, make(chan struct{}, n)}
+	return &limitListener{
+		Listener: l,
+		sem:      make(chan struct{}, n),
+		done:     make(chan struct{}),
+	}
 }
 
 type limitListener struct {
 	net.Listener
-	sem chan struct{}
+	sem       chan struct{}
+	closeOnce sync.Once     // ensures the done chan is only closed once
+	done      chan struct{} // no values sent; closed when Close is called
 }
 
-func (l *limitListener) acquire() { l.sem <- struct{}{} }
+// acquire acquires the limiting semaphore. Returns true if successfully
+// accquired, false if the listener is closed and the semaphore is not
+// acquired.
+func (l *limitListener) acquire() bool {
+	select {
+	case <-l.done:
+		return false
+	case l.sem <- struct{}{}:
+		return true
+	}
+}
 func (l *limitListener) release() { <-l.sem }
 
 func (l *limitListener) Accept() (net.Conn, error) {
-	l.acquire()
+	acquired := l.acquire()
+	// If the semaphore isn't acquired because the listener was closed, expect
+	// that this call to accept won't block, but immediately return an error.
 	c, err := l.Listener.Accept()
 	if err != nil {
-		l.release()
+		if acquired {
+			l.release()
+		}
 		return nil, err
 	}
 	return &limitListenerConn{Conn: c, release: l.release}, nil
 }
 
+func (l *limitListener) Close() error {
+	err := l.Listener.Close()
+	l.closeOnce.Do(func() { close(l.done) })
+	return err
+}
+
 type limitListenerConn struct {
 	net.Conn
 	releaseOnce sync.Once
diff --git a/netutil/listen_test.go b/netutil/listen_test.go
index 5e07d7b..f40c9aa 100644
--- a/netutil/listen_test.go
+++ b/netutil/listen_test.go
@@ -99,3 +99,49 @@
 		t.Fatal("timeout. deadlock?")
 	}
 }
+
+func TestLimitListenerClose(t *testing.T) {
+	ln, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer ln.Close()
+	ln = LimitListener(ln, 1)
+
+	doneCh := make(chan struct{})
+	defer close(doneCh)
+	go func() {
+		c, err := net.Dial("tcp", ln.Addr().String())
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer c.Close()
+		<-doneCh
+	}()
+
+	c, err := ln.Accept()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer c.Close()
+
+	acceptDone := make(chan struct{})
+	go func() {
+		c, err := ln.Accept()
+		if err == nil {
+			c.Close()
+			t.Errorf("Unexpected successful Accept()")
+		}
+		close(acceptDone)
+	}()
+
+	// Wait a tiny bit to ensure the Accept() is blocking.
+	time.Sleep(10 * time.Millisecond)
+	ln.Close()
+
+	select {
+	case <-acceptDone:
+	case <-time.After(5 * time.Second):
+		t.Fatalf("Accept() still blocking")
+	}
+}