proxy: split SOCKS5 Dial method in two

Split off a new SOCKS5 connect() method from Dial.

connect() takes an existing connection to a socks5 server, and
commands the server to extend that connection to a given target
address and port.

Change-Id: I5dbba58a67a0d884bda3d3ac194dc18bdebe74ab
Reviewed-on: https://go-review.googlesource.com/36643
Reviewed-by: Adam Langley <agl@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/proxy/socks5.go b/proxy/socks5.go
index 9b96282..973f57f 100644
--- a/proxy/socks5.go
+++ b/proxy/socks5.go
@@ -72,24 +72,28 @@
 	if err != nil {
 		return nil, err
 	}
-	closeConn := &conn
-	defer func() {
-		if closeConn != nil {
-			(*closeConn).Close()
-		}
-	}()
-
-	host, portStr, err := net.SplitHostPort(addr)
-	if err != nil {
+	if err := s.connect(conn, addr); err != nil {
+		conn.Close()
 		return nil, err
 	}
+	return conn, nil
+}
+
+// connect takes an existing connection to a socks5 proxy server,
+// and commands the server to extend that connection to target,
+// which must be a canonical address with a host and port.
+func (s *socks5) connect(conn net.Conn, target string) error {
+	host, portStr, err := net.SplitHostPort(target)
+	if err != nil {
+		return err
+	}
 
 	port, err := strconv.Atoi(portStr)
 	if err != nil {
-		return nil, errors.New("proxy: failed to parse port number: " + portStr)
+		return errors.New("proxy: failed to parse port number: " + portStr)
 	}
 	if port < 1 || port > 0xffff {
-		return nil, errors.New("proxy: port number out of range: " + portStr)
+		return errors.New("proxy: port number out of range: " + portStr)
 	}
 
 	// the size here is just an estimate
@@ -103,17 +107,17 @@
 	}
 
 	if _, err := conn.Write(buf); err != nil {
-		return nil, errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error())
+		return errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error())
 	}
 
 	if _, err := io.ReadFull(conn, buf[:2]); err != nil {
-		return nil, errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error())
+		return errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error())
 	}
 	if buf[0] != 5 {
-		return nil, errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0])))
+		return errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0])))
 	}
 	if buf[1] == 0xff {
-		return nil, errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication")
+		return errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication")
 	}
 
 	if buf[1] == socks5AuthPassword {
@@ -125,15 +129,15 @@
 		buf = append(buf, s.password...)
 
 		if _, err := conn.Write(buf); err != nil {
-			return nil, errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
+			return errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
 		}
 
 		if _, err := io.ReadFull(conn, buf[:2]); err != nil {
-			return nil, errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
+			return errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
 		}
 
 		if buf[1] != 0 {
-			return nil, errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password")
+			return errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password")
 		}
 	}
 
@@ -150,7 +154,7 @@
 		buf = append(buf, ip...)
 	} else {
 		if len(host) > 255 {
-			return nil, errors.New("proxy: destination hostname too long: " + host)
+			return errors.New("proxy: destination hostname too long: " + host)
 		}
 		buf = append(buf, socks5Domain)
 		buf = append(buf, byte(len(host)))
@@ -159,11 +163,11 @@
 	buf = append(buf, byte(port>>8), byte(port))
 
 	if _, err := conn.Write(buf); err != nil {
-		return nil, errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
+		return errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
 	}
 
 	if _, err := io.ReadFull(conn, buf[:4]); err != nil {
-		return nil, errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
+		return errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
 	}
 
 	failure := "unknown error"
@@ -172,7 +176,7 @@
 	}
 
 	if len(failure) > 0 {
-		return nil, errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure)
+		return errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure)
 	}
 
 	bytesToDiscard := 0
@@ -184,11 +188,11 @@
 	case socks5Domain:
 		_, err := io.ReadFull(conn, buf[:1])
 		if err != nil {
-			return nil, errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error())
+			return errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error())
 		}
 		bytesToDiscard = int(buf[0])
 	default:
-		return nil, errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr)
+		return errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr)
 	}
 
 	if cap(buf) < bytesToDiscard {
@@ -197,14 +201,13 @@
 		buf = buf[:bytesToDiscard]
 	}
 	if _, err := io.ReadFull(conn, buf); err != nil {
-		return nil, errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error())
+		return errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error())
 	}
 
 	// Also need to discard the port number
 	if _, err := io.ReadFull(conn, buf[:2]); err != nil {
-		return nil, errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error())
+		return errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error())
 	}
 
-	closeConn = nil
-	return conn, nil
+	return nil
 }