internal/socks: add DialWithConn method to Dialer
This change adds DialWithConn method for allowing package users to use
own net.Conn implementations optionally.
Also makes the deprecated Dialer.Dial return a raw transport connection
instead of a forward proxy connection for preserving the backward
compatibility on proxy.Dialer.Dial method.
Fixes golang/go#25104.
Change-Id: I4259cd10e299c1e36406545708e9f6888191705a
Reviewed-on: https://go-review.googlesource.com/110135
Run-TryBot: Mikio Hara <mikioh.mikioh@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/internal/socks/dial_test.go b/internal/socks/dial_test.go
index 93101a6..3a7a31b 100644
--- a/internal/socks/dial_test.go
+++ b/internal/socks/dial_test.go
@@ -17,19 +17,11 @@
"golang.org/x/net/internal/sockstest"
)
-const (
- targetNetwork = "tcp6"
- targetHostname = "fqdn.doesnotexist"
- targetHostIP = "2001:db8::1"
- targetPort = "5963"
-)
-
func TestDial(t *testing.T) {
t.Run("Connect", func(t *testing.T) {
ss, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
if err != nil {
- t.Error(err)
- return
+ t.Fatal(err)
}
defer ss.Close()
d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
@@ -41,21 +33,45 @@
Username: "username",
Password: "password",
}).Authenticate
- c, err := d.Dial(targetNetwork, net.JoinHostPort(targetHostIP, targetPort))
- if err == nil {
- c.(*socks.Conn).BoundAddr()
- c.Close()
- }
+ c, err := d.DialContext(context.Background(), ss.TargetAddr().Network(), ss.TargetAddr().String())
if err != nil {
- t.Error(err)
- return
+ t.Fatal(err)
+ }
+ c.(*socks.Conn).BoundAddr()
+ c.Close()
+ })
+ t.Run("ConnectWithConn", func(t *testing.T) {
+ ss, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ss.Close()
+ c, err := net.Dial(ss.Addr().Network(), ss.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
+ d.AuthMethods = []socks.AuthMethod{
+ socks.AuthMethodNotRequired,
+ socks.AuthMethodUsernamePassword,
+ }
+ d.Authenticate = (&socks.UsernamePassword{
+ Username: "username",
+ Password: "password",
+ }).Authenticate
+ a, err := d.DialWithConn(context.Background(), c, ss.TargetAddr().Network(), ss.TargetAddr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, ok := a.(*socks.Addr); !ok {
+ t.Fatalf("got %+v; want socks.Addr", a)
}
})
t.Run("Cancel", func(t *testing.T) {
ss, err := sockstest.NewServer(sockstest.NoAuthRequired, blackholeCmdFunc)
if err != nil {
- t.Error(err)
- return
+ t.Fatal(err)
}
defer ss.Close()
d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
@@ -63,7 +79,7 @@
defer cancel()
dialErr := make(chan error)
go func() {
- c, err := d.DialContext(ctx, ss.TargetAddr().Network(), net.JoinHostPort(targetHostname, targetPort))
+ c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String())
if err == nil {
c.Close()
}
@@ -73,41 +89,37 @@
cancel()
err = <-dialErr
if perr, nerr := parseDialError(err); perr != context.Canceled && nerr == nil {
- t.Errorf("got %v; want context.Canceled or equivalent", err)
- return
+ t.Fatalf("got %v; want context.Canceled or equivalent", err)
}
})
t.Run("Deadline", func(t *testing.T) {
ss, err := sockstest.NewServer(sockstest.NoAuthRequired, blackholeCmdFunc)
if err != nil {
- t.Error(err)
- return
+ t.Fatal(err)
}
defer ss.Close()
d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond))
defer cancel()
- c, err := d.DialContext(ctx, ss.TargetAddr().Network(), net.JoinHostPort(targetHostname, targetPort))
+ c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String())
if err == nil {
c.Close()
}
if perr, nerr := parseDialError(err); perr != context.DeadlineExceeded && nerr == nil {
- t.Errorf("got %v; want context.DeadlineExceeded or equivalent", err)
- return
+ t.Fatalf("got %v; want context.DeadlineExceeded or equivalent", err)
}
})
t.Run("WithRogueServer", func(t *testing.T) {
ss, err := sockstest.NewServer(sockstest.NoAuthRequired, rogueCmdFunc)
if err != nil {
- t.Error(err)
- return
+ t.Fatal(err)
}
defer ss.Close()
d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
for i := 0; i < 2*len(rogueCmdList); i++ {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond))
defer cancel()
- c, err := d.DialContext(ctx, targetNetwork, net.JoinHostPort(targetHostIP, targetPort))
+ c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String())
if err == nil {
t.Log(c.(*socks.Conn).BoundAddr())
c.Close()
diff --git a/internal/socks/socks.go b/internal/socks/socks.go
index fa38472..d93e699 100644
--- a/internal/socks/socks.go
+++ b/internal/socks/socks.go
@@ -149,20 +149,13 @@
// See func Dial of the net package of standard library for a
// description of the network and address parameters.
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
- switch network {
- case "tcp", "tcp6", "tcp4":
- default:
+ if err := d.validateTarget(network, address); err != nil {
proxy, dst, _ := d.pathAddrs(address)
- return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("network not implemented")}
- }
- switch d.cmd {
- case CmdConnect, cmdBind:
- default:
- proxy, dst, _ := d.pathAddrs(address)
- return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("command not implemented")}
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
if ctx == nil {
- ctx = context.Background()
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")}
}
var err error
var c net.Conn
@@ -185,11 +178,69 @@
return &Conn{Conn: c, boundAddr: a}, nil
}
+// DialWithConn initiates a connection from SOCKS server to the target
+// network and address using the connection c that is already
+// connected to the SOCKS server.
+//
+// It returns the connection's local address assigned by the SOCKS
+// server.
+func (d *Dialer) DialWithConn(ctx context.Context, c net.Conn, network, address string) (net.Addr, error) {
+ if err := d.validateTarget(network, address); err != nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+ }
+ if ctx == nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")}
+ }
+ a, err := d.connect(ctx, c, address)
+ if err != nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+ }
+ return a, nil
+}
+
// Dial connects to the provided address on the provided network.
//
-// Deprecated: Use DialContext instead.
+// Unlike DialContext, it returns a raw transport connection instead
+// of a forward proxy connection.
+//
+// Deprecated: Use DialContext or DialWithConn instead.
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
- return d.DialContext(context.Background(), network, address)
+ if err := d.validateTarget(network, address); err != nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+ }
+ var err error
+ var c net.Conn
+ if d.ProxyDial != nil {
+ c, err = d.ProxyDial(context.Background(), d.proxyNetwork, d.proxyAddress)
+ } else {
+ c, err = net.Dial(d.proxyNetwork, d.proxyAddress)
+ }
+ if err != nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+ }
+ if _, err := d.DialWithConn(context.Background(), c, network, address); err != nil {
+ return nil, err
+ }
+ return c, nil
+}
+
+func (d *Dialer) validateTarget(network, address string) error {
+ switch network {
+ case "tcp", "tcp6", "tcp4":
+ default:
+ return errors.New("network not implemented")
+ }
+ switch d.cmd {
+ case CmdConnect, cmdBind:
+ default:
+ return errors.New("command not implemented")
+ }
+ return nil
}
func (d *Dialer) pathAddrs(address string) (proxy, dst net.Addr, err error) {