internal/socks: add DialWithConn method to Dialer

This change adds DialWithConn method for allowing package users to use
own net.Conn implementations optionally.

Also makes the deprecated Dialer.Dial return a raw transport connection
instead of a forward proxy connection for preserving the backward
compatibility on proxy.Dialer.Dial method.

Fixes golang/go#25104.

Change-Id: I4259cd10e299c1e36406545708e9f6888191705a
Reviewed-on: https://go-review.googlesource.com/110135
Run-TryBot: Mikio Hara <mikioh.mikioh@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/internal/socks/dial_test.go b/internal/socks/dial_test.go
index 93101a6..3a7a31b 100644
--- a/internal/socks/dial_test.go
+++ b/internal/socks/dial_test.go
@@ -17,19 +17,11 @@
 	"golang.org/x/net/internal/sockstest"
 )
 
-const (
-	targetNetwork  = "tcp6"
-	targetHostname = "fqdn.doesnotexist"
-	targetHostIP   = "2001:db8::1"
-	targetPort     = "5963"
-)
-
 func TestDial(t *testing.T) {
 	t.Run("Connect", func(t *testing.T) {
 		ss, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
 		if err != nil {
-			t.Error(err)
-			return
+			t.Fatal(err)
 		}
 		defer ss.Close()
 		d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
@@ -41,21 +33,45 @@
 			Username: "username",
 			Password: "password",
 		}).Authenticate
-		c, err := d.Dial(targetNetwork, net.JoinHostPort(targetHostIP, targetPort))
-		if err == nil {
-			c.(*socks.Conn).BoundAddr()
-			c.Close()
-		}
+		c, err := d.DialContext(context.Background(), ss.TargetAddr().Network(), ss.TargetAddr().String())
 		if err != nil {
-			t.Error(err)
-			return
+			t.Fatal(err)
+		}
+		c.(*socks.Conn).BoundAddr()
+		c.Close()
+	})
+	t.Run("ConnectWithConn", func(t *testing.T) {
+		ss, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer ss.Close()
+		c, err := net.Dial(ss.Addr().Network(), ss.Addr().String())
+		if err != nil {
+			t.Fatal(err)
+		}
+		defer c.Close()
+		d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
+		d.AuthMethods = []socks.AuthMethod{
+			socks.AuthMethodNotRequired,
+			socks.AuthMethodUsernamePassword,
+		}
+		d.Authenticate = (&socks.UsernamePassword{
+			Username: "username",
+			Password: "password",
+		}).Authenticate
+		a, err := d.DialWithConn(context.Background(), c, ss.TargetAddr().Network(), ss.TargetAddr().String())
+		if err != nil {
+			t.Fatal(err)
+		}
+		if _, ok := a.(*socks.Addr); !ok {
+			t.Fatalf("got %+v; want socks.Addr", a)
 		}
 	})
 	t.Run("Cancel", func(t *testing.T) {
 		ss, err := sockstest.NewServer(sockstest.NoAuthRequired, blackholeCmdFunc)
 		if err != nil {
-			t.Error(err)
-			return
+			t.Fatal(err)
 		}
 		defer ss.Close()
 		d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
@@ -63,7 +79,7 @@
 		defer cancel()
 		dialErr := make(chan error)
 		go func() {
-			c, err := d.DialContext(ctx, ss.TargetAddr().Network(), net.JoinHostPort(targetHostname, targetPort))
+			c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String())
 			if err == nil {
 				c.Close()
 			}
@@ -73,41 +89,37 @@
 		cancel()
 		err = <-dialErr
 		if perr, nerr := parseDialError(err); perr != context.Canceled && nerr == nil {
-			t.Errorf("got %v; want context.Canceled or equivalent", err)
-			return
+			t.Fatalf("got %v; want context.Canceled or equivalent", err)
 		}
 	})
 	t.Run("Deadline", func(t *testing.T) {
 		ss, err := sockstest.NewServer(sockstest.NoAuthRequired, blackholeCmdFunc)
 		if err != nil {
-			t.Error(err)
-			return
+			t.Fatal(err)
 		}
 		defer ss.Close()
 		d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
 		ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond))
 		defer cancel()
