proxy: add Dial (with context)

The existing API does not allow client code to take advantage of Dialer implementations that implement DialContext receivers. This a familiar API, see net.Dialer.

Fixes golang/go#27874
Fixes golang/go#19354
Fixes golang/go#17759
Fixes golang/go#13455

Change-Id: I0f247783d2037da28c9917db99adda51db1647bd
GitHub-Last-Rev: b0a372707fc4c45772f19b1b886c8823dd613810
GitHub-Pull-Request: golang/net#38
Reviewed-on: https://go-review.googlesource.com/c/net/+/168921
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/proxy/dial.go b/proxy/dial.go
new file mode 100644
index 0000000..811c2e4
--- /dev/null
+++ b/proxy/dial.go
@@ -0,0 +1,54 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package proxy
+
+import (
+	"context"
+	"net"
+)
+
+// A ContextDialer dials using a context.
+type ContextDialer interface {
+	DialContext(ctx context.Context, network, address string) (net.Conn, error)
+}
+
+// Dial works like DialContext on net.Dialer but using a dialer returned by FromEnvironment.
+//
+// The passed ctx is only used for returning the Conn, not the lifetime of the Conn.
+//
+// Custom dialers (registered via RegisterDialerType) that do not implement ContextDialer
+// can leak a goroutine for as long as it takes the underlying Dialer implementation to timeout.
+//
+// A Conn returned from a successful Dial after the context has been cancelled will be immediately closed.
+func Dial(ctx context.Context, network, address string) (net.Conn, error) {
+	d := FromEnvironment()
+	if xd, ok := d.(ContextDialer); ok {
+		return xd.DialContext(ctx, network, address)
+	}
+	return dialContext(ctx, d, network, address)
+}
+
+// WARNING: this can leak a goroutine for as long as the underlying Dialer implementation takes to timeout
+// A Conn returned from a successful Dial after the context has been cancelled will be immediately closed.
+func dialContext(ctx context.Context, d Dialer, network, address string) (net.Conn, error) {
+	var (
+		conn net.Conn
+		done = make(chan struct{}, 1)
+		err  error
+	)
+	go func() {
+		conn, err = d.Dial(network, address)
+		close(done)
+		if conn != nil && ctx.Err() != nil {
+			conn.Close()
+		}
+	}()
+	select {
+	case <-ctx.Done():
+		err = ctx.Err()
+	case <-done:
+	}
+	return conn, err
+}
diff --git a/proxy/dial_test.go b/proxy/dial_test.go
new file mode 100644
index 0000000..3edab49
--- /dev/null
+++ b/proxy/dial_test.go
@@ -0,0 +1,131 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package proxy
+
+import (
+	"context"
+	"fmt"
+	"net"
+	"os"
+	"testing"
+	"time"
+
+	"golang.org/x/net/internal/sockstest"
+)
+
+func TestDial(t *testing.T) {
+	ResetProxyEnv()
+	t.Run("DirectWithCancel", func(t *testing.T) {
+		defer ResetProxyEnv()
+		l, err := net.Listen("tcp", "127.0.0.1:0")
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer l.Close()
+		_, port, err := net.SplitHostPort(l.Addr().String())
+		if err != nil {
+			t.Fatal(err)
+		}
+		ctx, cancel := context.WithCancel(context.Background())
+		defer cancel()
+		c, err := Dial(ctx, l.Addr().Network(), net.JoinHostPort("", port))
+		if err != nil {
+			t.Fatal(err)
+		}
+		c.Close()
+	})
+	t.Run("DirectWithTimeout", func(t *testing.T) {
+		defer ResetProxyEnv()
+		l, err := net.Listen("tcp", "127.0.0.1:0")
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer l.Close()
+		_, port, err := net.SplitHostPort(l.Addr().String())
+		if err != nil {
+			t.Fatal(err)
+		}
+		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+		defer cancel()
+		c, err := Dial(ctx, l.Addr().Network(), net.JoinHostPort("", port))
+		if err != nil {
+			t.Fatal(err)
+		}
+		c.Close()
+	})
+	t.Run("DirectWithTimeoutExceeded", func(t *testing.T) {
+		defer ResetProxyEnv()
+		l, err := net.Listen("tcp", "127.0.0.1:0")
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer l.Close()
+		_, port, err := net.SplitHostPort(l.Addr().String())
+		if err != nil {
+			t.Fatal(err)
+		}
+		ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond)
+		time.Sleep(time.Millisecond)
+		defer cancel()
+		c, err := Dial(ctx, l.Addr().Network(), net.JoinHostPort("", port))
+		if err == nil {
+			defer c.Close()
+			t.Fatal("failed to timeout")
+		}
+	})
+	t.Run("SOCKS5", func(t *testing.T) {
+		defer ResetProxyEnv()
+		s, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer s.Close()
+		if err = os.Setenv("ALL_PROXY", fmt.Sprintf("socks5://%s", s.Addr().String())); err != nil {
+			t.Fatal(err)
+		}
+		c, err := Dial(context.Background(), s.TargetAddr().Network(), s.TargetAddr().String())
+		if err != nil {
+			t.Fatal(err)
+		}
+		c.Close()
+	})
+	t.Run("SOCKS5WithTimeout", func(t *testing.T) {
+		defer ResetProxyEnv()
+		s, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer s.Close()
+		if err = os.Setenv("ALL_PROXY", fmt.Sprintf("socks5://%s", s.Addr().String())); err != nil {
+			t.Fatal(err)
+		}
+		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+		defer cancel()
+		c, err := Dial(ctx, s.TargetAddr().Network(), s.TargetAddr().String())
+		if err != nil {
+			t.Fatal(err)
+		}
+		c.Close()
+	})
+	t.Run("SOCKS5WithTimeoutExceeded", func(t *testing.T) {
+		defer ResetProxyEnv()
+		s, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer s.Close()
+		if err = os.Setenv("ALL_PROXY", fmt.Sprintf("socks5://%s", s.Addr().String())); err != nil {
+			t.Fatal(err)
+		}
+		ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond)
+		time.Sleep(time.Millisecond)
+		defer cancel()
+		c, err := Dial(ctx, s.TargetAddr().Network(), s.TargetAddr().String())
+		if err == nil {
+			defer c.Close()
+			t.Fatal("failed to timeout")
+		}
+	})
+}
diff --git a/proxy/direct.go b/proxy/direct.go
index 4c5ad88..26b51c3 100644
--- a/proxy/direct.go
+++ b/proxy/direct.go
@@ -5,6 +5,7 @@
 package proxy
 
 import (
+	"context"
 	"net"
 )
 
