ssh: support forwarding of Unix domain socket connections

This commit implements OpenSSH streamlocal extension, providing the equivalent
of `ssh -L local.sock:remote.sock`.

Change-Id: Idd6287d5a5669c643132bba770c3b4194615e84d
Reviewed-on: https://go-review.googlesource.com/38614
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/ssh/client.go b/ssh/client.go
index 667e371..a7e3263 100644
--- a/ssh/client.go
+++ b/ssh/client.go
@@ -14,7 +14,7 @@
 )
 
 // Client implements a traditional SSH client that supports shells,
-// subprocesses, port forwarding and tunneled dialing.
+// subprocesses, TCP port/streamlocal forwarding and tunneled dialing.
 type Client struct {
 	Conn
 
@@ -60,6 +60,7 @@
 		conn.forwards.closeAll()
 	}()
 	go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-tcpip"))
+	go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-streamlocal@openssh.com"))
 	return conn
 }
 
diff --git a/ssh/streamlocal.go b/ssh/streamlocal.go
new file mode 100644
index 0000000..a2dccc6
--- /dev/null
+++ b/ssh/streamlocal.go
@@ -0,0 +1,115 @@
+package ssh
+
+import (
+	"errors"
+	"io"
+	"net"
+)
+
+// streamLocalChannelOpenDirectMsg is a struct used for SSH_MSG_CHANNEL_OPEN message
+// with "direct-streamlocal@openssh.com" string.
+//
+// See openssh-portable/PROTOCOL, section 2.4. connection: Unix domain socket forwarding
+// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL#L235
+type streamLocalChannelOpenDirectMsg struct {
+	socketPath string
+	reserved0  string
+	reserved1  uint32
+}
+
+// forwardedStreamLocalPayload is a struct used for SSH_MSG_CHANNEL_OPEN message
+// with "forwarded-streamlocal@openssh.com" string.
+type forwardedStreamLocalPayload struct {
+	SocketPath string
+	Reserved0  string
+}
+
+// streamLocalChannelForwardMsg is a struct used for SSH2_MSG_GLOBAL_REQUEST message
+// with "streamlocal-forward@openssh.com"/"cancel-streamlocal-forward@openssh.com" string.
+type streamLocalChannelForwardMsg struct {
+	socketPath string
+}
+
+// ListenUnix is similar to ListenTCP but uses a Unix domain socket.
+func (c *Client) ListenUnix(socketPath string) (net.Listener, error) {
+	m := streamLocalChannelForwardMsg{
+		socketPath,
+	}
+	// send message
+	ok, _, err := c.SendRequest("streamlocal-forward@openssh.com", true, Marshal(&m))
+	if err != nil {
+		return nil, err
+	}
+	if !ok {
+		return nil, errors.New("ssh: streamlocal-forward@openssh.com request denied by peer")
+	}
+	ch := c.forwards.add(&net.UnixAddr{Name: socketPath, Net: "unix"})
+
+	return &unixListener{socketPath, c, ch}, nil
+}
+
+func (c *Client) dialStreamLocal(socketPath string) (Channel, error) {
+	msg := streamLocalChannelOpenDirectMsg{
+		socketPath: socketPath,
+	}
+	ch, in, err := c.OpenChannel("direct-streamlocal@openssh.com", Marshal(&msg))
+	if err != nil {
+		return nil, err
+	}
+	go DiscardRequests(in)
+	return ch, err
+}
+
+type unixListener struct {
+	socketPath string
+
+	conn *Client
+	in   <-chan forward
+}
+
+// Accept waits for and returns the next connection to the listener.
+func (l *unixListener) Accept() (net.Conn, error) {
+	s, ok := <-l.in
+	if !ok {
+		return nil, io.EOF
+	}
+	ch, incoming, err := s.newCh.Accept()
+	if err != nil {
+		return nil, err
+	}
+	go DiscardRequests(incoming)
+
+	return &chanConn{
+		Channel: ch,
+		laddr: &net.UnixAddr{
+			Name: l.socketPath,
+			Net:  "unix",
+		},
+		raddr: &net.UnixAddr{
+			Name: "@",
+			Net:  "unix",
+		},
+	}, nil
+}
+
+// Close closes the listener.
+func (l *unixListener) Close() error {
+	// this also closes the listener.
+	l.conn.forwards.remove(&net.UnixAddr{Name: l.socketPath, Net: "unix"})
+	m := streamLocalChannelForwardMsg{
+		l.socketPath,
+	}
+	ok, _, err := l.conn.SendRequest("cancel-streamlocal-forward@openssh.com", true, Marshal(&m))
+	if err == nil && !ok {
+		err = errors.New("ssh: cancel-streamlocal-forward@openssh.com failed")
+	}
+	return err
+}
+
+// Addr returns the listener's network address.
+func (l *unixListener) Addr() net.Addr {
+	return &net.UnixAddr{
+		Name: l.socketPath,
+		Net:  "unix",
+	}
+}
diff --git a/ssh/tcpip.go b/ssh/tcpip.go
index 6151241..acf1717 100644
--- a/ssh/tcpip.go
+++ b/ssh/tcpip.go
@@ -20,12 +20,20 @@
 // addr. Incoming connections will be available by calling Accept on
 // the returned net.Listener. The listener must be serviced, or the
 // SSH connection may hang.
