internal/{socks,sockstest}: new packages

This change factors out the code related to SOCKS protocol version 5
from the golang/x/net/proxy package and provides new SOCKS-specific
API to fix the following:
- inflexbility of forward proxy connection setup; e.g., no support for
  context-based deadline or canceling, no support for dial deadline,
  no support for working with external authentication mechanisms,
- useless error values for troubleshooting.

The new package socks is supposed to be used by the net/http package
of standard library and proxy package of golang.org/x/net repository.

Fixes golang/go#11682.
Updates golang/go#17759.
Updates golang/go#19354.
Updates golang/go#19688.
Fixes golang/go#21333.

Change-Id: I24098ac8522dcbdceb03d534147c5101ec9e7350
Reviewed-on: https://go-review.googlesource.com/38278
Run-TryBot: Mikio Hara <mikioh.mikioh@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/internal/socks/client.go b/internal/socks/client.go
new file mode 100644
index 0000000..3d6f516
--- /dev/null
+++ b/internal/socks/client.go
@@ -0,0 +1,168 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package socks
+
+import (
+	"context"
+	"errors"
+	"io"
+	"net"
+	"strconv"
+	"time"
+)
+
+var (
+	noDeadline   = time.Time{}
+	aLongTimeAgo = time.Unix(1, 0)
+)
+
+func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) {
+	host, port, err := splitHostPort(address)
+	if err != nil {
+		return nil, err
+	}
+	if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
+		c.SetDeadline(deadline)
+		defer c.SetDeadline(noDeadline)
+	}
+	if ctx != context.Background() {
+		errCh := make(chan error, 1)
+		done := make(chan struct{})
+		defer func() {
+			close(done)
+			if ctxErr == nil {
+				ctxErr = <-errCh
+			}
+		}()
+		go func() {
+			select {
+			case <-ctx.Done():
+				c.SetDeadline(aLongTimeAgo)
+				errCh <- ctx.Err()
+			case <-done:
+				errCh <- nil
+			}
+		}()
+	}
+
+	b := make([]byte, 0, 6+len(host)) // the size here is just an estimate
+	b = append(b, Version5)
+	if len(d.AuthMethods) == 0 || d.Authenticate == nil {
+		b = append(b, 1, byte(AuthMethodNotRequired))
+	} else {
+		ams := d.AuthMethods
+		if len(ams) > 255 {
+			return nil, errors.New("too many authentication methods")
+		}
+		b = append(b, byte(len(ams)))
+		for _, am := range ams {
+			b = append(b, byte(am))
+		}
+	}
+	if _, ctxErr = c.Write(b); ctxErr != nil {
+		return
+	}
+
+	if _, ctxErr = io.ReadFull(c, b[:2]); ctxErr != nil {
+		return
+	}
+	if b[0] != Version5 {
+		return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
+	}
+	am := AuthMethod(b[1])
+	if am == AuthMethodNoAcceptableMethods {
+		return nil, errors.New("no acceptable authentication methods")
+	}
+	if d.Authenticate != nil {
+		if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil {
+			return
+		}
+	}
+
+	b = b[:0]
+	b = append(b, Version5, byte(d.cmd), 0)
+	if ip := net.ParseIP(host); ip != nil {
+		if ip4 := ip.To4(); ip4 != nil {
+			b = append(b, AddrTypeIPv4)
+			b = append(b, ip4...)
+		} else if ip6 := ip.To16(); ip6 != nil {
+			b = append(b, AddrTypeIPv6)
+			b = append(b, ip6...)
+		} else {
+			return nil, errors.New("unknown address type")
+		}
+	} else {
+		if len(host) > 255 {
+			return nil, errors.New("FQDN too long")
+		}
+		b = append(b, AddrTypeFQDN)
+		b = append(b, byte(len(host)))
+		b = append(b, host...)
+	}
+	b = append(b, byte(port>>8), byte(port))
+	if _, ctxErr = c.Write(b); ctxErr != nil {
+		return
+	}
+
+	if _, ctxErr = io.ReadFull(c, b[:4]); ctxErr != nil {
+		return
+	}
+	if b[0] != Version5 {
+		return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
+	}
+	if cmdErr := Reply(b[1]); cmdErr != StatusSucceeded {
+		return nil, errors.New("unknown error " + cmdErr.String())
+	}
+	if b[2] != 0 {
+		return nil, errors.New("non-zero reserved field")
+	}
+	l := 2
+	var a Addr
+	switch b[3] {
+	case AddrTypeIPv4:
+		l += net.IPv4len
+		a.IP = make(net.IP, net.IPv4len)
+	case AddrTypeIPv6:
+		l += net.IPv6len
+		a.IP = make(net.IP, net.IPv6len)
+	case AddrTypeFQDN:
+		if _, err := io.ReadFull(c, b[:1]); err != nil {
+			return nil, err
+		}
+		l += int(b[0])
+	default:
+		return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3])))
+	}
+	if cap(b) < l {
+		b = make([]byte, l)
+	} else {
+		b = b[:l]
+	}
+	if _, ctxErr = io.ReadFull(c, b); ctxErr != nil {
+		return
+	}
+	if a.IP != nil {
+		copy(a.IP, b)
+	} else {
+		a.Name = string(b[:len(b)-2])
+	}
+	a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1])
+	return &a, nil
+}
+
+func splitHostPort(address string) (string, int, error) {
+	host, port, err := net.SplitHostPort(address)
+	if err != nil {
+		return "", 0, err
+	}
+	portnum, err := strconv.Atoi(port)
+	if err != nil {
+		return "", 0, err
+	}
+	if 1 > portnum || portnum > 0xffff {
+		return "", 0, errors.New("port number out of range " + port)
+	}
+	return host, portnum, nil
+}
diff --git a/internal/socks/dial_test.go b/internal/socks/dial_test.go
new file mode 100644
index 0000000..93101a6
--- /dev/null
+++ b/internal/socks/dial_test.go
@@ -0,0 +1,158 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package socks_test
+
+import (
+	"context"
+	"io"
+	"math/rand"
+	"net"
+	"os"
+	"testing"
+	"time"
+
+	"golang.org/x/net/internal/socks"
+	"golang.org/x/net/internal/sockstest"
+)
+
+const (
+	targetNetwork  = "tcp6"
+	targetHostname = "fqdn.doesnotexist"
+	targetHostIP   = "2001:db8::1"
+	targetPort     = "5963"
+)
+
+func TestDial(t *testing.T) {
+	t.Run("Connect", func(t *testing.T) {
+		ss, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
+		if err != nil {
+			t.Error(err)
+			return
+		}
+		defer ss.Close()
+		d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
+		d.AuthMethods = []socks.AuthMethod{
+			socks.AuthMethodNotRequired,
+			socks.AuthMethodUsernamePassword,
+		}
+		d.Authenticate = (&socks.UsernamePassword{
+			Username: "username",
+			Password: "password",
+		}).Authenticate
+		c, err := d.Dial(targetNetwork, net.JoinHostPort(targetHostIP, targetPort))
+		if err == nil {
+			c.(*socks.Conn).BoundAddr()
+			c.Close()
+		}
+		if err != nil {
+			t.Error(err)
+			return
+		}
+	})
+	t.Run("Cancel", func(t *testing.T) {
+		ss, err := sockstest.NewServer(sockstest.NoAuthRequired, blackholeCmdFunc)
+		if err != nil {
+			t.Error(err)
+			return
+		}
+		defer ss.Close()
+		d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
+		ctx, cancel := context.WithCancel(context.Background())
+		defer cancel()
+		dialErr := make(chan error)
+		go func() {
+			c, err := d.DialContext(ctx, ss.TargetAddr().Network(), net.JoinHostPort(targetHostname, targetPort))
+			if err == nil {
+				c.Close()
+			}
+			dialErr <- err
+		}()
+		time.Sleep(100 * time.Millisecond)
+		cancel()
+		err = <-dialErr
+		if perr, nerr := parseDialError(err); perr != context.Canceled && nerr == nil {
+			t.Errorf("got %v; want context.Canceled or equivalent", err)
+			return
+		}
+	})
+	t.Run("Deadline", func(t *testing.T) {
+		ss, err := sockstest.NewServer(sockstest.NoAuthRequired, blackholeCmdFunc)
+		if err != nil {
+			t.Error(err)
+			return
+		}
+		defer ss.Close()
+		d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
+		ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond))
+		defer cancel()
+		c, err := d.DialContext(ctx, ss.TargetAddr().Network(), net.JoinHostPort(targetHostname, targetPort))
+		if err == nil {
+			c.Close()
+		}
+		if perr, nerr := parseDialError(err); perr != context.DeadlineExceeded && nerr == nil {
+			t.Errorf("got %v; want context.DeadlineExceeded or equivalent", err)
+			return
+		}
+	})
+	t.Run("WithRogueServer", func(t *testing.T) {
+		ss, err := sockstest.NewServer(sockstest.NoAuthRequired, rogueCmdFunc)
+		if err != nil {
+			t.Error(err)
+			return
+		}
+		defer ss.Close()
+		d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
+		for i := 0; i < 2*len(rogueCmdList); i++ {
+			ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond))
+			defer cancel()
+			c, err := d.DialContext(ctx, targetNetwork, net.JoinHostPort(targetHostIP, targetPort))
+			if err == nil {
+				t.Log(c.(*socks.Conn).BoundAddr())
+				c.Close()
+				t.Error("should fail")
+			}
+		}
+	})
+}
+
+func blackholeCmdFunc(rw io.ReadWriter, b []byte) error {
+	if _, err := sockstest.ParseCmdRequest(b); err != nil {
+		return err
+	}
+	var bb [1]byte
+	for {
+		if _, err := rw.Read(bb[:]); err != nil {
+			return err
+		}
+	}
+}
+
+func rogueCmdFunc(rw io.ReadWriter, b []byte) error {
+	if _, err := sockstest.ParseCmdRequest(b); err != nil {
+		return err
+	}
+	rw.Write(rogueCmdList[rand.Intn(len(rogueCmdList))])
+	return nil
+}
+
+var rogueCmdList = [][]byte{
+	{0x05},
+	{0x06, 0x00, 0x00, 0x01, 192, 0, 2, 1, 0x17, 0x4b},
+	{0x05, 0x00, 0xff, 0x01, 192, 0, 2, 2, 0x17, 0x4b},
+	{0x05, 0x00, 0x00, 0x01, 192, 0, 2, 3},
+	{0x05, 0x00, 0x00, 0x03, 0x04, 'F', 'Q', 'D', 'N'},
+}
+
+func parseDialError(err error) (perr, nerr error) {
+	if e, ok := err.(*net.OpError); ok {
+		err = e.Err
+		nerr = e
+	}
+	if e, ok := err.(*os.SyscallError); ok {
+		err = e.Err
+	}
+	perr = err
+	return
+}
diff --git a/internal/socks/socks.go b/internal/socks/socks.go
new file mode 100644
index 0000000..9158595
--- /dev/null
+++ b/internal/socks/socks.go
@@ -0,0 +1,265 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package socks provides a SOCKS version 5 client implementation.
+//
+// SOCKS protocol version 5 is defined in RFC 1928.
+// Username/Password authentication for SOCKS version 5 is defined in
+// RFC 1929.
+package socks
+
+import (
+	"context"
+	"errors"
+	"io"
+	"net"
+	"strconv"
+)
+
+// A Command represents a SOCKS command.
+type Command int
+
+func (cmd Command) String() string {
+	switch cmd {
+	case CmdConnect:
+		return "socks connect"
+	case cmdBind:
+		return "socks bind"
+	default:
+		return "socks " + strconv.Itoa(int(cmd))
+	}
+}
+
+// An AuthMethod represents a SOCKS authentication method.
+type AuthMethod int
+
+// A Reply represents a SOCKS command reply code.
+type Reply int
+
+func (code Reply) String() string {
+	switch code {
+	case StatusSucceeded:
+		return "succeeded"
+	case 0x01:
+		return "general SOCKS server failure"
+	case 0x02:
+		return "connection not allowed by ruleset"
+	case 0x03:
+		return "network unreachable"
+	case 0x04:
+		return "host unreachable"
+	case 0x05:
+		return "connection refused"
+	case 0x06:
+		return "TTL expired"
+	case 0x07:
+		return "command not supported"
+	case 0x08:
+		return "address type not supported"
+	default:
+		return "unknown code: " + strconv.Itoa(int(code))
+	}
+}
+
+// Wire protocol constants.
+const (
+	Version5 = 0x05
+
+	AddrTypeIPv4 = 0x01
+	AddrTypeFQDN = 0x03
+	AddrTypeIPv6 = 0x04
+
+	CmdConnect Command = 0x01 // establishes an active-open forward proxy connection
+	cmdBind    Command = 0x02 // establishes a passive-open forward proxy connection
+
+	AuthMethodNotRequired         AuthMethod = 0x00 // no authentication required
+	AuthMethodUsernamePassword    AuthMethod = 0x02 // use username/password
+	AuthMethodNoAcceptableMethods AuthMethod = 0xff // no acceptable authetication methods
+
+	StatusSucceeded Reply = 0x00
+)
+
+// An Addr represents a SOCKS-specific address.
+// Either Name or IP is used exclusively.
+type Addr struct {
+	Name string // fully-qualified domain name
+	IP   net.IP
+	Port int
+}
+
+func (a *Addr) Network() string { return "socks" }
+
+func (a *Addr) String() string {
+	if a == nil {
+		return "<nil>"
+	}
+	port := strconv.Itoa(a.Port)
+	if a.IP == nil {
+		return net.JoinHostPort(a.Name, port)
+	}
+	return net.JoinHostPort(a.IP.String(), port)
+}
+
+// A Conn represents a forward proxy connection.
+type Conn struct {
+	net.Conn
+
+	boundAddr net.Addr
+}
+
+// BoundAddr returns the address assigned by the proxy server for
+// connecting to the command target address from the proxy server.
+func (c *Conn) BoundAddr() net.Addr {
+	if c == nil {
+		return nil
+	}
+	return c.boundAddr
+}
+
+// A Dialer holds SOCKS-specific options.
+type Dialer struct {
+	cmd          Command // either CmdConnect or cmdBind
+	proxyNetwork string  // network between a proxy server and a client
+	proxyAddress string  // proxy server address
+
+	// ProxyDial specifies the optional dial function for
+	// establishing the transport connection.
+	ProxyDial func(context.Context, string, string) (net.Conn, error)
+
+	// AuthMethods specifies the list of request authention
+	// methods.
+	// If empty, SOCKS client requests only AuthMethodNotRequired.
+	AuthMethods []AuthMethod
+
+	// Authenticate specifies the optional authentication
+	// function. It must be non-nil when AuthMethods is not empty.
+	// It must return an error when the authentication is failed.
+	Authenticate func(context.Context, io.ReadWriter, AuthMethod) error
+}
+
+// DialContext connects to the provided address on the provided
+// network.
+//
+// The returned error value may be a net.OpError. When the Op field of
+// net.OpError contains "socks", the Source field contains a proxy
+// server address and the Addr field contains a command target
+// address.
+//
+// See func Dial of the net package of standard library for a
+// description of the network and address parameters.
+func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
+	switch network {
+	case "tcp", "tcp6", "tcp4":
+	default:
+		proxy, dst, _ := d.pathAddrs(address)
+		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("network not implemented")}
+	}
+	switch d.cmd {
+	case CmdConnect, cmdBind:
+	default:
+		proxy, dst, _ := d.pathAddrs(address)
+		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("command not implemented")}
+	}
+	if ctx == nil {
+		ctx = context.Background()
+	}
+	var err error
+	var c net.Conn
+	if d.ProxyDial != nil {
+		c, err = d.ProxyDial(ctx, d.proxyNetwork, d.proxyAddress)
+	} else {
+		var dd net.Dialer
+		c, err = dd.DialContext(ctx, d.proxyNetwork, d.proxyAddress)
+	}
+	if err != nil {
+		proxy, dst, _ := d.pathAddrs(address)
+		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+	}
+	a, err := d.connect(ctx, c, address)
+	if err != nil {
+		c.Close()
+		proxy, dst, _ := d.pathAddrs(address)
+		return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+	}
+	return &Conn{Conn: c, boundAddr: a}, nil
+}
+
+// Dial connects to the provided address on the provided network.
+//
+// Deprecated: Use DialContext instead.
+func (d *Dialer) Dial(network, address string) (net.Conn, error) {
+	return d.DialContext(context.Background(), network, address)
+}
+
+func (d *Dialer) pathAddrs(address string) (proxy, dst net.Addr, err error) {
+	for i, s := range []string{d.proxyAddress, address} {
+		host, port, err := splitHostPort(s)
+		if err != nil {
+			return nil, nil, err
+		}
+		a := &Addr{Port: port}
+		a.IP = net.ParseIP(host)
+		if a.IP == nil {
+			a.Name = host
+		}
+		if i == 0 {
+			proxy = a
+		} else {
+			dst = a
+		}
+	}
+	return
+}
+
+// NewDialer returns a new Dialer that dials through the provided
+// proxy server's network and address.
+func NewDialer(network, address string) *Dialer {
+	return &Dialer{proxyNetwork: network, proxyAddress: address, cmd: CmdConnect}
+}
+
+const (
+	authUsernamePasswordVersion = 0x01
+	authStatusSucceeded         = 0x00
+)
+
+// UsernamePassword are the credentials for the username/password
+// authentication method.
+type UsernamePassword struct {
+	Username string
+	Password string
+}
+
+// Authenticate authenticates a pair of username and password with the
+// proxy server.
+func (up *UsernamePassword) Authenticate(ctx context.Context, rw io.ReadWriter, auth AuthMethod) error {
+	switch auth {
+	case AuthMethodNotRequired:
+		return nil
+	case AuthMethodUsernamePassword:
+		if len(up.Username) == 0 || len(up.Username) > 255 || len(up.Password) == 0 || len(up.Password) > 255 {
+			return errors.New("invalid username/password")
+		}
+		b := []byte{authUsernamePasswordVersion}
+		b = append(b, byte(len(up.Username)))
+		b = append(b, up.Username...)
+		b = append(b, byte(len(up.Password)))
+		b = append(b, up.Password...)
+		// TODO(mikio): handle IO deadlines and cancelation if
+		// necessary
+		if _, err := rw.Write(b); err != nil {
+			return err
+		}
+		if _, err := io.ReadFull(rw, b[:2]); err != nil {
+			return err
+		}
+		if b[0] != authUsernamePasswordVersion {
+			return errors.New("invalid username/password version")
+		}
+		if b[1] != authStatusSucceeded {
+			return errors.New("username/password authentication failed")
+		}
+		return nil
+	}
+	return errors.New("unsupported authentication method " + strconv.Itoa(int(auth)))
+}
diff --git a/internal/sockstest/server.go b/internal/sockstest/server.go
new file mode 100644
index 0000000..3c6e9e9
--- /dev/null
+++ b/internal/sockstest/server.go
@@ -0,0 +1,241 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package sockstest provides utilities for SOCKS testing.
+package sockstest
+
+import (
+	"errors"
+	"io"
+	"net"
+
+	"golang.org/x/net/internal/nettest"
+	"golang.org/x/net/internal/socks"
+)
+
+// An AuthRequest represents an authentication request.
+type AuthRequest struct {
+	Version int
+	Methods []socks.AuthMethod
+}
+
+// ParseAuthRequest parses an authentication request.
+func ParseAuthRequest(b []byte) (*AuthRequest, error) {
+	if len(b) < 2 {
+		return nil, errors.New("short auth request")
+	}
+	if b[0] != socks.Version5 {
+		return nil, errors.New("unexpected protocol version")
+	}
+	if len(b)-2 < int(b[1]) {
+		return nil, errors.New("short auth request")
+	}
+	req := &AuthRequest{Version: int(b[0])}
+	if b[1] > 0 {
+		req.Methods = make([]socks.AuthMethod, b[1])
+		for i, m := range b[2 : 2+b[1]] {
+			req.Methods[i] = socks.AuthMethod(m)
+		}
+	}
+	return req, nil
+}
+
+// MarshalAuthReply returns an authentication reply in wire format.
+func MarshalAuthReply(ver int, m socks.AuthMethod) ([]byte, error) {
+	return []byte{byte(ver), byte(m)}, nil
+}
+
+// A CmdRequest repesents a command request.
+type CmdRequest struct {
+	Version int
+	Cmd     socks.Command
+	Addr    socks.Addr
+}
+
+// ParseCmdRequest parses a command request.
+func ParseCmdRequest(b []byte) (*CmdRequest, error) {
+	if len(b) < 7 {
+		return nil, errors.New("short cmd request")
+	}
+	if b[0] != socks.Version5 {
+		return nil, errors.New("unexpected protocol version")
+	}
+	if socks.Command(b[1]) != socks.CmdConnect {
+		return nil, errors.New("unexpected command")
+	}
+	if b[2] != 0 {
+		return nil, errors.New("non-zero reserved field")
+	}
+	req := &CmdRequest{Version: int(b[0]), Cmd: socks.Command(b[1])}
+	l := 2
+	off := 4
+	switch b[3] {
+	case socks.AddrTypeIPv4:
+		l += net.IPv4len
+		req.Addr.IP = make(net.IP, net.IPv4len)
+	case socks.AddrTypeIPv6:
+		l += net.IPv6len
+		req.Addr.IP = make(net.IP, net.IPv6len)
+	case socks.AddrTypeFQDN:
+		l += int(b[4])
+		off = 5
+	default:
+		return nil, errors.New("unknown address type")
+	}
+	if len(b[off:]) < l {
+		return nil, errors.New("short cmd request")
+	}
+	if req.Addr.IP != nil {
+		copy(req.Addr.IP, b[off:])
+	} else {
+		req.Addr.Name = string(b[off : off+l-2])
+	}
+	req.Addr.Port = int(b[off+l-2])<<8 | int(b[off+l-1])
+	return req, nil
+}
+
+// MarshalCmdReply returns a command reply in wire format.
+func MarshalCmdReply(ver int, reply socks.Reply, a *socks.Addr) ([]byte, error) {
+	b := make([]byte, 4)
+	b[0] = byte(ver)
+	b[1] = byte(reply)
+	if a.Name != "" {
+		if len(a.Name) > 255 {
+			return nil, errors.New("fqdn too long")
+		}
+		b[3] = socks.AddrTypeFQDN
+		b = append(b, byte(len(a.Name)))
+		b = append(b, a.Name...)
+	} else if ip4 := a.IP.To4(); ip4 != nil {
+		b[3] = socks.AddrTypeIPv4
+		b = append(b, ip4...)
+	} else if ip6 := a.IP.To16(); ip6 != nil {
+		b[3] = socks.AddrTypeIPv6
+		b = append(b, ip6...)
+	} else {
+		return nil, errors.New("unknown address type")
+	}
+	b = append(b, byte(a.Port>>8), byte(a.Port))
+	return b, nil
+}
+
+// A Server repesents a server for handshake testing.
+type Server struct {
+	ln net.Listener
+}
+
+// Addr rerurns a server address.
+func (s *Server) Addr() net.Addr {
+	return s.ln.Addr()
+}
+
+// TargetAddr returns a fake final destination address.
+//
+// The returned address is only valid for testing with Server.
+func (s *Server) TargetAddr() net.Addr {
+	a := s.ln.Addr()
+	switch a := a.(type) {
+	case *net.TCPAddr:
+		if a.IP.To4() != nil {
+			return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 5963}
+		}
+		if a.IP.To16() != nil && a.IP.To4() == nil {
+			return &net.TCPAddr{IP: net.IPv6loopback, Port: 5963}
+		}
+	}
+	return nil
+}
+
+// Close closes the server.
+func (s *Server) Close() error {
+	return s.ln.Close()
+}
+
+func (s *Server) serve(authFunc, cmdFunc func(io.ReadWriter, []byte) error) {
+	c, err := s.ln.Accept()
+	if err != nil {
+		return
+	}
+	defer c.Close()
+	go s.serve(authFunc, cmdFunc)
+	b := make([]byte, 512)
+	n, err := c.Read(b)
+	if err != nil {
+		return
+	}
+	if err := authFunc(c, b[:n]); err != nil {
+		return
+	}
+	n, err = c.Read(b)
+	if err != nil {
+		return
+	}
+	if err := cmdFunc(c, b[:n]); err != nil {
+		return
+	}
+}
+
+// NewServer returns a new server.
+//
+// The provided authFunc and cmdFunc must parse requests and return
+// appropriate replies to clients.
+func NewServer(authFunc, cmdFunc func(io.ReadWriter, []byte) error) (*Server, error) {
+	var err error
+	s := new(Server)
+	s.ln, err = nettest.NewLocalListener("tcp")
+	if err != nil {
+		return nil, err
+	}
+	go s.serve(authFunc, cmdFunc)
+	return s, nil
+}
+
+// NoAuthRequired handles a no-authentication-required signaling.
+func NoAuthRequired(rw io.ReadWriter, b []byte) error {
+	req, err := ParseAuthRequest(b)
+	if err != nil {
+		return err
+	}
+	b, err = MarshalAuthReply(req.Version, socks.AuthMethodNotRequired)
+	if err != nil {
+		return err
+	}
+	n, err := rw.Write(b)
+	if err != nil {
+		return err
+	}
+	if n != len(b) {
+		return errors.New("short write")
+	}
+	return nil
+}
+
+// NoProxyRequired handles a command signaling without constructing a
+// proxy connection to the final destination.
+func NoProxyRequired(rw io.ReadWriter, b []byte) error {
+	req, err := ParseCmdRequest(b)
+	if err != nil {
+		return err
+	}
+	req.Addr.Port += 1
+	if req.Addr.Name != "" {
+		req.Addr.Name = "boundaddr.doesnotexist"
+	} else if req.Addr.IP.To4() != nil {
+		req.Addr.IP = net.IPv4(127, 0, 0, 1)
+	} else {
+		req.Addr.IP = net.IPv6loopback
+	}
+	b, err = MarshalCmdReply(socks.Version5, socks.StatusSucceeded, &req.Addr)
+	if err != nil {
+		return err
+	}
+	n, err := rw.Write(b)
+	if err != nil {
+		return err
+	}
+	if n != len(b) {
+		return errors.New("short write")
+	}
+	return nil
+}
diff --git a/internal/sockstest/server_test.go b/internal/sockstest/server_test.go
new file mode 100644
index 0000000..2b02d81
--- /dev/null
+++ b/internal/sockstest/server_test.go
@@ -0,0 +1,103 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package sockstest
+
+import (
+	"net"
+	"reflect"
+	"testing"
+
+	"golang.org/x/net/internal/socks"
+)
+
+func TestParseAuthRequest(t *testing.T) {
+	for i, tt := range []struct {
+		wire []byte
+		req  *AuthRequest
+	}{
+		{
+			[]byte{0x05, 0x00},
+			&AuthRequest{
+				socks.Version5,
+				nil,
+			},
+		},
+		{
+			[]byte{0x05, 0x01, 0xff},
+			&AuthRequest{
+				socks.Version5,
+				[]socks.AuthMethod{
+					socks.AuthMethodNoAcceptableMethods,
+				},
+			},
+		},
+		{
+			[]byte{0x05, 0x02, 0x00, 0xff},
+			&AuthRequest{
+				socks.Version5,
+				[]socks.AuthMethod{
+					socks.AuthMethodNotRequired,
+					socks.AuthMethodNoAcceptableMethods,
+				},
+			},
+		},
+
+		// corrupted requests
+		{nil, nil},
+		{[]byte{0x00, 0x01}, nil},
+		{[]byte{0x06, 0x00}, nil},
+		{[]byte{0x05, 0x02, 0x00}, nil},
+	} {
+		req, err := ParseAuthRequest(tt.wire)
+		if !reflect.DeepEqual(req, tt.req) {
+			t.Errorf("#%d: got %v, %v; want %v", i, req, err, tt.req)
+			continue
+		}
+	}
+}
+
+func TestParseCmdRequest(t *testing.T) {
+	for i, tt := range []struct {
+		wire []byte
+		req  *CmdRequest
+	}{
+		{
+			[]byte{0x05, 0x01, 0x00, 0x01, 192, 0, 2, 1, 0x17, 0x4b},
+			&CmdRequest{
+				socks.Version5,
+				socks.CmdConnect,
+				socks.Addr{
+					IP:   net.IP{192, 0, 2, 1},
+					Port: 5963,
+				},
+			},
+		},
+		{
+			[]byte{0x05, 0x01, 0x00, 0x03, 0x04, 'F', 'Q', 'D', 'N', 0x17, 0x4b},
+			&CmdRequest{
+				socks.Version5,
+				socks.CmdConnect,
+				socks.Addr{
+					Name: "FQDN",
+					Port: 5963,
+				},
+			},
+		},
+
+		// corrupted requests
+		{nil, nil},
+		{[]byte{0x05}, nil},
+		{[]byte{0x06, 0x01, 0x00, 0x01, 192, 0, 2, 2, 0x17, 0x4b}, nil},
+		{[]byte{0x05, 0x01, 0xff, 0x01, 192, 0, 2, 3}, nil},
+		{[]byte{0x05, 0x01, 0x00, 0x01, 192, 0, 2, 4}, nil},
+		{[]byte{0x05, 0x01, 0x00, 0x03, 0x04, 'F', 'Q', 'D', 'N'}, nil},
+	} {
+		req, err := ParseCmdRequest(tt.wire)
+		if !reflect.DeepEqual(req, tt.req) {
+			t.Errorf("#%d: got %v, %v; want %v", i, req, err, tt.req)
+			continue
+		}
+	}
+}
diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go
index 0f31e21..0be1b42 100644
--- a/proxy/proxy_test.go
+++ b/proxy/proxy_test.go
@@ -7,14 +7,12 @@
 import (
 	"bytes"
 	"fmt"
-	"io"
-	"net"
 	"net/url"
 	"os"
-	"strconv"
 	"strings"
-	"sync"
 	"testing"
+
+	"golang.org/x/net/internal/sockstest"
 )
 
 type proxyFromEnvTest struct {
@@ -73,131 +71,41 @@
 }
 
 func TestFromURL(t *testing.T) {
-	endSystem, err := net.Listen("tcp", "127.0.0.1:0")
+	ss, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
 	if err != nil {
-		t.Fatalf("net.Listen failed: %v", err)
+		t.Fatal(err)
 	}
-	defer endSystem.Close()
-	gateway, err := net.Listen("tcp", "127.0.0.1:0")
+	defer ss.Close()
+	url, err := url.Parse("socks5://user:password@" + ss.Addr().String())
 	if err != nil {
-		t.Fatalf("net.Listen failed: %v", err)
+		t.Fatal(err)
 	}
-	defer gateway.Close()
-
-	var wg sync.WaitGroup
-	wg.Add(1)
-	go socks5Gateway(t, gateway, endSystem, socks5Domain, &wg)
-
-	url, err := url.Parse("socks5://user:password@" + gateway.Addr().String())
+	proxy, err := FromURL(url, nil)
 	if err != nil {
-		t.Fatalf("url.Parse failed: %v", err)
+		t.Fatal(err)
 	}
-	proxy, err := FromURL(url, Direct)
+	c, err := proxy.Dial("tcp", "fqdn.doesnotexist:5963")
 	if err != nil {
-		t.Fatalf("FromURL failed: %v", err)
+		t.Fatal(err)
 	}
-	_, port, err := net.SplitHostPort(endSystem.Addr().String())
-	if err != nil {
-		t.Fatalf("net.SplitHostPort failed: %v", err)
-	}
-	if c, err := proxy.Dial("tcp", "localhost:"+port); err != nil {
-		t.Fatalf("FromURL.Dial failed: %v", err)
-	} else {
-		c.Close()
-	}
-
-	wg.Wait()
+	c.Close()
 }
 
 func TestSOCKS5(t *testing.T) {
-	endSystem, err := net.Listen("tcp", "127.0.0.1:0")
+	ss, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
 	if err != nil {
-		t.Fatalf("net.Listen failed: %v", err)
+		t.Fatal(err)
 	}
-	defer endSystem.Close()
-	gateway, err := net.Listen("tcp", "127.0.0.1:0")
+	defer ss.Close()
+	proxy, err := SOCKS5("tcp", ss.Addr().String(), nil, nil)
 	if err != nil {
-		t.Fatalf("net.Listen failed: %v", err)
+		t.Fatal(err)
 	}
-	defer gateway.Close()
-
-	var wg sync.WaitGroup
-	wg.Add(1)
-	go socks5Gateway(t, gateway, endSystem, socks5IP4, &wg)
-
-	proxy, err := SOCKS5("tcp", gateway.Addr().String(), nil, Direct)
+	c, err := proxy.Dial("tcp", ss.TargetAddr().String())
 	if err != nil {
-		t.Fatalf("SOCKS5 failed: %v", err)
+		t.Fatal(err)
 	}
-	if c, err := proxy.Dial("tcp", endSystem.Addr().String()); err != nil {
-		t.Fatalf("SOCKS5.Dial failed: %v", err)
-	} else {
-		c.Close()
-	}
-
-	wg.Wait()
-}
-
-func socks5Gateway(t *testing.T, gateway, endSystem net.Listener, typ byte, wg *sync.WaitGroup) {
-	defer wg.Done()
-
-	c, err := gateway.Accept()
-	if err != nil {
-		t.Errorf("net.Listener.Accept failed: %v", err)
-		return
-	}
-	defer c.Close()
-
-	b := make([]byte, 32)
-	var n int
-	if typ == socks5Domain {
-		n = 4
-	} else {
-		n = 3
-	}
-	if _, err := io.ReadFull(c, b[:n]); err != nil {
-		t.Errorf("io.ReadFull failed: %v", err)
-		return
-	}
-	if _, err := c.Write([]byte{socks5Version, socks5AuthNone}); err != nil {
-		t.Errorf("net.Conn.Write failed: %v", err)
-		return
-	}
-	if typ == socks5Domain {
-		n = 16
-	} else {
-		n = 10
-	}
-	if _, err := io.ReadFull(c, b[:n]); err != nil {
-		t.Errorf("io.ReadFull failed: %v", err)
-		return
-	}
-	if b[0] != socks5Version || b[1] != socks5Connect || b[2] != 0x00 || b[3] != typ {
-		t.Errorf("got an unexpected packet: %#02x %#02x %#02x %#02x", b[0], b[1], b[2], b[3])
-		return
-	}
-	if typ == socks5Domain {
-		copy(b[:5], []byte{socks5Version, 0x00, 0x00, socks5Domain, 9})
-		b = append(b, []byte("localhost")...)
-	} else {
-		copy(b[:4], []byte{socks5Version, 0x00, 0x00, socks5IP4})
-	}
-	host, port, err := net.SplitHostPort(endSystem.Addr().String())
-	if err != nil {
-		t.Errorf("net.SplitHostPort failed: %v", err)
-		return
-	}
-	b = append(b, []byte(net.ParseIP(host).To4())...)
-	p, err := strconv.Atoi(port)
-	if err != nil {
-		t.Errorf("strconv.Atoi failed: %v", err)
-		return
-	}
-	b = append(b, []byte{byte(p >> 8), byte(p)}...)
-	if _, err := c.Write(b); err != nil {
-		t.Errorf("net.Conn.Write failed: %v", err)
-		return
-	}
+	c.Close()
 }
 
 func ResetProxyEnv() {
diff --git a/proxy/socks5.go b/proxy/socks5.go
index 3fed38e..56345ec 100644
--- a/proxy/socks5.go
+++ b/proxy/socks5.go
@@ -5,210 +5,32 @@
 package proxy
 
 import (
-	"errors"
-	"io"
+	"context"
 	"net"
-	"strconv"
+
+	"golang.org/x/net/internal/socks"
 )
 
-// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given address
-// with an optional username and password. See RFC 1928 and RFC 1929.
-func SOCKS5(network, addr string, auth *Auth, forward Dialer) (Dialer, error) {
-	s := &socks5{
-		network: network,
-		addr:    addr,
-		forward: forward,
+// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given
+// address with an optional username and password.
+// See RFC 1928 and RFC 1929.
+func SOCKS5(network, address string, auth *Auth, forward Dialer) (Dialer, error) {
+	d := socks.NewDialer(network, address)
+	if forward != nil {
+		d.ProxyDial = func(_ context.Context, network string, address string) (net.Conn, error) {
+			return forward.Dial(network, address)
+		}
 	}
 	if auth != nil {
-		s.user = auth.User
-		s.password = auth.Password
-	}
-
-	return s, nil
-}
-
-type socks5 struct {
-	user, password string
-	network, addr  string
-	forward        Dialer
-}
-
-const socks5Version = 5
-
-const (
-	socks5AuthNone     = 0
-	socks5AuthPassword = 2
-)
-
-const socks5Connect = 1
-
-const (
-	socks5IP4    = 1
-	socks5Domain = 3
-	socks5IP6    = 4
-)
-
-var socks5Errors = []string{
-	"",
-	"general failure",
-	"connection forbidden",
-	"network unreachable",
-	"host unreachable",
-	"connection refused",
-	"TTL expired",
-	"command not supported",
-	"address type not supported",
-}
-
-// Dial connects to the address addr on the given network via the SOCKS5 proxy.
-func (s *socks5) Dial(network, addr string) (net.Conn, error) {
-	switch network {
-	case "tcp", "tcp6", "tcp4":
-	default:
-		return nil, errors.New("proxy: no support for SOCKS5 proxy connections of type " + network)
-	}
-
-	conn, err := s.forward.Dial(s.network, s.addr)
-	if err != nil {
-		return nil, err
-	}
-	if err := s.connect(conn, addr); err != nil {
-		conn.Close()
-		return nil, err
-	}
-	return conn, nil
-}
-
-// connect takes an existing connection to a socks5 proxy server,
-// and commands the server to extend that connection to target,
-// which must be a canonical address with a host and port.
-func (s *socks5) connect(conn net.Conn, target string) error {
-	host, portStr, err := net.SplitHostPort(target)
-	if err != nil {
-		return err
-	}
-
-	port, err := strconv.Atoi(portStr)
-	if err != nil {
-		return errors.New("proxy: failed to parse port number: " + portStr)
-	}
-	if port < 1 || port > 0xffff {
-		return errors.New("proxy: port number out of range: " + portStr)
-	}
-
-	// the size here is just an estimate
-	buf := make([]byte, 0, 6+len(host))
-
-	buf = append(buf, socks5Version)
-	if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 {
-		buf = append(buf, 2 /* num auth methods */, socks5AuthNone, socks5AuthPassword)
-	} else {
-		buf = append(buf, 1 /* num auth methods */, socks5AuthNone)
-	}
-
-	if _, err := conn.Write(buf); err != nil {
-		return errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error())
-	}
-
-	if _, err := io.ReadFull(conn, buf[:2]); err != nil {
-		return errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error())
-	}
-	if buf[0] != 5 {
-		return errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0])))
-	}
-	if buf[1] == 0xff {
-		return errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication")
-	}
-
-	// See RFC 1929
-	if buf[1] == socks5AuthPassword {
-		buf = buf[:0]
-		buf = append(buf, 1 /* password protocol version */)
-		buf = append(buf, uint8(len(s.user)))
-		buf = append(buf, s.user...)
-		buf = append(buf, uint8(len(s.password)))
-		buf = append(buf, s.password...)
-
-		if _, err := conn.Write(buf); err != nil {
-			return errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
+		up := socks.UsernamePassword{
+			Username: auth.User,
+			Password: auth.Password,
 		}
-
-		if _, err := io.ReadFull(conn, buf[:2]); err != nil {
-			return errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
+		d.AuthMethods = []socks.AuthMethod{
+			socks.AuthMethodNotRequired,
+			socks.AuthMethodUsernamePassword,
 		}
-
-		if buf[1] != 0 {
-			return errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password")
-		}
+		d.Authenticate = up.Authenticate
 	}
-
-	buf = buf[:0]
-	buf = append(buf, socks5Version, socks5Connect, 0 /* reserved */)
-
-	if ip := net.ParseIP(host); ip != nil {
-		if ip4 := ip.To4(); ip4 != nil {
-			buf = append(buf, socks5IP4)
-			ip = ip4
-		} else {
-			buf = append(buf, socks5IP6)
-		}
-		buf = append(buf, ip...)
-	} else {
-		if len(host) > 255 {
-			return errors.New("proxy: destination host name too long: " + host)
-		}
-		buf = append(buf, socks5Domain)
-		buf = append(buf, byte(len(host)))
-		buf = append(buf, host...)
-	}
-	buf = append(buf, byte(port>>8), byte(port))
-
-	if _, err := conn.Write(buf); err != nil {
-		return errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
-	}
-
-	if _, err := io.ReadFull(conn, buf[:4]); err != nil {
-		return errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
-	}
-
-	failure := "unknown error"
-	if int(buf[1]) < len(socks5Errors) {
-		failure = socks5Errors[buf[1]]
-	}
-
-	if len(failure) > 0 {
-		return errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure)
-	}
-
-	bytesToDiscard := 0
-	switch buf[3] {
-	case socks5IP4:
-		bytesToDiscard = net.IPv4len
-	case socks5IP6:
-		bytesToDiscard = net.IPv6len
-	case socks5Domain:
-		_, err := io.ReadFull(conn, buf[:1])
-		if err != nil {
-			return errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error())
-		}
-		bytesToDiscard = int(buf[0])
-	default:
-		return errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr)
-	}
-
-	if cap(buf) < bytesToDiscard {
-		buf = make([]byte, bytesToDiscard)
-	} else {
-		buf = buf[:bytesToDiscard]
-	}
-	if _, err := io.ReadFull(conn, buf); err != nil {
-		return errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error())
-	}
-
-	// Also need to discard the port number
-	if _, err := io.ReadFull(conn, buf[:2]); err != nil {
-		return errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error())
-	}
-
-	return nil
+	return d, nil
 }