@@ -13,6 +14,13 @@
 // Direct is a direct proxy: one that makes network connections directly.
 var Direct = direct{}
 
+// Dial directly invokes net.Dial with the supplied parameters.
 func (direct) Dial(network, addr string) (net.Conn, error) {
 	return net.Dial(network, addr)
 }
+
+// DialContext instantiates a net.Dialer and invokes its DialContext receiver with the supplied parameters.
+func (direct) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
+	var d net.Dialer
+	return d.DialContext(ctx, network, addr)
+}
diff --git a/proxy/per_host.go b/proxy/per_host.go
index 0689bb6..573fe79 100644
--- a/proxy/per_host.go
+++ b/proxy/per_host.go
@@ -5,6 +5,7 @@
 package proxy
 
 import (
+	"context"
 	"net"
 	"strings"
 )
@@ -41,6 +42,20 @@
 	return p.dialerForRequest(host).Dial(network, addr)
 }
 
+// DialContext connects to the address addr on the given network through either
+// defaultDialer or bypass.
+func (p *PerHost) DialContext(ctx context.Context, network, addr string) (c net.Conn, err error) {
+	host, _, err := net.SplitHostPort(addr)
+	if err != nil {
+		return nil, err
+	}
+	d := p.dialerForRequest(host)
+	if x, ok := d.(ContextDialer); ok {
+		return x.DialContext(ctx, network, addr)
+	}
+	return dialContext(ctx, d, network, addr)
+}
+
 func (p *PerHost) dialerForRequest(host string) Dialer {
 	if ip := net.ParseIP(host); ip != nil {
 		for _, net := range p.bypassNetworks {
diff --git a/proxy/per_host_test.go b/proxy/per_host_test.go
index a7d8095..0447eb4 100644
--- a/proxy/per_host_test.go
+++ b/proxy/per_host_test.go
@@ -5,6 +5,7 @@
 package proxy
 
 import (
+	"context"
 	"errors"
 	"net"
 	"reflect"
@@ -21,10 +22,6 @@
 }
 
 func TestPerHost(t *testing.T) {
-	var def, bypass recordingProxy
-	perHost := NewPerHost(&def, &bypass)
-	perHost.AddFromString("localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16")
-
 	expectedDef := []string{
 		"example.com:123",
 		"1.2.3.4:123",
@@ -39,17 +36,41 @@
 		"[1000::]:123",
 	}
 
-	for _, addr := range expectedDef {
-		perHost.Dial("tcp", addr)
-	}
-	for _, addr := range expectedBypass {
-		perHost.Dial("tcp", addr)
-	}
+	t.Run("Dial", func(t *testing.T) {
+		var def, bypass recordingProxy
+		perHost := NewPerHost(&def, &bypass)
+		perHost.AddFromString("localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16")
+		for _, addr := range expectedDef {
+			perHost.Dial("tcp", addr)
+		}
+		for _, addr := range expectedBypass {
+			perHost.Dial("tcp", addr)
+		}
 
-	if !reflect.DeepEqual(expectedDef, def.addrs) {
-		t.Errorf("Hosts which went to the default proxy didn't match. Got %v, want %v", def.addrs, expectedDef)
-	}
-	if !reflect.DeepEqual(expectedBypass, bypass.addrs) {
-		t.Errorf("Hosts which went to the bypass proxy didn't match. Got %v, want %v", bypass.addrs, expectedBypass)
-	}
+		if !reflect.DeepEqual(expectedDef, def.addrs) {
+			t.Errorf("Hosts which went to the default proxy didn't match. Got %v, want %v", def.addrs, expectedDef)
+		}
+		if !reflect.DeepEqual(expectedBypass, bypass.addrs) {
+			t.Errorf("Hosts which went to the bypass proxy didn't match. Got %v, want %v", bypass.addrs, expectedBypass)
+		}
+	})
+
+	t.Run("DialContext", func(t *testing.T) {
+		var def, bypass recordingProxy
+		perHost := NewPerHost(&def, &bypass)
+		perHost.AddFromString("localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16")
+		for _, addr := range expectedDef {
+			perHost.DialContext(context.Background(), "tcp", addr)
+		}
+		for _, addr := range expectedBypass {
+			perHost.DialContext(context.Background(), "tcp", addr)
+		}
+
+		if !reflect.DeepEqual(expectedDef, def.addrs) {
+			t.Errorf("Hosts which went to the default proxy didn't match. Got %v, want %v", def.addrs, expectedDef)
+		}
+		if !reflect.DeepEqual(expectedBypass, bypass.addrs) {
+			t.Errorf("Hosts which went to the bypass proxy didn't match. Got %v, want %v", bypass.addrs, expectedBypass)
+		}
+	})
 }
diff --git a/proxy/proxy.go b/proxy/proxy.go
index f6026b9..37d3cab 100644
--- a/proxy/proxy.go
+++ b/proxy/proxy.go
@@ -15,6 +15,7 @@
 )
 
 // A Dialer is a means to establish a connection.
+// Custom dialers should also implement ContextDialer.
 type Dialer interface {
 	// Dial connects to the given address via the proxy.
 	Dial(network, addr string) (c net.Conn, err error)
diff --git a/proxy/socks5.go b/proxy/socks5.go
index 56345ec..c91651f 100644
--- a/proxy/socks5.go
+++ b/proxy/socks5.go
@@ -17,8 +17,14 @@
 func SOCKS5(network, address string, auth *Auth, forward Dialer) (Dialer, error) {
 	d := socks.NewDialer(network, address)
 	if forward != nil {
-		d.ProxyDial = func(_ context.Context, network string, address string) (net.Conn, error) {
-			return forward.Dial(network, address)
+		if f, ok := forward.(ContextDialer); ok {
+			d.ProxyDial = func(ctx context.Context, network string, address string) (net.Conn, error) {
+				return f.DialContext(ctx, network, address)
+			}
+		} else {
+			d.ProxyDial = func(ctx context.Context, network string, address string) (net.Conn, error) {
+				return dialContext(ctx, forward, network, address)
+			}
 		}
 	}
 	if auth != nil {