ssh: allow to bind to a hostname in remote forwarding

To avoid breaking backwards compatibility, we fix Listen, which
receives the address as a string, while ListenTCP can still only
be used with IP addresses.

Fixes golang/go#33227
Fixes golang/go#37239

Change-Id: I4d45b40fdcb0d6012ed8da59a02149fa37e7db50
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/599995
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Junyang Shao <shaojunyang@google.com>
Reviewed-by: Bishakh Ghosh <ghoshbishakh@gmail.com>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Auto-Submit: Nicola Murino <nicola.murino@gmail.com>
Reviewed-by: Michael Pratt <mpratt@google.com>
diff --git a/ssh/streamlocal.go b/ssh/streamlocal.go
index b171b33..152470f 100644
--- a/ssh/streamlocal.go
+++ b/ssh/streamlocal.go
@@ -44,7 +44,7 @@
 	if !ok {
 		return nil, errors.New("ssh: streamlocal-forward@openssh.com request denied by peer")
 	}
-	ch := c.forwards.add(&net.UnixAddr{Name: socketPath, Net: "unix"})
+	ch := c.forwards.add("unix", socketPath)
 
 	return &unixListener{socketPath, c, ch}, nil
 }
@@ -96,7 +96,7 @@
 // Close closes the listener.
 func (l *unixListener) Close() error {
 	// this also closes the listener.
-	l.conn.forwards.remove(&net.UnixAddr{Name: l.socketPath, Net: "unix"})
+	l.conn.forwards.remove("unix", l.socketPath)
 	m := streamLocalChannelForwardMsg{
 		l.socketPath,
 	}
diff --git a/ssh/tcpip.go b/ssh/tcpip.go
index 93d844f..78c41fe 100644
--- a/ssh/tcpip.go
+++ b/ssh/tcpip.go
@@ -11,6 +11,7 @@
 	"io"
 	"math/rand"
 	"net"
+	"net/netip"
 	"strconv"
 	"strings"
 	"sync"
@@ -22,14 +23,21 @@
 // the returned net.Listener. The listener must be serviced, or the
 // SSH connection may hang.
 // N must be "tcp", "tcp4", "tcp6", or "unix".
+//
+// If the address is a hostname, it is sent to the remote peer as-is, without
+// being resolved locally, and the Listener Addr method will return a zero IP.
 func (c *Client) Listen(n, addr string) (net.Listener, error) {
 	switch n {
 	case "tcp", "tcp4", "tcp6":
-		laddr, err := net.ResolveTCPAddr(n, addr)
+		host, portStr, err := net.SplitHostPort(addr)
 		if err != nil {
 			return nil, err
 		}
-		return c.ListenTCP(laddr)
+		port, err := strconv.ParseInt(portStr, 10, 32)
+		if err != nil {
+			return nil, err
+		}
+		return c.listenTCPInternal(host, int(port))
 	case "unix":
 		return c.ListenUnix(addr)
 	default:
@@ -102,15 +110,24 @@
 // ListenTCP requests the remote peer open a listening socket
 // on laddr. Incoming connections will be available by calling
 // Accept on the returned net.Listener.
+//
+// ListenTCP accepts an IP address, to provide a hostname use [Client.Listen]
+// with "tcp", "tcp4", or "tcp6" network instead.
 func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
 	c.handleForwardsOnce.Do(c.handleForwards)
 	if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) {
 		return c.autoPortListenWorkaround(laddr)
 	}
 
+	return c.listenTCPInternal(laddr.IP.String(), laddr.Port)
+}
+
+func (c *Client) listenTCPInternal(host string, port int) (net.Listener, error) {
+	c.handleForwardsOnce.Do(c.handleForwards)
+
 	m := channelForwardMsg{
-		laddr.IP.String(),
-		uint32(laddr.Port),
+		host,
+		uint32(port),
 	}
 	// send message
 	ok, resp, err := c.SendRequest("tcpip-forward", true, Marshal(&m))
@@ -123,20 +140,33 @@
 
 	// If the original port was 0, then the remote side will
 	// supply a real port number in the response.
-	if laddr.Port == 0 {
+	if port == 0 {
 		var p struct {
 			Port uint32
 		}
 		if err := Unmarshal(resp, &p); err != nil {
 			return nil, err
 		}
-		laddr.Port = int(p.Port)
+		port = int(p.Port)
 	}
+	// Construct a local address placeholder for the remote listener. If the
+	// original host is an IP address, preserve it so that Listener.Addr()
+	// reports the same IP. If the host is a hostname or cannot be parsed as an
+	// IP, fall back to IPv4zero. The port field is always set, even if the
+	// original port was 0, because in that case the remote server will assign
+	// one, allowing callers to determine which port was selected.
+	ip := net.IPv4zero
+	if parsed, err := netip.ParseAddr(host); err == nil {
+		ip = net.IP(parsed.AsSlice())
+	}
+	laddr := &net.TCPAddr{
+		IP:   ip,
+		Port: port,
+	}
+	addr := net.JoinHostPort(host, strconv.FormatInt(int64(port), 10))
+	ch := c.forwards.add("tcp", addr)
 
-	// Register this forward, using the port number we obtained.
-	ch := c.forwards.add(laddr)
-
-	return &tcpListener{laddr, c, ch}, nil
+	return &tcpListener{laddr, addr, c, ch}, nil
 }
 
 // forwardList stores a mapping between remote
@@ -149,8 +179,9 @@
 // forwardEntry represents an established mapping of a laddr on a
 // remote ssh server to a channel connected to a tcpListener.
 type forwardEntry struct {
-	laddr net.Addr
-	c     chan forward
+	addr    string // host:port or socket path
+	network string // tcp or unix
+	c       chan forward
 }
 
 // forward represents an incoming forwarded tcpip connection. The
@@ -161,12 +192,13 @@
 	raddr net.Addr   // the raddr of the incoming connection
 }
 
-func (l *forwardList) add(addr net.Addr) chan forward {
+func (l *forwardList) add(n, addr string) chan forward {
 	l.Lock()
 	defer l.Unlock()
 	f := forwardEntry{
-		laddr: addr,
-		c:     make(chan forward, 1),
+		addr:    addr,
+		network: n,
+		c:       make(chan forward, 1),
 	}
 	l.entries = append(l.entries, f)
 	return f.c
@@ -185,19 +217,20 @@
 	if port == 0 || port > 65535 {
 		return nil, fmt.Errorf("ssh: port number out of range: %d", port)
 	}
-	ip := net.ParseIP(string(addr))
-	if ip == nil {
+	ip, err := netip.ParseAddr(addr)
+	if err != nil {
 		return nil, fmt.Errorf("ssh: cannot parse IP address %q", addr)
 	}
-	return &net.TCPAddr{IP: ip, Port: int(port)}, nil
+	return &net.TCPAddr{IP: net.IP(ip.AsSlice()), Port: int(port)}, nil
 }
 
 func (l *forwardList) handleChannels(in <-chan NewChannel) {
 	for ch := range in {
 		var (
-			laddr net.Addr
-			raddr net.Addr
-			err   error
+			addr    string
+			network string
+			raddr   net.Addr
+			err     error
 		)
 		switch channelType := ch.ChannelType(); channelType {
 		case "forwarded-tcpip":
@@ -207,40 +240,34 @@
 				continue
 			}
 
-			// RFC 4254 section 7.2 specifies that incoming
-			// addresses should list the address, in string
-			// format. It is implied that this should be an IP
-			// address, as it would be impossible to connect to it
-			// otherwise.
-			laddr, err = parseTCPAddr(payload.Addr, payload.Port)
-			if err != nil {
-				ch.Reject(ConnectionFailed, err.Error())
-				continue
-			}
+			// RFC 4254 section 7.2 specifies that incoming addresses should
+			// list the address that was connected, in string format. It is the
+			// same address used in the tcpip-forward request. The originator
+			// address is an IP address instead.
+			addr = net.JoinHostPort(payload.Addr, strconv.FormatUint(uint64(payload.Port), 10))
+
 			raddr, err = parseTCPAddr(payload.OriginAddr, payload.OriginPort)
 			if err != nil {
 				ch.Reject(ConnectionFailed, err.Error())
 				continue
 			}
-
+			network = "tcp"
 		case "forwarded-streamlocal@openssh.com":
 			var payload forwardedStreamLocalPayload
 			if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
 				ch.Reject(ConnectionFailed, "could not parse forwarded-streamlocal@openssh.com payload: "+err.Error())
 				continue
 			}
-			laddr = &net.UnixAddr{
-				Name: payload.SocketPath,
-				Net:  "unix",
-			}
+			addr = payload.SocketPath
 			raddr = &net.UnixAddr{
 				Name: "@",
 				Net:  "unix",
 			}
+			network = "unix"
 		default:
 			panic(fmt.Errorf("ssh: unknown channel type %s", channelType))
 		}
-		if ok := l.forward(laddr, raddr, ch); !ok {
+		if ok := l.forward(network, addr, raddr, ch); !ok {
 			// Section 7.2, implementations MUST reject spurious incoming
 			// connections.
 			ch.Reject(Prohibited, "no forward for address")
@@ -252,11 +279,11 @@
 
 // remove removes the forward entry, and the channel feeding its
 // listener.
-func (l *forwardList) remove(addr net.Addr) {
+func (l *forwardList) remove(n, addr string) {
 	l.Lock()
 	defer l.Unlock()
 	for i, f := range l.entries {
-		if addr.Network() == f.laddr.Network() && addr.String() == f.laddr.String() {
+		if n == f.network && addr == f.addr {
 			l.entries = append(l.entries[:i], l.entries[i+1:]...)
 			close(f.c)
 			return
@@ -274,11 +301,11 @@
 	l.entries = nil
 }
 
-func (l *forwardList) forward(laddr, raddr net.Addr, ch NewChannel) bool {
+func (l *forwardList) forward(n, addr string, raddr net.Addr, ch NewChannel) bool {
 	l.Lock()
 	defer l.Unlock()
 	for _, f := range l.entries {
-		if laddr.Network() == f.laddr.Network() && laddr.String() == f.laddr.String() {
+		if n == f.network && addr == f.addr {
 			f.c <- forward{newCh: ch, raddr: raddr}
 			return true
 		}
@@ -288,6 +315,7 @@
 
 type tcpListener struct {
 	laddr *net.TCPAddr
+	addr  string
 
 	conn *Client
 	in   <-chan forward
@@ -314,13 +342,21 @@
 
 // Close closes the listener.
 func (l *tcpListener) Close() error {
+	host, port, err := net.SplitHostPort(l.addr)
+	if err != nil {
+		return err
+	}
+	rport, err := strconv.ParseUint(port, 10, 32)
+	if err != nil {
+		return err
+	}
 	m := channelForwardMsg{
-		l.laddr.IP.String(),
-		uint32(l.laddr.Port),
+		host,
+		uint32(rport),
 	}
 
 	// this also closes the listener.
-	l.conn.forwards.remove(l.laddr)
+	l.conn.forwards.remove("tcp", l.addr)
 	ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m))
 	if err == nil && !ok {
 		err = errors.New("ssh: cancel-tcpip-forward failed")
diff --git a/ssh/test/forward_unix_test.go b/ssh/test/forward_unix_test.go
index c10d1d0..549d46c 100644
--- a/ssh/test/forward_unix_test.go
+++ b/ssh/test/forward_unix_test.go
@@ -51,6 +51,8 @@
 		}
 	}()
 
+	// The forwarded address match the listen address because we run the tests
+	// on the same host.
 	forwardedAddr := sshListener.Addr().String()
 	netConn, err := net.Dial(n, forwardedAddr)
 	if err != nil {
@@ -111,6 +113,8 @@
 }
 
 func TestPortForwardTCP(t *testing.T) {
+	testPortForward(t, "tcp", ":0")
+	testPortForward(t, "tcp", "[::]:0")
 	testPortForward(t, "tcp", "localhost:0")
 }