ssh: add (*Client).DialContext method

This change adds DialContext to ssh.Client, which opens a TCP-IP
connection tunneled over the SSH connection. This is useful for
proxying network connections, e.g. setting
(net/http.Transport).DialContext.

Fixes golang/go#20288.

Change-Id: I110494c00962424ea803065535ebe2209364ac27
GitHub-Last-Rev: 3176984a71a9a1422702e3a071340ecfff71ff62
GitHub-Pull-Request: golang/crypto#260
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/504735
Run-TryBot: Nicola Murino <nicola.murino@gmail.com>
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
Auto-Submit: Nicola Murino <nicola.murino@gmail.com>
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Nicola Murino <nicola.murino@gmail.com>
Commit-Queue: Nicola Murino <nicola.murino@gmail.com>
diff --git a/ssh/tcpip.go b/ssh/tcpip.go
index 80d35f5..ef5059a 100644
--- a/ssh/tcpip.go
+++ b/ssh/tcpip.go
@@ -5,6 +5,7 @@
 package ssh
 
 import (
+	"context"
 	"errors"
 	"fmt"
 	"io"
@@ -332,6 +333,40 @@
 	return l.laddr
 }
 
+// DialContext initiates a connection to the addr from the remote host.
+//
+// The provided Context must be non-nil. If the context expires before the
+// connection is complete, an error is returned. Once successfully connected,
+// any expiration of the context will not affect the connection.
+//
+// See func Dial for additional information.
+func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, error) {
+	if err := ctx.Err(); err != nil {
+		return nil, err
+	}
+	type connErr struct {
+		conn net.Conn
+		err  error
+	}
+	ch := make(chan connErr)
+	go func() {
+		conn, err := c.Dial(n, addr)
+		select {
+		case ch <- connErr{conn, err}:
+		case <-ctx.Done():
+			if conn != nil {
+				conn.Close()
+			}
+		}
+	}()
+	select {
+	case res := <-ch:
+		return res.conn, res.err
+	case <-ctx.Done():
+		return nil, ctx.Err()
+	}
+}
+
 // Dial initiates a connection to the addr from the remote host.
 // The resulting connection has a zero LocalAddr() and RemoteAddr().
 func (c *Client) Dial(n, addr string) (net.Conn, error) {
diff --git a/ssh/tcpip_test.go b/ssh/tcpip_test.go
index f1265cb..4d85114 100644
--- a/ssh/tcpip_test.go
+++ b/ssh/tcpip_test.go
@@ -5,7 +5,10 @@
 package ssh
 
 import (
+	"context"
+	"net"
 	"testing"
+	"time"
 )
 
 func TestAutoPortListenBroken(t *testing.T) {
@@ -18,3 +21,33 @@
 		t.Errorf("version %q marked as broken", works)
 	}
 }
+
+func TestClientImplementsDialContext(t *testing.T) {
+	type ContextDialer interface {
+		DialContext(context.Context, string, string) (net.Conn, error)
+	}
+	// Belt and suspenders assertion, since package net does not
+	// declare a ContextDialer type.
+	var _ ContextDialer = &net.Dialer{}
+	var _ ContextDialer = &Client{}
+}
+
+func TestClientDialContextWithCancel(t *testing.T) {
+	c := &Client{}
+	ctx, cancel := context.WithCancel(context.Background())
+	cancel()
+	_, err := c.DialContext(ctx, "tcp", "localhost:1000")
+	if err != context.Canceled {
+		t.Errorf("DialContext: got nil error, expected %v", context.Canceled)
+	}
+}
+
+func TestClientDialContextWithDeadline(t *testing.T) {
+	c := &Client{}
+	ctx, cancel := context.WithDeadline(context.Background(), time.Now())
+	defer cancel()
+	_, err := c.DialContext(ctx, "tcp", "localhost:1000")
+	if err != context.DeadlineExceeded {
+		t.Errorf("DialContext: got nil error, expected %v", context.DeadlineExceeded)
+	}
+}
diff --git a/ssh/test/dial_unix_test.go b/ssh/test/dial_unix_test.go
index 0a5f5e3..8ec8d50 100644
--- a/ssh/test/dial_unix_test.go
+++ b/ssh/test/dial_unix_test.go
@@ -9,6 +9,7 @@
 // direct-tcpip and direct-streamlocal functional tests
 
 import (
+	"context"
 	"fmt"
 	"io"
 	"net"
@@ -46,7 +47,11 @@
 		}
 	}()
 
-	conn, err := sshConn.Dial(n, l.Addr().String())
+	ctx, cancel := context.WithCancel(context.Background())
+	conn, err := sshConn.DialContext(ctx, n, l.Addr().String())
+	// Canceling the context after dial should have no effect
+	// on the opened connection.
+	cancel()
 	if err != nil {
 		t.Fatalf("Dial: %v", err)
 	}