+// N must be "tcp", "tcp4", "tcp6", or "unix".
 func (c *Client) Listen(n, addr string) (net.Listener, error) {
-	laddr, err := net.ResolveTCPAddr(n, addr)
-	if err != nil {
-		return nil, err
+	switch n {
+	case "tcp", "tcp4", "tcp6":
+		laddr, err := net.ResolveTCPAddr(n, addr)
+		if err != nil {
+			return nil, err
+		}
+		return c.ListenTCP(laddr)
+	case "unix":
+		return c.ListenUnix(addr)
+	default:
+		return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
 	}
-	return c.ListenTCP(laddr)
 }
 
 // Automatic port allocation is broken with OpenSSH before 6.0. See
@@ -116,7 +124,7 @@
 	}
 
 	// Register this forward, using the port number we obtained.
-	ch := c.forwards.add(*laddr)
+	ch := c.forwards.add(laddr)
 
 	return &tcpListener{laddr, c, ch}, nil
 }
@@ -131,7 +139,7 @@
 // forwardEntry represents an established mapping of a laddr on a
 // remote ssh server to a channel connected to a tcpListener.
 type forwardEntry struct {
-	laddr net.TCPAddr
+	laddr net.Addr
 	c     chan forward
 }
 
@@ -139,16 +147,16 @@
 // arguments to add/remove/lookup should be address as specified in
 // the original forward-request.
 type forward struct {
-	newCh NewChannel   // the ssh client channel underlying this forward
-	raddr *net.TCPAddr // the raddr of the incoming connection
+	newCh NewChannel // the ssh client channel underlying this forward
+	raddr net.Addr   // the raddr of the incoming connection
 }
 
