netutil: unblock LimitListener.Accept on Close
The net.Listener interface specifies that on Close:
// Any blocked Accept operations will be unblocked and return errors.
Fixes golang/go#24458
Change-Id: I4a61a79db9579a40b536aa65c8077da87aa25156
Reviewed-on: https://go-review.googlesource.com/101535
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/netutil/listen.go b/netutil/listen.go
index 56f43bf..cee46e3 100644
--- a/netutil/listen.go
+++ b/netutil/listen.go
@@ -14,27 +14,53 @@
// LimitListener returns a Listener that accepts at most n simultaneous
// connections from the provided Listener.
func LimitListener(l net.Listener, n int) net.Listener {
- return &limitListener{l, make(chan struct{}, n)}
+ return &limitListener{
+ Listener: l,
+ sem: make(chan struct{}, n),
+ done: make(chan struct{}),
+ }
}
type limitListener struct {
net.Listener
- sem chan struct{}
+ sem chan struct{}
+ closeOnce sync.Once // ensures the done chan is only closed once
+ done chan struct{} // no values sent; closed when Close is called
}
-func (l *limitListener) acquire() { l.sem <- struct{}{} }
+// acquire acquires the limiting semaphore. Returns true if successfully
+// accquired, false if the listener is closed and the semaphore is not
+// acquired.
+func (l *limitListener) acquire() bool {
+ select {
+ case <-l.done:
+ return false
+ case l.sem <- struct{}{}:
+ return true
+ }
+}
func (l *limitListener) release() { <-l.sem }
func (l *limitListener) Accept() (net.Conn, error) {
- l.acquire()
+ acquired := l.acquire()
+ // If the semaphore isn't acquired because the listener was closed, expect
+ // that this call to accept won't block, but immediately return an error.
c, err := l.Listener.Accept()
if err != nil {
- l.release()
+ if acquired {
+ l.release()
+ }
return nil, err
}
return &limitListenerConn{Conn: c, release: l.release}, nil
}
+func (l *limitListener) Close() error {
+ err := l.Listener.Close()
+ l.closeOnce.Do(func() { close(l.done) })
+ return err
+}
+
type limitListenerConn struct {
net.Conn
releaseOnce sync.Once
diff --git a/netutil/listen_test.go b/netutil/listen_test.go
index 5e07d7b..f40c9aa 100644
--- a/netutil/listen_test.go
+++ b/netutil/listen_test.go
@@ -99,3 +99,49 @@
t.Fatal("timeout. deadlock?")
}
}
+
+func TestLimitListenerClose(t *testing.T) {
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+ ln = LimitListener(ln, 1)
+
+ doneCh := make(chan struct{})
+ defer close(doneCh)
+ go func() {
+ c, err := net.Dial("tcp", ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ <-doneCh
+ }()
+
+ c, err := ln.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ acceptDone := make(chan struct{})
+ go func() {
+ c, err := ln.Accept()
+ if err == nil {
+ c.Close()
+ t.Errorf("Unexpected successful Accept()")
+ }
+ close(acceptDone)
+ }()
+
+ // Wait a tiny bit to ensure the Accept() is blocking.
+ time.Sleep(10 * time.Millisecond)
+ ln.Close()
+
+ select {
+ case <-acceptDone:
+ case <-time.After(5 * time.Second):
+ t.Fatalf("Accept() still blocking")
+ }
+}