| // Copyright 2018 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package socks |
| |
| import ( |
| "context" |
| "errors" |
| "io" |
| "net" |
| "strconv" |
| "time" |
| ) |
| |
| var ( |
| noDeadline = time.Time{} |
| aLongTimeAgo = time.Unix(1, 0) |
| ) |
| |
| func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) { |
| host, port, err := splitHostPort(address) |
| if err != nil { |
| return nil, err |
| } |
| if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() { |
| c.SetDeadline(deadline) |
| defer c.SetDeadline(noDeadline) |
| } |
| if ctx != context.Background() { |
| errCh := make(chan error, 1) |
| done := make(chan struct{}) |
| defer func() { |
| close(done) |
| if ctxErr == nil { |
| ctxErr = <-errCh |
| } |
| }() |
| go func() { |
| select { |
| case <-ctx.Done(): |
| c.SetDeadline(aLongTimeAgo) |
| errCh <- ctx.Err() |
| case <-done: |
| errCh <- nil |
| } |
| }() |
| } |
| |
| b := make([]byte, 0, 6+len(host)) // the size here is just an estimate |
| b = append(b, Version5) |
| if len(d.AuthMethods) == 0 || d.Authenticate == nil { |
| b = append(b, 1, byte(AuthMethodNotRequired)) |
| } else { |
| ams := d.AuthMethods |
| if len(ams) > 255 { |
| return nil, errors.New("too many authentication methods") |
| } |
| b = append(b, byte(len(ams))) |
| for _, am := range ams { |
| b = append(b, byte(am)) |
| } |
| } |
| if _, ctxErr = c.Write(b); ctxErr != nil { |
| return |
| } |
| |
| if _, ctxErr = io.ReadFull(c, b[:2]); ctxErr != nil { |
| return |
| } |
| if b[0] != Version5 { |
| return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0]))) |
| } |
| am := AuthMethod(b[1]) |
| if am == AuthMethodNoAcceptableMethods { |
| return nil, errors.New("no acceptable authentication methods") |
| } |
| if d.Authenticate != nil { |
| if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil { |
| return |
| } |
| } |
| |
| b = b[:0] |
| b = append(b, Version5, byte(d.cmd), 0) |
| if ip := net.ParseIP(host); ip != nil { |
| if ip4 := ip.To4(); ip4 != nil { |
| b = append(b, AddrTypeIPv4) |
| b = append(b, ip4...) |
| } else if ip6 := ip.To16(); ip6 != nil { |
| b = append(b, AddrTypeIPv6) |
| b = append(b, ip6...) |
| } else { |
| return nil, errors.New("unknown address type") |
| } |
| } else { |
| if len(host) > 255 { |
| return nil, errors.New("FQDN too long") |
| } |
| b = append(b, AddrTypeFQDN) |
| b = append(b, byte(len(host))) |
| b = append(b, host...) |
| } |
| b = append(b, byte(port>>8), byte(port)) |
| if _, ctxErr = c.Write(b); ctxErr != nil { |
| return |
| } |
| |
| if _, ctxErr = io.ReadFull(c, b[:4]); ctxErr != nil { |
| return |
| } |
| if b[0] != Version5 { |
| return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0]))) |
| } |
| if cmdErr := Reply(b[1]); cmdErr != StatusSucceeded { |
| return nil, errors.New("unknown error " + cmdErr.String()) |
| } |
| if b[2] != 0 { |
| return nil, errors.New("non-zero reserved field") |
| } |
| l := 2 |
| var a Addr |
| switch b[3] { |
| case AddrTypeIPv4: |
| l += net.IPv4len |
| a.IP = make(net.IP, net.IPv4len) |
| case AddrTypeIPv6: |
| l += net.IPv6len |
| a.IP = make(net.IP, net.IPv6len) |
| case AddrTypeFQDN: |
| if _, err := io.ReadFull(c, b[:1]); err != nil { |
| return nil, err |
| } |
| l += int(b[0]) |
| default: |
| return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3]))) |
| } |
| if cap(b) < l { |
| b = make([]byte, l) |
| } else { |
| b = b[:l] |
| } |
| if _, ctxErr = io.ReadFull(c, b); ctxErr != nil { |
| return |
| } |
| if a.IP != nil { |
| copy(a.IP, b) |
| } else { |
| a.Name = string(b[:len(b)-2]) |
| } |
| a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1]) |
| return &a, nil |
| } |
| |
| func splitHostPort(address string) (string, int, error) { |
| host, port, err := net.SplitHostPort(address) |
| if err != nil { |
| return "", 0, err |
| } |
| portnum, err := strconv.Atoi(port) |
| if err != nil { |
| return "", 0, err |
| } |
| if 1 > portnum || portnum > 0xffff { |
| return "", 0, errors.New("port number out of range " + port) |
| } |
| return host, portnum, nil |
| } |