go.crypto/ssh: close channel feeding tcpListener. Close both on closing the listener, and on closing the connection. Test the former case. R=dave CC=golang-dev https://golang.org/cl/11349043
diff --git a/ssh/client.go b/ssh/client.go index 16569a8..a506c77 100644 --- a/ssh/client.go +++ b/ssh/client.go
@@ -210,7 +210,8 @@ func (c *ClientConn) mainLoop() { defer func() { c.Close() - c.closeAll() + c.chanList.closeAll() + c.forwardList.closeAll() }() for {
diff --git a/ssh/tcpip.go b/ssh/tcpip.go index 8ebe262..ad92b43 100644 --- a/ssh/tcpip.go +++ b/ssh/tcpip.go
@@ -101,17 +101,30 @@ return f.c } +// remove removes the forward entry, and the channel feeding its +// listener. func (l *forwardList) remove(addr net.TCPAddr) { l.Lock() defer l.Unlock() for i, f := range l.entries { if addr.IP.Equal(f.laddr.IP) && addr.Port == f.laddr.Port { l.entries = append(l.entries[:i], l.entries[i+1:]...) + close(f.c) return } } } +// closeAll closes and clears all forwards. +func (l *forwardList) closeAll() { + l.Lock() + defer l.Unlock() + for _, f := range l.entries { + close(f.c) + } + l.entries = nil +} + func (l *forwardList) lookup(addr net.TCPAddr) (chan forward, bool) { l.Lock() defer l.Unlock()
diff --git a/ssh/test/forward_unix_test.go b/ssh/test/forward_unix_test.go index dc64241..e15fae5 100644 --- a/ssh/test/forward_unix_test.go +++ b/ssh/test/forward_unix_test.go
@@ -14,14 +14,12 @@ "math/rand" "net" "testing" + "time" + + "code.google.com/p/go.crypto/ssh" ) -func TestPortForward(t *testing.T) { - server := newServer(t) - defer server.Shutdown() - conn := server.Dial(clientConfig()) - defer conn.Close() - +func listenSSHAuto(conn *ssh.ClientConn) (net.Listener, error) { var sshListener net.Listener var err error tries := 10 @@ -38,7 +36,21 @@ } if err != nil { - t.Fatalf("conn.Listen failed: %v (after %d tries)", err, tries) + return nil, fmt.Errorf("conn.Listen failed: %v (after %d tries)", err, tries) + } + + return sshListener, nil +} + +func TestPortForward(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + defer conn.Close() + + sshListener, err := listenSSHAuto(conn) + if err != nil { + t.Fatal(err) } go func() { @@ -106,3 +118,37 @@ t.Errorf("still listening to %s after closing", forwardedAddr) } } + +func TestAcceptClose(t *testing.T) { + server := newServer(t) + defer server.Shutdown() + conn := server.Dial(clientConfig()) + + sshListener, err := listenSSHAuto(conn) + if err != nil { + t.Fatal(err) + } + + quit := make(chan error, 1) + go func() { + for { + c, err := sshListener.Accept() + if err != nil { + quit <- err + break + } + c.Close() + } + }() + sshListener.Close() + + select { + case <-time.After(1 * time.Second): + t.Errorf("timeout: listener did not close.") + case err := <-quit: + t.Logf("quit as expected (error %v)", err) + } +} + +// TODO(hanwen): test that closing the connection also +// exits the listeners.