// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Package sockstest provides utilities for SOCKS testing.
package sockstest

import (
	"errors"
	"io"
	"net"

	"golang.org/x/net/internal/socks"
	"golang.org/x/net/nettest"
)

// An AuthRequest represents an authentication request.
type AuthRequest struct {
	Version int
	Methods []socks.AuthMethod
}

// ParseAuthRequest parses an authentication request.
func ParseAuthRequest(b []byte) (*AuthRequest, error) {
	if len(b) < 2 {
		return nil, errors.New("short auth request")
	}
	if b[0] != socks.Version5 {
		return nil, errors.New("unexpected protocol version")
	}
	if len(b)-2 < int(b[1]) {
		return nil, errors.New("short auth request")
	}
	req := &AuthRequest{Version: int(b[0])}
	if b[1] > 0 {
		req.Methods = make([]socks.AuthMethod, b[1])
		for i, m := range b[2 : 2+b[1]] {
			req.Methods[i] = socks.AuthMethod(m)
		}
	}
	return req, nil
}

// MarshalAuthReply returns an authentication reply in wire format.
func MarshalAuthReply(ver int, m socks.AuthMethod) ([]byte, error) {
	return []byte{byte(ver), byte(m)}, nil
}

// A CmdRequest represents a command request.
type CmdRequest struct {
	Version int
	Cmd     socks.Command
	Addr    socks.Addr
}

// ParseCmdRequest parses a command request.
func ParseCmdRequest(b []byte) (*CmdRequest, error) {
	if len(b) < 7 {
		return nil, errors.New("short cmd request")
	}
	if b[0] != socks.Version5 {
		return nil, errors.New("unexpected protocol version")
	}
	if socks.Command(b[1]) != socks.CmdConnect {
		return nil, errors.New("unexpected command")
	}
	if b[2] != 0 {
		return nil, errors.New("non-zero reserved field")
	}
	req := &CmdRequest{Version: int(b[0]), Cmd: socks.Command(b[1])}
	l := 2
	off := 4
	switch b[3] {
	case socks.AddrTypeIPv4:
		l += net.IPv4len
		req.Addr.IP = make(net.IP, net.IPv4len)
	case socks.AddrTypeIPv6:
		l += net.IPv6len
		req.Addr.IP = make(net.IP, net.IPv6len)
	case socks.AddrTypeFQDN:
		l += int(b[4])
		off = 5
	default:
		return nil, errors.New("unknown address type")
	}
	if len(b[off:]) < l {
		return nil, errors.New("short cmd request")
	}
	if req.Addr.IP != nil {
		copy(req.Addr.IP, b[off:])
	} else {
		req.Addr.Name = string(b[off : off+l-2])
	}
	req.Addr.Port = int(b[off+l-2])<<8 | int(b[off+l-1])
	return req, nil
}

// MarshalCmdReply returns a command reply in wire format.
func MarshalCmdReply(ver int, reply socks.Reply, a *socks.Addr) ([]byte, error) {
	b := make([]byte, 4)
	b[0] = byte(ver)
	b[1] = byte(reply)
	if a.Name != "" {
		if len(a.Name) > 255 {
			return nil, errors.New("fqdn too long")
		}
		b[3] = socks.AddrTypeFQDN
		b = append(b, byte(len(a.Name)))
		b = append(b, a.Name...)
	} else if ip4 := a.IP.To4(); ip4 != nil {
		b[3] = socks.AddrTypeIPv4
		b = append(b, ip4...)
	} else if ip6 := a.IP.To16(); ip6 != nil {
		b[3] = socks.AddrTypeIPv6
		b = append(b, ip6...)
	} else {
		return nil, errors.New("unknown address type")
	}
	b = append(b, byte(a.Port>>8), byte(a.Port))
	return b, nil
}

// A Server represents a server for handshake testing.
type Server struct {
	ln net.Listener
}

// Addr returns a server address.
func (s *Server) Addr() net.Addr {
	return s.ln.Addr()
}

// TargetAddr returns a fake final destination address.
//
// The returned address is only valid for testing with Server.
func (s *Server) TargetAddr() net.Addr {
	a := s.ln.Addr()
	switch a := a.(type) {
	case *net.TCPAddr:
		if a.IP.To4() != nil {
			return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 5963}
		}
		if a.IP.To16() != nil && a.IP.To4() == nil {
			return &net.TCPAddr{IP: net.IPv6loopback, Port: 5963}
		}
	}
	return nil
}

// Close closes the server.
func (s *Server) Close() error {
	return s.ln.Close()
}

func (s *Server) serve(authFunc, cmdFunc func(io.ReadWriter, []byte) error) {
	c, err := s.ln.Accept()
	if err != nil {
		return
	}
	defer c.Close()
	go s.serve(authFunc, cmdFunc)
	b := make([]byte, 512)
	n, err := c.Read(b)
	if err != nil {
		return
	}
	if err := authFunc(c, b[:n]); err != nil {
		return
	}
	n, err = c.Read(b)
	if err != nil {
		return
	}
	if err := cmdFunc(c, b[:n]); err != nil {
		return
	}
}

// NewServer returns a new server.
//
// The provided authFunc and cmdFunc must parse requests and return
// appropriate replies to clients.
func NewServer(authFunc, cmdFunc func(io.ReadWriter, []byte) error) (*Server, error) {
	var err error
	s := new(Server)
	s.ln, err = nettest.NewLocalListener("tcp")
	if err != nil {
		return nil, err
	}
	go s.serve(authFunc, cmdFunc)
	return s, nil
}

// NoAuthRequired handles a no-authentication-required signaling.
func NoAuthRequired(rw io.ReadWriter, b []byte) error {
	req, err := ParseAuthRequest(b)
	if err != nil {
		return err
	}
	b, err = MarshalAuthReply(req.Version, socks.AuthMethodNotRequired)
	if err != nil {
		return err
	}
	n, err := rw.Write(b)
	if err != nil {
		return err
	}
	if n != len(b) {
		return errors.New("short write")
	}
	return nil
}

// NoProxyRequired handles a command signaling without constructing a
// proxy connection to the final destination.
func NoProxyRequired(rw io.ReadWriter, b []byte) error {
	req, err := ParseCmdRequest(b)
	if err != nil {
		return err
	}
	req.Addr.Port += 1
	if req.Addr.Name != "" {
		req.Addr.Name = "boundaddr.doesnotexist"
	} else if req.Addr.IP.To4() != nil {
		req.Addr.IP = net.IPv4(127, 0, 0, 1)
	} else {
		req.Addr.IP = net.IPv6loopback
	}
	b, err = MarshalCmdReply(socks.Version5, socks.StatusSucceeded, &req.Addr)
	if err != nil {
		return err
	}
	n, err := rw.Write(b)
	if err != nil {
		return err
	}
	if n != len(b) {
		return errors.New("short write")
	}
	return nil
}