-func (l *forwardList) add(addr net.TCPAddr) chan forward {
+func (l *forwardList) add(addr net.Addr) chan forward {
 	l.Lock()
 	defer l.Unlock()
 	f := forwardEntry{
-		addr,
-		make(chan forward, 1),
+		laddr: addr,
+		c:     make(chan forward, 1),
 	}
 	l.entries = append(l.entries, f)
 	return f.c
@@ -176,44 +184,69 @@
 
 func (l *forwardList) handleChannels(in <-chan NewChannel) {
 	for ch := range in {
-		var payload forwardedTCPPayload
-		if err := Unmarshal(ch.ExtraData(), &payload); err != nil {
-			ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error())
-			continue
-		}
+		var (
+			laddr net.Addr
+			raddr net.Addr
+			err   error
+		)
+		switch channelType := ch.ChannelType(); channelType {
+		case "forwarded-tcpip":
+			var payload forwardedTCPPayload
+			if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
+				ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error())
+				continue
+			}
 
-		// RFC 4254 section 7.2 specifies that incoming
-		// addresses should list the address, in string
-		// format. It is implied that this should be an IP
-		// address, as it would be impossible to connect to it
-		// otherwise.
-		laddr, err := parseTCPAddr(payload.Addr, payload.Port)
-		if err != nil {
-			ch.Reject(ConnectionFailed, err.Error())
-			continue
-		}
-		raddr, err := parseTCPAddr(payload.OriginAddr, payload.OriginPort)
-		if err != nil {
-			ch.Reject(ConnectionFailed, err.Error())
-			continue
-		}
+			// RFC 4254 section 7.2 specifies that incoming
+			// addresses should list the address, in string
+			// format. It is implied that this should be an IP
+			// address, as it would be impossible to connect to it
+			// otherwise.
+			laddr, err = parseTCPAddr(payload.Addr, payload.Port)
+			if err != nil {
+				ch.Reject(ConnectionFailed, err.Error())
+				continue
+			}
+			raddr, err = parseTCPAddr(payload.OriginAddr, payload.OriginPort)
+			if err != nil {
+				ch.Reject(ConnectionFailed, err.Error())
+				continue
+			}
 
-		if ok := l.forward(*laddr, *raddr, ch); !ok {
+		case "forwarded-streamlocal@openssh.com":
+			var payload forwardedStreamLocalPayload
+			if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
+				ch.Reject(ConnectionFailed, "could not parse forwarded-streamlocal@openssh.com payload: "+err.Error())
+				continue
+			}
+			laddr = &net.UnixAddr{
+				Name: payload.SocketPath,
+				Net:  "unix",
+			}
+			raddr = &net.UnixAddr{
+				Name: "@",
+				Net:  "unix",
+			}
+		default:
+			panic(fmt.Errorf("ssh: unknown channel type %s", channelType))
+		}
+		if ok := l.forward(laddr, raddr, ch); !ok {
 			// Section 7.2, implementations MUST reject spurious incoming
 			// connections.
 			ch.Reject(Prohibited, "no forward for address")
 			continue
 		}
+
 	}
 }
 
 // remove removes the forward entry, and the channel feeding its
 // listener.
