go.crypto/ssh: add support for remote tcpip forwarding
Add support for server (remote) forwarded tcpip channels.
See RFC4254 Section 7.1
R=gustav.paul, jeff, agl, lieqiewang
CC=golang-dev
https://golang.org/cl/6038047
diff --git a/ssh/channel.go b/ssh/channel.go
index 20bc710..4f8050a 100644
--- a/ssh/channel.go
+++ b/ssh/channel.go
@@ -122,7 +122,7 @@
reject := channelOpenFailureMsg{
PeersId: c.theirId,
- Reason: uint32(reason),
+ Reason: reason,
Message: message,
Language: "en",
}
diff --git a/ssh/client.go b/ssh/client.go
index 3b29923..ea29675 100644
--- a/ssh/client.go
+++ b/ssh/client.go
@@ -21,8 +21,9 @@
// ClientConn represents the client side of an SSH connection.
type ClientConn struct {
*transport
- config *ClientConfig
- chanlist
+ config *ClientConfig
+ chanlist // channels associated with this connection
+ forwardList // forwared tcpip connections from the remote side
}
// Client returns a new SSH client connection using c as the underlying transport.
@@ -239,7 +240,7 @@
default:
switch msg := decode(packet).(type) {
case *channelOpenMsg:
- c.getChan(msg.PeersId).msg <- msg
+ c.handleChanOpen(msg)
case *channelOpenConfirmMsg:
c.getChan(msg.PeersId).msg <- msg
case *channelOpenFailureMsg:
@@ -281,6 +282,71 @@
}
}
+// Handle channel open messages from the remote side.
+func (c *ClientConn) handleChanOpen(msg *channelOpenMsg) {
+ switch msg.ChanType {
+ case "forwarded-tcpip":
+ addr, err := parseAddr(msg.TypeSpecificData)
+ if err != nil {
+ // invalid request
+ m := channelOpenFailureMsg{
+ PeersId: msg.PeersId,
+ Reason: ConnectionFailed,
+ Message: fmt.Sprintf("invalid request: %v", err),
+ Language: "en_US.UTF-8",
+ }
+ c.writePacket(marshal(msgChannelOpenFailure, m))
+ return
+ }
+ l, ok := c.forwardList.Lookup(addr)
+ if !ok {
+ // Section 7.2, implementations MUST reject suprious incoming
+ // connections.
+ return
+ }
+ ch := c.newChan(c.transport)
+ ch.peersId = msg.PeersId
+ ch.stdin.win.add(msg.PeersWindow)
+
+ m := channelOpenConfirmMsg{
+ PeersId: ch.peersId,
+ MyId: ch.id,
+ MyWindow: 1 << 14,
+ MaxPacketSize: 1 << 15, // RFC 4253 6.1
+ }
+ c.writePacket(marshal(msgChannelOpenConfirm, m))
+ l <- forward{ch, addr}
+ default:
+ // unknown channel type
+ m := channelOpenFailureMsg{
+ PeersId: msg.PeersId,
+ Reason: UnknownChannelType,
+ Message: fmt.Sprintf("unknown channel type: %v", msg.ChanType),
+ Language: "en_US.UTF-8",
+ }
+ c.writePacket(marshal(msgChannelOpenFailure, m))
+ }
+}
+
+// parseAddr parses the originating address from the remote into a *net.TCPAddr.
+// RFC 4254 section 7.2 is mute on what to do if parsing fails but the forwardlist
+// requires a valid *net.TCPAddr to operate, so we enforce that restriction here.
+func parseAddr(b []byte) (*net.TCPAddr, error) {
+ addr, b, ok := parseString(b)
+ if !ok {
+ return nil, ParseError{msgChannelOpen}
+ }
+ port, _, ok := parseUint32(b)
+ if !ok {
+ return nil, ParseError{msgChannelOpen}
+ }
+ ip := net.ParseIP(string(addr))
+ if ip == nil {
+ return nil, ParseError{msgChannelOpen}
+ }
+ return &net.TCPAddr{ip, int(port)}, nil
+}
+
// Dial connects to the given network address using net.Dial and
// then initiates a SSH handshake, returning the resulting client connection.
func Dial(network, addr string, config *ClientConfig) (*ClientConn, error) {
@@ -383,6 +449,14 @@
}))
}
+func (c *clientChan) sendWindowAdj(n int) error {
+ msg := windowAdjustMsg{
+ PeersId: c.peersId,
+ AdditionalBytes: uint32(n),
+ }
+ return c.writePacket(marshal(msgChannelWindowAdjust, msg))
+}
+
// Close closes the channel. This does not close the underlying connection.
func (c *clientChan) Close() error {
if !c.weClosed {
@@ -522,11 +596,7 @@
if len(r.buf) > 0 {
n := copy(data, r.buf)
r.buf = r.buf[n:]
- msg := windowAdjustMsg{
- PeersId: r.clientChan.peersId,
- AdditionalBytes: uint32(n),
- }
- return n, r.clientChan.writePacket(marshal(msgChannelWindowAdjust, msg))
+ return n, r.clientChan.sendWindowAdj(n)
}
r.buf, ok = <-r.data
if !ok {
diff --git a/ssh/example_test.go b/ssh/example_test.go
index ea772c2..c8a2de8 100644
--- a/ssh/example_test.go
+++ b/ssh/example_test.go
@@ -8,6 +8,8 @@
"bytes"
"fmt"
"io/ioutil"
+ "log"
+ "net/http"
"code.google.com/p/go.crypto/ssh/terminal"
)
@@ -120,3 +122,30 @@
}
fmt.Println(b.String())
}
+
+func ExampleClientConn_Listen() {
+ config := &ClientConfig{
+ User: "username",
+ Auth: []ClientAuth{
+ ClientAuthPassword(password("password")),
+ },
+ }
+ // Dial your ssh server.
+ conn, err := Dial("tcp", "localhost:22", config)
+ if err != nil {
+ log.Fatalf("unable to connect: %s", err)
+ }
+ defer conn.Close()
+
+ // Request the remote side to open port 8080 on all interfaces.
+ l, err := conn.Listen("tcp", "0.0.0.0:8080")
+ if err != nil {
+ log.Fatalf("unable to register tcp forward: %v", err)
+ }
+ defer l.Close()
+
+ // Serve HTTP with your SSH server acting as a reverse proxy.
+ http.Serve(l, http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
+ fmt.Fprintf(resp, "Hello world!\n")
+ }))
+}
diff --git a/ssh/messages.go b/ssh/messages.go
index 3efe81f..d8fbaf0 100644
--- a/ssh/messages.go
+++ b/ssh/messages.go
@@ -139,7 +139,7 @@
// See RFC 4254, section 5.1.
type channelOpenFailureMsg struct {
PeersId uint32
- Reason uint32
+ Reason RejectionReason
Message string
Language string
}
diff --git a/ssh/tcpip.go b/ssh/tcpip.go
index e0c47bc..55cd7cc 100644
--- a/ssh/tcpip.go
+++ b/ssh/tcpip.go
@@ -9,9 +9,154 @@
"fmt"
"io"
"net"
+ "sync"
"time"
)
+var (
+ // TODO(dfc) relax this restriction
+ errNoPort = errors.New("A port number must be supplied")
+)
+
+// Listen requests the remote peer open a listening socket
+// on addr. Incoming connections will be available by calling
+// Accept on the returned net.Listener.
+func (c *ClientConn) Listen(n, addr string) (net.Listener, error) {
+ raddr, err := net.ResolveTCPAddr(n, addr)
+ if err != nil {
+ return nil, err
+ }
+ return c.ListenTCP(raddr)
+}
+
+// ListenTCP requests the remote peer open a listening socket
+// on raddr. Incoming connections will be available by calling
+// Accept on the returned net.Listener.
+func (c *ClientConn) ListenTCP(raddr *net.TCPAddr) (net.Listener, error) {
+ if raddr.Port == 0 {
+ return nil, errNoPort
+ }
+ return c.listen(raddr)
+}
+
+// RFC 4254 7.1
+type channelForwardMsg struct {
+ Message string
+ WantReply bool
+ raddr string
+ rport uint32
+}
+
+func (c *ClientConn) listen(addr *net.TCPAddr) (net.Listener, error) {
+ m := channelForwardMsg{
+ "tcpip-forward",
+ false, // can't handle reply message from remote yet
+ addr.IP.String(),
+ uint32(addr.Port),
+ }
+ // register this forward
+ ch := c.forwardList.Add(addr)
+ // send message
+ if err := c.writePacket(marshal(msgGlobalRequest, m)); err != nil {
+ c.forwardList.Remove(addr)
+ return nil, err
+ }
+ return &tcpListener{addr, c, ch}, nil
+}
+
+// forwardList stores a mapping between remote
+// forward requests and the tcpListeners.
+type forwardList struct {
+ sync.Mutex
+ entries []forwardEntry
+}
+
+// forwardEntry represents an established mapping of a laddr on a
+// remote ssh server to a channel connected to a tcpListener.
+type forwardEntry struct {
+ laddr *net.TCPAddr
+ c chan forward
+}
+
+// forward represents an incoming forwarded tcpip connection
+type forward struct {
+ c *clientChan // the ssh client channel underlying this forward
+ raddr *net.TCPAddr // the raddr of the incoming connection
+}
+
+func (l *forwardList) Add(addr *net.TCPAddr) chan forward {
+ l.Lock()
+ defer l.Unlock()
+ f := forwardEntry{
+ addr,
+ make(chan forward, 1),
+ }
+ l.entries = append(l.entries, f)
+ return f.c
+}
+
+func (l *forwardList) Remove(addr *net.TCPAddr) {
+ l.Lock()
+ defer l.Unlock()
+ for i, f := range l.entries {
+ if addr.IP.Equal(f.laddr.IP) && addr.Port == f.laddr.Port {
+ l.entries = append(l.entries[:i], l.entries[i+1:]...)
+ return
+ }
+ }
+}
+
+func (l *forwardList) Lookup(addr *net.TCPAddr) (chan forward, bool) {
+ l.Lock()
+ defer l.Unlock()
+ for _, f := range l.entries {
+ if addr.IP.Equal(f.laddr.IP) && addr.Port == f.laddr.Port {
+ return f.c, true
+ }
+ }
+ return nil, false
+}
+
+type tcpListener struct {
+ laddr *net.TCPAddr
+ conn *ClientConn
+ in <-chan forward
+}
+
+// Accept waits for and returns the next connection to the listener.
+func (l *tcpListener) Accept() (net.Conn, error) {
+ s, ok := <-l.in
+ if !ok {
+ return nil, io.EOF
+ }
+ return &tcpChanConn{
+ tcpChan: &tcpChan{
+ clientChan: s.c,
+ Reader: s.c.stdout,
+ Writer: s.c.stdin,
+ },
+ laddr: l.laddr,
+ raddr: s.raddr,
+ }, nil
+}
+
+// Close closes the listener.
+func (l *tcpListener) Close() error {
+ m := channelForwardMsg{
+ "cancel-tcpip-forward",
+ false, // TODO(dfc) process reply
+ l.laddr.IP.String(),
+ uint32(l.laddr.Port),
+ }
+ l.conn.forwardList.Remove(l.laddr)
+ return l.conn.writePacket(marshal(msgGlobalRequest, m))
+}
+
+// Addr returns the listener's network address.
+func (l *tcpListener) Addr() net.Addr {
+ return l.laddr
+}
+
// Dial initiates a connection to the addr from the remote host.
// addr is resolved using net.ResolveTCPAddr before connection.
// This could allow an observer to observe the DNS name of the
@@ -38,8 +183,8 @@
if err != nil {
return nil, err
}
- return &tcpchanconn{
- tcpchan: ch,
+ return &tcpChanConn{
+ tcpChan: ch,
laddr: laddr,
raddr: raddr,
}, nil
@@ -59,7 +204,7 @@
// dial opens a direct-tcpip connection to the remote server. laddr and raddr are passed as
// strings and are expected to be resolveable at the remote end.
-func (c *ClientConn) dial(laddr string, lport int, raddr string, rport int) (*tcpchan, error) {
+func (c *ClientConn) dial(laddr string, lport int, raddr string, rport int) (*tcpChan, error) {
ch := c.newChan(c.transport)
if err := c.writePacket(marshal(msgChannelOpen, channelOpenDirectMsg{
ChanType: "direct-tcpip",
@@ -78,39 +223,39 @@
c.chanlist.remove(ch.id)
return nil, fmt.Errorf("ssh: unable to open direct tcpip connection: %v", err)
}
- return &tcpchan{
+ return &tcpChan{
clientChan: ch,
Reader: ch.stdout,
Writer: ch.stdin,
}, nil
}
-type tcpchan struct {
+type tcpChan struct {
*clientChan // the backing channel
io.Reader
io.Writer
}
-// tcpchanconn fulfills the net.Conn interface without
-// the tcpchan having to hold laddr or raddr directly.
-type tcpchanconn struct {
- *tcpchan
+// tcpChanConn fulfills the net.Conn interface without
+// the tcpChan having to hold laddr or raddr directly.
+type tcpChanConn struct {
+ *tcpChan
laddr, raddr net.Addr
}
// LocalAddr returns the local network address.
-func (t *tcpchanconn) LocalAddr() net.Addr {
+func (t *tcpChanConn) LocalAddr() net.Addr {
return t.laddr
}
// RemoteAddr returns the remote network address.
-func (t *tcpchanconn) RemoteAddr() net.Addr {
+func (t *tcpChanConn) RemoteAddr() net.Addr {
return t.raddr
}
// SetDeadline sets the read and write deadlines associated
// with the connection.
-func (t *tcpchanconn) SetDeadline(deadline time.Time) error {
+func (t *tcpChanConn) SetDeadline(deadline time.Time) error {
if err := t.SetReadDeadline(deadline); err != nil {
return err
}
@@ -121,12 +266,12 @@
// A zero value for t means Read will not time out.
// After the deadline, the error from Read will implement net.Error
// with Timeout() == true.
-func (t *tcpchanconn) SetReadDeadline(deadline time.Time) error {
- return errors.New("ssh: tcpchan: deadline not supported")
+func (t *tcpChanConn) SetReadDeadline(deadline time.Time) error {
+ return errors.New("ssh: tcpChan: deadline not supported")
}
// SetWriteDeadline exists to satisfy the net.Conn interface
// but is not implemented by this type. It always returns an error.
-func (t *tcpchanconn) SetWriteDeadline(deadline time.Time) error {
- return errors.New("ssh: tcpchan: deadline not supported")
+func (t *tcpChanConn) SetWriteDeadline(deadline time.Time) error {
+ return errors.New("ssh: tcpChan: deadline not supported")
}