-		c, err := d.DialContext(ctx, ss.TargetAddr().Network(), net.JoinHostPort(targetHostname, targetPort))
+		c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String())
 		if err == nil {
 			c.Close()
 		}
 		if perr, nerr := parseDialError(err); perr != context.DeadlineExceeded && nerr == nil {
-			t.Errorf("got %v; want context.DeadlineExceeded or equivalent", err)
-			return
+			t.Fatalf("got %v; want context.DeadlineExceeded or equivalent", err)
 		}
 	})
 	t.Run("WithRogueServer", func(t *testing.T) {
 		ss, err := sockstest.NewServer(sockstest.NoAuthRequired, rogueCmdFunc)
 		if err != nil {
-			t.Error(err)
-			return
+			t.Fatal(err)
 		}
 		defer ss.Close()
 		d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
 		for i := 0; i < 2*len(rogueCmdList); i++ {
 			ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond))
 			defer cancel()
-			c, err := d.DialContext(ctx, targetNetwork, net.JoinHostPort(targetHostIP, targetPort))
+			c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String())
 			if err == nil {
 				t.Log(c.(*socks.Conn).BoundAddr())
 				c.Close()
diff --git a/internal/socks/socks.go b/internal/socks/socks.go
index fa38472..d93e699 100644
--- a/internal/socks/socks.go
+++ b/internal/socks/socks.go
@@ -149,20 +149,13 @@
 // See func Dial of the net package of standard library for a
 // description of the network and address parameters.
 func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
-	switch network {
-	case "tcp", "tcp6", "tcp4":
-	default:
+	if err := d.validateTarget(network, address); err != nil {
 		proxy, dst, _ := d.pathAddrs(address)
-		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("network not implemented")}
-	}
-	switch d.cmd {
-	case CmdConnect, cmdBind:
-	default:
-		proxy, dst, _ := d.pathAddrs(address)
-		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("command not implemented")}
+		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
 	}
 	if ctx == nil {
-		ctx = context.Background()
+		proxy, dst, _ := d.pathAddrs(address)
+		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")}
 	}
 	var err error
 	var c net.Conn
@@ -185,11 +178,69 @@
 	return &Conn{Conn: c, boundAddr: a}, nil
 }
 
+// DialWithConn initiates a connection from SOCKS server to the target
+// network and address using the connection c that is already
+// connected to the SOCKS server.
+//
+// It returns the connection's local address assigned by the SOCKS
+// server.
+func (d *Dialer) DialWithConn(ctx context.Context, c net.Conn, network, address string) (net.Addr, error) {
+	if err := d.validateTarget(network, address); err != nil {
+		proxy, dst, _ := d.pathAddrs(address)
+		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+	}
+	if ctx == nil {
+		proxy, dst, _ := d.pathAddrs(address)
+		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")}
+	}
+	a, err := d.connect(ctx, c, address)
+	if err != nil {
+		proxy, dst, _ := d.pathAddrs(address)
+		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+	}
+	return a, nil
+}
+
 // Dial connects to the provided address on the provided network.
 //
-// Deprecated: Use DialContext instead.
+// Unlike DialContext, it returns a raw transport connection instead
+// of a forward proxy connection.
+//
+// Deprecated: Use DialContext or DialWithConn instead.
 func (d *Dialer) Dial(network, address string) (net.Conn, error) {
-	return d.DialContext(context.Background(), network, address)
+	if err := d.validateTarget(network, address); err != nil {
+		proxy, dst, _ := d.pathAddrs(address)
+		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+	}
+	var err error
+	var c net.Conn
+	if d.ProxyDial != nil {
+		c, err = d.ProxyDial(context.Background(), d.proxyNetwork, d.proxyAddress)
+	} else {
+		c, err = net.Dial(d.proxyNetwork, d.proxyAddress)
+	}
+	if err != nil {
+		proxy, dst, _ := d.pathAddrs(address)
+		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+	}
+	if _, err := d.DialWithConn(context.Background(), c, network, address); err != nil {
+		return nil, err
+	}
+	return c, nil
+}
+
+func (d *Dialer) validateTarget(network, address string) error {
+	switch network {
+	case "tcp", "tcp6", "tcp4":
+	default:
+		return errors.New("network not implemented")
+	}
+	switch d.cmd {
+	case CmdConnect, cmdBind:
+	default:
+		return errors.New("command not implemented")
+	}
+	return nil
 }
 
 func (d *Dialer) pathAddrs(address string) (proxy, dst net.Addr, err error) {