ssh: add (*Client).DialContext method
This change adds DialContext to ssh.Client, which opens a TCP-IP
connection tunneled over the SSH connection. This is useful for
proxying network connections, e.g. setting
(net/http.Transport).DialContext.
Fixes golang/go#20288.
Change-Id: I110494c00962424ea803065535ebe2209364ac27
GitHub-Last-Rev: 3176984a71a9a1422702e3a071340ecfff71ff62
GitHub-Pull-Request: golang/crypto#260
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/504735
Run-TryBot: Nicola Murino <nicola.murino@gmail.com>
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
Auto-Submit: Nicola Murino <nicola.murino@gmail.com>
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Nicola Murino <nicola.murino@gmail.com>
Commit-Queue: Nicola Murino <nicola.murino@gmail.com>
diff --git a/ssh/tcpip.go b/ssh/tcpip.go
index 80d35f5..ef5059a 100644
--- a/ssh/tcpip.go
+++ b/ssh/tcpip.go
@@ -5,6 +5,7 @@
package ssh
import (
+ "context"
"errors"
"fmt"
"io"
@@ -332,6 +333,40 @@
return l.laddr
}
+// DialContext initiates a connection to the addr from the remote host.
+//
+// The provided Context must be non-nil. If the context expires before the
+// connection is complete, an error is returned. Once successfully connected,
+// any expiration of the context will not affect the connection.
+//
+// See func Dial for additional information.
+func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, error) {
+ if err := ctx.Err(); err != nil {
+ return nil, err
+ }
+ type connErr struct {
+ conn net.Conn
+ err error
+ }
+ ch := make(chan connErr)
+ go func() {
+ conn, err := c.Dial(n, addr)
+ select {
+ case ch <- connErr{conn, err}:
+ case <-ctx.Done():
+ if conn != nil {
+ conn.Close()
+ }
+ }
+ }()
+ select {
+ case res := <-ch:
+ return res.conn, res.err
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ }
+}
+
// Dial initiates a connection to the addr from the remote host.
// The resulting connection has a zero LocalAddr() and RemoteAddr().
func (c *Client) Dial(n, addr string) (net.Conn, error) {
diff --git a/ssh/tcpip_test.go b/ssh/tcpip_test.go
index f1265cb..4d85114 100644
--- a/ssh/tcpip_test.go
+++ b/ssh/tcpip_test.go
@@ -5,7 +5,10 @@
package ssh
import (
+ "context"
+ "net"
"testing"
+ "time"
)
func TestAutoPortListenBroken(t *testing.T) {
@@ -18,3 +21,33 @@
t.Errorf("version %q marked as broken", works)
}
}
+
+func TestClientImplementsDialContext(t *testing.T) {
+ type ContextDialer interface {
+ DialContext(context.Context, string, string) (net.Conn, error)
+ }
+ // Belt and suspenders assertion, since package net does not
+ // declare a ContextDialer type.
+ var _ ContextDialer = &net.Dialer{}
+ var _ ContextDialer = &Client{}
+}
+
+func TestClientDialContextWithCancel(t *testing.T) {
+ c := &Client{}
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ _, err := c.DialContext(ctx, "tcp", "localhost:1000")
+ if err != context.Canceled {
+ t.Errorf("DialContext: got nil error, expected %v", context.Canceled)
+ }
+}
+
+func TestClientDialContextWithDeadline(t *testing.T) {
+ c := &Client{}
+ ctx, cancel := context.WithDeadline(context.Background(), time.Now())
+ defer cancel()
+ _, err := c.DialContext(ctx, "tcp", "localhost:1000")
+ if err != context.DeadlineExceeded {
+ t.Errorf("DialContext: got nil error, expected %v", context.DeadlineExceeded)
+ }
+}
diff --git a/ssh/test/dial_unix_test.go b/ssh/test/dial_unix_test.go
index 0a5f5e3..8ec8d50 100644
--- a/ssh/test/dial_unix_test.go
+++ b/ssh/test/dial_unix_test.go
@@ -9,6 +9,7 @@
// direct-tcpip and direct-streamlocal functional tests
import (
+ "context"
"fmt"
"io"
"net"
@@ -46,7 +47,11 @@
}
}()
- conn, err := sshConn.Dial(n, l.Addr().String())
+ ctx, cancel := context.WithCancel(context.Background())
+ conn, err := sshConn.DialContext(ctx, n, l.Addr().String())
+ // Canceling the context after dial should have no effect
+ // on the opened connection.
+ cancel()
if err != nil {
t.Fatalf("Dial: %v", err)
}