go.crypto/ssh: fix and test port forwarding.
Set maxPacket in forwarded connection, and use the requested port
number as key in forwardList.
R=golang-dev, agl, dave
CC=golang-dev
https://golang.org/cl/9753044
diff --git a/ssh/client.go b/ssh/client.go
index 3e0b4c4..a42d13a 100644
--- a/ssh/client.go
+++ b/ssh/client.go
@@ -335,6 +335,10 @@
// Handle channel open messages from the remote side.
func (c *ClientConn) handleChanOpen(msg *channelOpenMsg) {
+ if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
+ c.sendConnectionFailed(msg.PeersId)
+ }
+
switch msg.ChanType {
case "forwarded-tcpip":
laddr, rest, ok := parseTCPAddr(msg.TypeSpecificData)
@@ -343,8 +347,10 @@
c.sendConnectionFailed(msg.PeersId)
return
}
- l, ok := c.forwardList.lookup(laddr)
+
+ l, ok := c.forwardList.lookup(*laddr)
if !ok {
+ // TODO: print on a more structured log.
fmt.Println("could not find forward list entry for", laddr)
// Section 7.2, implementations MUST reject suprious incoming
// connections.
@@ -360,13 +366,17 @@
ch := c.newChan(c.transport)
ch.remoteId = msg.PeersId
ch.remoteWin.add(msg.PeersWindow)
+ ch.maxPacket = msg.MaxPacketSize
m := channelOpenConfirmMsg{
- PeersId: ch.remoteId,
- MyId: ch.localId,
- MyWindow: 1 << 14,
- MaxPacketSize: 1 << 15, // RFC 4253 6.1
+ PeersId: ch.remoteId,
+ MyId: ch.localId,
+ MyWindow: 1 << 14,
+
+ // As per RFC 4253 6.1, 32k is also the minimum.
+ MaxPacketSize: 1 << 15,
}
+
c.writePacket(marshal(msgChannelOpenConfirm, m))
l <- forward{ch, raddr}
default:
diff --git a/ssh/tcpip.go b/ssh/tcpip.go
index c3f16b9..498e341 100644
--- a/ssh/tcpip.go
+++ b/ssh/tcpip.go
@@ -47,8 +47,16 @@
if err != nil {
return nil, err
}
- // fixup laddr. If the original port was 0, then the remote side will
- // supply one in the resp.
+
+ // Register this forward, using the port number we requested.
+ // If we requested port 0 (auto allocated port), we have to
+ // register under 0, since the channelOpenMsg will list 0
+ // rather than the allocated port number.
+ ch := c.forwardList.add(*laddr)
+
+ // If the original port was 0, then the remote side will
+ // supply a real port number in the response.
+ origPort := uint32(laddr.Port)
if laddr.Port == 0 {
port, _, ok := parseUint32(resp.Data)
if !ok {
@@ -57,9 +65,7 @@
laddr.Port = int(port)
}
- // register this forward
- ch := c.forwardList.add(laddr)
- return &tcpListener{laddr, c, ch}, nil
+ return &tcpListener{laddr, origPort, c, ch}, nil
}
// forwardList stores a mapping between remote
@@ -72,17 +78,19 @@
// forwardEntry represents an established mapping of a laddr on a
// remote ssh server to a channel connected to a tcpListener.
type forwardEntry struct {
- laddr *net.TCPAddr
+ laddr net.TCPAddr
c chan forward
}
-// forward represents an incoming forwarded tcpip connection
+// forward represents an incoming forwarded tcpip connection. The
+// arguments to add/remove/lookup should be address as specified in
+// the original forward-request.
type forward struct {
c *clientChan // the ssh client channel underlying this forward
raddr *net.TCPAddr // the raddr of the incoming connection
}
-func (l *forwardList) add(addr *net.TCPAddr) chan forward {
+func (l *forwardList) add(addr net.TCPAddr) chan forward {
l.Lock()
defer l.Unlock()
f := forwardEntry{
@@ -93,7 +101,7 @@
return f.c
}
-func (l *forwardList) remove(addr *net.TCPAddr) {
+func (l *forwardList) remove(addr net.TCPAddr) {
l.Lock()
defer l.Unlock()
for i, f := range l.entries {
@@ -104,7 +112,7 @@
}
}
-func (l *forwardList) lookup(addr *net.TCPAddr) (chan forward, bool) {
+func (l *forwardList) lookup(addr net.TCPAddr) (chan forward, bool) {
l.Lock()
defer l.Unlock()
for _, f := range l.entries {
@@ -117,8 +125,11 @@
type tcpListener struct {
laddr *net.TCPAddr
- conn *ClientConn
- in <-chan forward
+
+ // The port with which we made the request, which can be 0.
+ origPort uint32
+ conn *ClientConn
+ in <-chan forward
}
// Accept waits for and returns the next connection to the listener.
@@ -144,9 +155,13 @@
"cancel-tcpip-forward",
true,
l.laddr.IP.String(),
- uint32(l.laddr.Port),
+ l.origPort,
}
- l.conn.forwardList.remove(l.laddr)
+ origAddr := net.TCPAddr{
+ IP: l.laddr.IP,
+ Port: int(l.origPort),
+ }
+ l.conn.forwardList.remove(origAddr)
if _, err := l.conn.sendGlobalRequest(m); err != nil {
return err
}
diff --git a/ssh/test/forward_test.go b/ssh/test/forward_test.go
new file mode 100644
index 0000000..9f57bb0
--- /dev/null
+++ b/ssh/test/forward_test.go
@@ -0,0 +1,87 @@
+package test
+
+import (
+ "bytes"
+ "io"
+ "io/ioutil"
+ "math/rand"
+ "net"
+ "testing"
+)
+
+func TestPortForward(t *testing.T) {
+ server := newServer(t)
+ defer server.Shutdown()
+ conn := server.Dial(clientConfig())
+ defer conn.Close()
+
+ sshListener, err := conn.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("conn.Listen failed: %v", err)
+ }
+
+ go func() {
+ sshConn, err := sshListener.Accept()
+ if err != nil {
+ t.Fatalf("listen.Accept failed: %v", err)
+ }
+
+ _, err = io.Copy(sshConn, sshConn)
+ if err != nil && err != io.EOF {
+ t.Fatalf("ssh client copy: %v", err)
+ }
+ sshConn.Close()
+ }()
+
+ forwardedAddr := sshListener.Addr().String()
+ tcpConn, err := net.Dial("tcp", forwardedAddr)
+ if err != nil {
+ t.Fatalf("TCP dial failed: %v", err)
+ }
+
+ readChan := make(chan []byte)
+ go func() {
+ data, _ := ioutil.ReadAll(tcpConn)
+ readChan <- data
+ }()
+
+ // Invent some data.
+ data := make([]byte, 100*1000)
+ for i := range data {
+ data[i] = byte(i % 255)
+ }
+
+ var sent []byte
+ for len(sent) < 1000*1000 {
+ // Send random sized chunks
+ m := rand.Intn(len(data))
+ n, err := tcpConn.Write(data[:m])
+ if err != nil {
+ break
+ }
+ sent = append(sent, data[:n]...)
+ }
+ if err := tcpConn.(*net.TCPConn).CloseWrite(); err != nil {
+ t.Errorf("tcpConn.CloseWrite: %v", err)
+ }
+
+ read := <-readChan
+
+ if len(sent) != len(read) {
+ t.Fatalf("got %d bytes, want %d", len(read), len(sent))
+ }
+ if bytes.Compare(sent, read) != 0 {
+ t.Fatalf("read back data does not match")
+ }
+
+ if err := sshListener.Close(); err != nil {
+ t.Fatalf("sshListener.Close: %v", err)
+ }
+
+ // Check that the forward disappeared.
+ tcpConn, err = net.Dial("tcp", forwardedAddr)
+ if err == nil {
+ tcpConn.Close()
+ t.Errorf("still listening to %s after closing", forwardedAddr)
+ }
+}