proxy: add ProxyFromEnvironmentUsing

Updates golang/go#31813

Change-Id: Ic05fcdb5881c9e01967697542228224611b7a73f
Reviewed-on: https://go-review.googlesource.com/c/net/+/175100
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Jacob Blain Christen <dweomer5@gmail.com>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/proxy/direct.go b/proxy/direct.go
index 26b51c3..3d66bde 100644
--- a/proxy/direct.go
+++ b/proxy/direct.go
@@ -11,9 +11,14 @@
 
 type direct struct{}
 
-// Direct is a direct proxy: one that makes network connections directly.
+// Direct implements Dialer by making network connections directly using net.Dial or net.DialContext.
 var Direct = direct{}
 
+var (
+	_ Dialer        = Direct
+	_ ContextDialer = Direct
+)
+
 // Dial directly invokes net.Dial with the supplied parameters.
 func (direct) Dial(network, addr string) (net.Conn, error) {
 	return net.Dial(network, addr)
diff --git a/proxy/proxy.go b/proxy/proxy.go
index 37d3cab..9ff4b9a 100644
--- a/proxy/proxy.go
+++ b/proxy/proxy.go
@@ -26,21 +26,30 @@
 	User, Password string
 }
 
-// FromEnvironment returns the dialer specified by the proxy related variables in
-// the environment.
+// FromEnvironment returns the dialer specified by the proxy-related
+// variables in the environment and makes underlying connections
+// directly.
 func FromEnvironment() Dialer {
+	return FromEnvironmentUsing(Direct)
+}
+
+// FromEnvironmentUsing returns the dialer specify by the proxy-related
+// variables in the environment and makes underlying connections
+// using the provided forwarding Dialer (for instance, a *net.Dialer
+// with desired configuration).
+func FromEnvironmentUsing(forward Dialer) Dialer {
 	allProxy := allProxyEnv.Get()
 	if len(allProxy) == 0 {
-		return Direct
+		return forward
 	}
 
 	proxyURL, err := url.Parse(allProxy)
 	if err != nil {
-		return Direct
+		return forward
 	}
-	proxy, err := FromURL(proxyURL, Direct)
+	proxy, err := FromURL(proxyURL, forward)
 	if err != nil {
-		return Direct
+		return forward
 	}
 
 	noProxy := noProxyEnv.Get()
@@ -48,7 +57,7 @@
 		return proxy
 	}
 
-	perHost := NewPerHost(proxy, Direct)
+	perHost := NewPerHost(proxy, forward)
 	perHost.AddFromString(noProxy)
 	return perHost
 }
diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go
index d260d69..567fc9c 100644
--- a/proxy/proxy_test.go
+++ b/proxy/proxy_test.go
@@ -6,7 +6,10 @@
 
 import (
 	"bytes"
+	"context"
+	"errors"
 	"fmt"
+	"net"
 	"net/url"
 	"os"
 	"strings"
@@ -110,6 +113,37 @@
 	c.Close()
 }
 
+type funcFailDialer func(context.Context) error
+
+func (f funcFailDialer) Dial(net, addr string) (net.Conn, error) {
+	panic("shouldn't see a call to Dial")
+}
+
+func (f funcFailDialer) DialContext(ctx context.Context, net, addr string) (net.Conn, error) {
+	return nil, f(ctx)
+}
+
+// Check that FromEnvironmentUsing uses our dialer.
+func TestFromEnvironmentUsing(t *testing.T) {
+	ResetProxyEnv()
+	errFoo := errors.New("some error to check our dialer was used)")
+	type key string
+	ctx := context.WithValue(context.Background(), key("foo"), "bar")
+	dialer := FromEnvironmentUsing(funcFailDialer(func(ctx context.Context) error {
+		if got := ctx.Value(key("foo")); got != "bar" {
+			t.Errorf("Resolver context = %T %v, want %q", got, got, "bar")
+		}
+		return errFoo
+	}))
+	_, err := dialer.(ContextDialer).DialContext(ctx, "tcp", "foo.tld:123")
+	if err == nil {
+		t.Fatalf("unexpected success")
+	}
+	if !strings.Contains(err.Error(), errFoo.Error()) {
+		t.Errorf("got unexpected error %q; want substr %q", err, errFoo)
+	}
+}
+
 func ResetProxyEnv() {
 	for _, env := range []*envOnce{allProxyEnv, noProxyEnv} {
 		for _, v := range env.names {