go.crypto/ssh: close channel feeding tcpListener.
Close both on closing the listener, and on closing the
connection. Test the former case.
R=dave
CC=golang-dev
https://golang.org/cl/11349043
diff --git a/ssh/client.go b/ssh/client.go
index 16569a8..a506c77 100644
--- a/ssh/client.go
+++ b/ssh/client.go
@@ -210,7 +210,8 @@
func (c *ClientConn) mainLoop() {
defer func() {
c.Close()
- c.closeAll()
+ c.chanList.closeAll()
+ c.forwardList.closeAll()
}()
for {
diff --git a/ssh/tcpip.go b/ssh/tcpip.go
index 8ebe262..ad92b43 100644
--- a/ssh/tcpip.go
+++ b/ssh/tcpip.go
@@ -101,17 +101,30 @@
return f.c
}
+// remove removes the forward entry, and the channel feeding its
+// listener.
func (l *forwardList) remove(addr net.TCPAddr) {
l.Lock()
defer l.Unlock()
for i, f := range l.entries {
if addr.IP.Equal(f.laddr.IP) && addr.Port == f.laddr.Port {
l.entries = append(l.entries[:i], l.entries[i+1:]...)
+ close(f.c)
return
}
}
}
+// closeAll closes and clears all forwards.
+func (l *forwardList) closeAll() {
+ l.Lock()
+ defer l.Unlock()
+ for _, f := range l.entries {
+ close(f.c)
+ }
+ l.entries = nil
+}
+
func (l *forwardList) lookup(addr net.TCPAddr) (chan forward, bool) {
l.Lock()
defer l.Unlock()
diff --git a/ssh/test/forward_unix_test.go b/ssh/test/forward_unix_test.go
index dc64241..e15fae5 100644
--- a/ssh/test/forward_unix_test.go
+++ b/ssh/test/forward_unix_test.go
@@ -14,14 +14,12 @@
"math/rand"
"net"
"testing"
+ "time"
+
+ "code.google.com/p/go.crypto/ssh"
)
-func TestPortForward(t *testing.T) {
- server := newServer(t)
- defer server.Shutdown()
- conn := server.Dial(clientConfig())
- defer conn.Close()
-
+func listenSSHAuto(conn *ssh.ClientConn) (net.Listener, error) {
var sshListener net.Listener
var err error
tries := 10
@@ -38,7 +36,21 @@
}
if err != nil {
- t.Fatalf("conn.Listen failed: %v (after %d tries)", err, tries)
+ return nil, fmt.Errorf("conn.Listen failed: %v (after %d tries)", err, tries)
+ }
+
+ return sshListener, nil
+}
+
+func TestPortForward(t *testing.T) {
+ server := newServer(t)
+ defer server.Shutdown()
+ conn := server.Dial(clientConfig())
+ defer conn.Close()
+
+ sshListener, err := listenSSHAuto(conn)
+ if err != nil {
+ t.Fatal(err)
}
go func() {
@@ -106,3 +118,37 @@
t.Errorf("still listening to %s after closing", forwardedAddr)
}
}
+
+func TestAcceptClose(t *testing.T) {
+ server := newServer(t)
+ defer server.Shutdown()
+ conn := server.Dial(clientConfig())
+
+ sshListener, err := listenSSHAuto(conn)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ quit := make(chan error, 1)
+ go func() {
+ for {
+ c, err := sshListener.Accept()
+ if err != nil {
+ quit <- err
+ break
+ }
+ c.Close()
+ }
+ }()
+ sshListener.Close()
+
+ select {
+ case <-time.After(1 * time.Second):
+ t.Errorf("timeout: listener did not close.")
+ case err := <-quit:
+ t.Logf("quit as expected (error %v)", err)
+ }
+}
+
+// TODO(hanwen): test that closing the connection also
+// exits the listeners.