-func (l *forwardList) remove(addr net.TCPAddr) {
+func (l *forwardList) remove(addr net.Addr) {
 	l.Lock()
 	defer l.Unlock()
 	for i, f := range l.entries {
-		if addr.IP.Equal(f.laddr.IP) && addr.Port == f.laddr.Port {
+		if addr.Network() == f.laddr.Network() && addr.String() == f.laddr.String() {
 			l.entries = append(l.entries[:i], l.entries[i+1:]...)
 			close(f.c)
 			return
@@ -231,12 +264,12 @@
 	l.entries = nil
 }
 
-func (l *forwardList) forward(laddr, raddr net.TCPAddr, ch NewChannel) bool {
+func (l *forwardList) forward(laddr, raddr net.Addr, ch NewChannel) bool {
 	l.Lock()
 	defer l.Unlock()
 	for _, f := range l.entries {
-		if laddr.IP.Equal(f.laddr.IP) && laddr.Port == f.laddr.Port {
-			f.c <- forward{ch, &raddr}
+		if laddr.Network() == f.laddr.Network() && laddr.String() == f.laddr.String() {
+			f.c <- forward{newCh: ch, raddr: raddr}
 			return true
 		}
 	}
@@ -262,7 +295,7 @@
 	}
 	go DiscardRequests(incoming)
 
-	return &tcpChanConn{
+	return &chanConn{
 		Channel: ch,
 		laddr:   l.laddr,
 		raddr:   s.raddr,
@@ -277,7 +310,7 @@
 	}
 
 	// this also closes the listener.
-	l.conn.forwards.remove(*l.laddr)
+	l.conn.forwards.remove(l.laddr)
 	ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m))
 	if err == nil && !ok {
 		err = errors.New("ssh: cancel-tcpip-forward failed")
@@ -293,29 +326,52 @@
 // Dial initiates a connection to the addr from the remote host.
 // The resulting connection has a zero LocalAddr() and RemoteAddr().
 func (c *Client) Dial(n, addr string) (net.Conn, error) {
-	// Parse the address into host and numeric port.
-	host, portString, err := net.SplitHostPort(addr)
-	if err != nil {
-		return nil, err
+	var ch Channel
+	switch n {
+	case "tcp", "tcp4", "tcp6":
+		// Parse the address into host and numeric port.
+		host, portString, err := net.SplitHostPort(addr)
+		if err != nil {
+			return nil, err
+		}
+		port, err := strconv.ParseUint(portString, 10, 16)
+		if err != nil {
+			return nil, err
+		}
+		ch, err = c.dial(net.IPv4zero.String(), 0, host, int(port))
+		if err != nil {
+			return nil, err
+		}
+		// Use a zero address for local and remote address.
+		zeroAddr := &net.TCPAddr{
+			IP:   net.IPv4zero,
+			Port: 0,
+		}
+		return &chanConn{
+			Channel: ch,
+			laddr:   zeroAddr,
+			raddr:   zeroAddr,
+		}, nil
+	case "unix":
+		var err error
+		ch, err = c.dialStreamLocal(addr)
+		if err != nil {
+			return nil, err
+		}
+		return &chanConn{
+			Channel: ch,
+			laddr: &net.UnixAddr{
+				Name: "@",
+				Net:  "unix",
+			},
+			raddr: &net.UnixAddr{
+				Name: addr,
+				Net:  "unix",
+			},
+		}, nil
+	default:
+		return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
 	}
-	port, err := strconv.ParseUint(portString, 10, 16)
-	if err != nil {
-		return nil, err
-	}
-	// Use a zero address for local and remote address.
-	zeroAddr := &net.TCPAddr{
-		IP:   net.IPv4zero,
-		Port: 0,
-	}
-	ch, err := c.dial(net.IPv4zero.String(), 0, host, int(port))
-	if err != nil {
-		return nil, err
-	}
-	return &tcpChanConn{
-		Channel: ch,
-		laddr:   zeroAddr,
-		raddr:   zeroAddr,
-	}, nil
 }
 
 // DialTCP connects to the remote address raddr on the network net,
@@ -332,7 +388,7 @@
 	if err != nil {
 		return nil, err
 	}
-	return &tcpChanConn{
+	return &chanConn{
 		Channel: ch,
 		laddr:   laddr,
 		raddr:   raddr,
@@ -366,26 +422,26 @@
 	Channel // the backing channel
 }
 
-// tcpChanConn fulfills the net.Conn interface without
+// chanConn fulfills the net.Conn interface without
 // the tcpChan having to hold laddr or raddr directly.
-type tcpChanConn struct {
+type chanConn struct {
 	Channel
 	laddr, raddr net.Addr
 }
 
 // LocalAddr returns the local network address.
-func (t *tcpChanConn) LocalAddr() net.Addr {
+func (t *chanConn) LocalAddr() net.Addr {
 	return t.laddr
 }
 
 // RemoteAddr returns the remote network address.
-func (t *tcpChanConn) RemoteAddr() net.Addr {
+func (t *chanConn) RemoteAddr() net.Addr {
 	return t.raddr
 }
 
 // SetDeadline sets the read and write deadlines associated
 // with the connection.
-func (t *tcpChanConn) SetDeadline(deadline time.Time) error {
+func (t *chanConn) SetDeadline(deadline time.Time) error {
 	if err := t.SetReadDeadline(deadline); err != nil {
 		return err
 	}
@@ -396,12 +452,14 @@
 // A zero value for t means Read will not time out.
 // After the deadline, the error from Read will implement net.Error
 // with Timeout() == true.
-func (t *tcpChanConn) SetReadDeadline(deadline time.Time) error {
+func (t *chanConn) SetReadDeadline(deadline time.Time) error {
+	// for compatibility with previous version,
+	// the error message contains "tcpChan"
 	return errors.New("ssh: tcpChan: deadline not supported")
 }
 
 // SetWriteDeadline exists to satisfy the net.Conn interface
 // but is not implemented by this type.  It always returns an error.
-func (t *tcpChanConn) SetWriteDeadline(deadline time.Time) error {
+func (t *chanConn) SetWriteDeadline(deadline time.Time) error {
 	return errors.New("ssh: tcpChan: deadline not supported")
 }
diff --git a/ssh/test/dial_unix_test.go b/ssh/test/dial_unix_test.go
new file mode 100644
index 0000000..091e48c
--- /dev/null
+++ b/ssh/test/dial_unix_test.go
@@ -0,0 +1,128 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !windows
+
+package test
+
+// direct-tcpip and direct-streamlocal functional tests
+
+import (
+	"fmt"
+	"io"
+	"io/ioutil"
+	"net"
+	"strings"
+	"testing"
+)
+
+type dialTester interface {
+	TestServerConn(t *testing.T, c net.Conn)
+	TestClientConn(t *testing.T, c net.Conn)
+}
+
+func testDial(t *testing.T, n, listenAddr string, x dialTester) {
+	server := newServer(t)
+	defer server.Shutdown()
+	sshConn := server.Dial(clientConfig())
+	defer sshConn.Close()
+
+	l, err := net.Listen(n, listenAddr)
+	if err != nil {
+		t.Fatalf("Listen: %v", err)
+	}
+	defer l.Close()
+
+	testData := fmt.Sprintf("hello from %s, %s", n, listenAddr)
+	go func() {
+		for {
+			c, err := l.Accept()
+			if err != nil {
+				break
+			}
+			x.TestServerConn(t, c)
+
+			io.WriteString(c, testData)
+			c.Close()
+		}
+	}()
+
+	conn, err := sshConn.Dial(n, l.Addr().String())
+	if err != nil {
+		t.Fatalf("Dial: %v", err)
+	}
+	x.TestClientConn(t, conn)
+	defer conn.Close()
+	b, err := ioutil.ReadAll(conn)
+	if err != nil {
+		t.Fatalf("ReadAll: %v", err)
+	}
+	t.Logf("got %q", string(b))
+	if string(b) != testData {
+		t.Fatalf("expected %q, got %q", testData, string(b))
+	}
+}
+
+type tcpDialTester struct {
+	listenAddr string
+}
+
+func (x *tcpDialTester) TestServerConn(t *testing.T, c net.Conn) {
+	host := strings.Split(x.listenAddr, ":")[0]
+	prefix := host + ":"
+	if !strings.HasPrefix(c.LocalAddr().String(), prefix) {
+		t.Fatalf("expected to start with %q, got %q", prefix, c.LocalAddr().String())
+	}
+	if !strings.HasPrefix(c.RemoteAddr().String(), prefix) {
+		t.Fatalf("expected to start with %q, got %q", prefix, c.RemoteAddr().String())
+	}
+}
+
+func (x *tcpDialTester) TestClientConn(t *testing.T, c net.Conn) {
+	// we use zero addresses. see *Client.Dial.
+	if c.LocalAddr().String() != "0.0.0.0:0" {
+		t.Fatalf("expected \"0.0.0.0:0\", got %q", c.LocalAddr().String())
+	}
+	if c.RemoteAddr().String() != "0.0.0.0:0" {
+		t.Fatalf("expected \"0.0.0.0:0\", got %q", c.RemoteAddr().String())
+	}
+}
+
+func TestDialTCP(t *testing.T) {
+	x := &tcpDialTester{
+		listenAddr: "127.0.0.1:0",
+	}
+	testDial(t, "tcp", x.listenAddr, x)
+}
+
+type unixDialTester struct {
+	listenAddr string
+}
+
+func (x *unixDialTester) TestServerConn(t *testing.T, c net.Conn) {
+	if c.LocalAddr().String() != x.listenAddr {
+		t.Fatalf("expected %q, got %q", x.listenAddr, c.LocalAddr().String())
+	}
+	if c.RemoteAddr().String() != "@" {
+		t.Fatalf("expected \"@\", got %q", c.RemoteAddr().String())
+	}
+}
+
+func (x *unixDialTester) TestClientConn(t *testing.T, c net.Conn) {
+	if c.RemoteAddr().String() != x.listenAddr {
+		t.Fatalf("expected %q, got %q", x.listenAddr, c.RemoteAddr().String())
+	}
+	if c.LocalAddr().String() != "@" {
+		t.Fatalf("expected \"@\", got %q", c.LocalAddr().String())
+	}
+}
+
+func TestDialUnix(t *testing.T) {
+	addr, cleanup := newTempSocket(t)
+	defer cleanup()
+	x := &unixDialTester{
+		listenAddr: addr,
+	}
+	testDial(t, "unix", x.listenAddr, x)
+}
diff --git a/ssh/test/forward_unix_test.go b/ssh/test/forward_unix_test.go
index 877a88c..ea81937 100644
--- a/ssh/test/forward_unix_test.go
+++ b/ssh/test/forward_unix_test.go
@@ -16,13 +16,17 @@
 	"time"
 )
 
-func TestPortForward(t *testing.T) {
+type closeWriter interface {
+	CloseWrite() error
+}
+
+func testPortForward(t *testing.T, n, listenAddr string) {
 	server := newServer(t)
 	defer server.Shutdown()
 	conn := server.Dial(clientConfig())
 	defer conn.Close()
 
-	sshListener, err := conn.Listen("tcp", "localhost:0")
+	sshListener, err := conn.Listen(n, listenAddr)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -41,14 +45,14 @@
 	}()
 
 	forwardedAddr := sshListener.Addr().String()
-	tcpConn, err := net.Dial("tcp", forwardedAddr)
+	netConn, err := net.Dial(n, forwardedAddr)
 	if err != nil {
-		t.Fatalf("TCP dial failed: %v", err)
+		t.Fatalf("net dial failed: %v", err)
 	}
 
 	readChan := make(chan []byte)
 	go func() {
-		data, _ := ioutil.ReadAll(tcpConn)
+		data, _ := ioutil.ReadAll(netConn)
 		readChan <- data
 	}()
 
@@ -62,14 +66,14 @@
 	for len(sent) < 1000*1000 {
 		// Send random sized chunks
 		m := rand.Intn(len(data))
-		n, err := tcpConn.Write(data[:m])
+		n, err := netConn.Write(data[:m])
 		if err != nil {
 			break
 		}
 		sent = append(sent, data[:n]...)
 	}
-	if err := tcpConn.(*net.TCPConn).CloseWrite(); err != nil {
-		t.Errorf("tcpConn.CloseWrite: %v", err)
+	if err := netConn.(closeWriter).CloseWrite(); err != nil {
+		t.Errorf("netConn.CloseWrite: %v", err)
 	}
 
 	read := <-readChan
@@ -86,19 +90,29 @@
 	}
 
 	// Check that the forward disappeared.
-	tcpConn, err = net.Dial("tcp", forwardedAddr)
+	netConn, err = net.Dial(n, forwardedAddr)
 	if err == nil {
-		tcpConn.Close()
+		netConn.Close()
 		t.Errorf("still listening to %s after closing", forwardedAddr)
 	}
 }
 
-func TestAcceptClose(t *testing.T) {
+func TestPortForwardTCP(t *testing.T) {
+	testPortForward(t, "tcp", "localhost:0")
+}
+
+func TestPortForwardUnix(t *testing.T) {
+	addr, cleanup := newTempSocket(t)
+	defer cleanup()
+	testPortForward(t, "unix", addr)
+}
+
+func testAcceptClose(t *testing.T, n, listenAddr string) {
 	server := newServer(t)
 	defer server.Shutdown()
 	conn := server.Dial(clientConfig())
 
-	sshListener, err := conn.Listen("tcp", "localhost:0")
+	sshListener, err := conn.Listen(n, listenAddr)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -124,13 +138,23 @@
 	}
 }
 
+func TestAcceptCloseTCP(t *testing.T) {
+	testAcceptClose(t, "tcp", "localhost:0")
+}
+
+func TestAcceptCloseUnix(t *testing.T) {
+	addr, cleanup := newTempSocket(t)
+	defer cleanup()
+	testAcceptClose(t, "unix", addr)
+}
+
 // Check that listeners exit if the underlying client transport dies.
-func TestPortForwardConnectionClose(t *testing.T) {
+func testPortForwardConnectionClose(t *testing.T, n, listenAddr string) {
 	server := newServer(t)
 	defer server.Shutdown()
 	conn := server.Dial(clientConfig())
 
-	sshListener, err := conn.Listen("tcp", "localhost:0")
+	sshListener, err := conn.Listen(n, listenAddr)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -158,3 +182,13 @@
 		t.Logf("quit as expected (error %v)", err)
 	}
 }
+
+func TestPortForwardConnectionCloseTCP(t *testing.T) {
+	testPortForwardConnectionClose(t, "tcp", "localhost:0")
+}
+
+func TestPortForwardConnectionCloseUnix(t *testing.T) {
+	addr, cleanup := newTempSocket(t)
+	defer cleanup()
+	testPortForwardConnectionClose(t, "unix", addr)
+}
diff --git a/ssh/test/tcpip_test.go b/ssh/test/tcpip_test.go
deleted file mode 100644
index a2eb935..0000000
--- a/ssh/test/tcpip_test.go
+++ /dev/null
@@ -1,46 +0,0 @@
-// Copyright 2012 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// +build !windows
-
-package test
-
-// direct-tcpip functional tests
-
-import (
-	"io"
-	"net"
-	"testing"
-)
-
-func TestDial(t *testing.T) {
-	server := newServer(t)
-	defer server.Shutdown()
-	sshConn := server.Dial(clientConfig())
-	defer sshConn.Close()
-
-	l, err := net.Listen("tcp", "127.0.0.1:0")
-	if err != nil {
-		t.Fatalf("Listen: %v", err)
-	}
-	defer l.Close()
-
-	go func() {
-		for {
-			c, err := l.Accept()
-			if err != nil {
-				break
-			}
-
-			io.WriteString(c, c.RemoteAddr().String())
-			c.Close()
-		}
-	}()
-
-	conn, err := sshConn.Dial("tcp", l.Addr().String())
-	if err != nil {
-		t.Fatalf("Dial: %v", err)
-	}
-	defer conn.Close()
-}
diff --git a/ssh/test/test_unix_test.go b/ssh/test/test_unix_test.go
index 3bfd881..dd9ff40 100644
--- a/ssh/test/test_unix_test.go
+++ b/ssh/test/test_unix_test.go
@@ -266,3 +266,13 @@
 		},
 	}
 }
+
+func newTempSocket(t *testing.T) (string, func()) {
+	dir, err := ioutil.TempDir("", "socket")
+	if err != nil {
+		t.Fatal(err)
+	}
+	deferFunc := func() { os.RemoveAll(dir) }
+	addr := filepath.Join(dir, "sock")
+	return addr, deferFunc
+}