// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Test that Resolver.Dial can be a func returning an in-memory net.Conn
// speaking DNS.

package net

import (
	"bytes"
	"context"
	"errors"
	"fmt"
	"reflect"
	"sort"
	"testing"
	"time"

	"golang.org/x/net/dns/dnsmessage"
)

func TestResolverDialFunc(t *testing.T) {
	r := &Resolver{
		PreferGo: true,
		Dial: newResolverDialFunc(&resolverDialHandler{
			StartDial: func(network, address string) error {
				t.Logf("StartDial(%q, %q) ...", network, address)
				return nil
			},
			Question: func(h dnsmessage.Header, q dnsmessage.Question) {
				t.Logf("Header: %+v for %q (type=%v, class=%v)", h,
					q.Name.String(), q.Type, q.Class)
			},
			// TODO: add test without HandleA* hooks specified at all, that Go
			// doesn't issue retries; map to something terminal.
			HandleA: func(w AWriter, name string) error {
				w.AddIP([4]byte{1, 2, 3, 4})
				w.AddIP([4]byte{5, 6, 7, 8})
				return nil
			},
			HandleAAAA: func(w AAAAWriter, name string) error {
				w.AddIP([16]byte{1: 1, 15: 15})
				w.AddIP([16]byte{2: 2, 14: 14})
				return nil
			},
			HandleSRV: func(w SRVWriter, name string) error {
				w.AddSRV(1, 2, 80, "foo.bar.")
				w.AddSRV(2, 3, 81, "bar.baz.")
				return nil
			},
		}),
	}
	ctx := context.Background()
	const fakeDomain = "something-that-is-a-not-a-real-domain.fake-tld."

	t.Run("LookupIP", func(t *testing.T) {
		ips, err := r.LookupIP(ctx, "ip", fakeDomain)
		if err != nil {
			t.Fatal(err)
		}
		if got, want := sortedIPStrings(ips), []string{"0:200::e00", "1.2.3.4", "1::f", "5.6.7.8"}; !reflect.DeepEqual(got, want) {
			t.Errorf("LookupIP wrong.\n got: %q\nwant: %q\n", got, want)
		}
	})

	t.Run("LookupSRV", func(t *testing.T) {
		_, got, err := r.LookupSRV(ctx, "some-service", "tcp", fakeDomain)
		if err != nil {
			t.Fatal(err)
		}
		want := []*SRV{
			{
				Target:   "foo.bar.",
				Port:     80,
				Priority: 1,
				Weight:   2,
			},
			{
				Target:   "bar.baz.",
				Port:     81,
				Priority: 2,
				Weight:   3,
			},
		}
		if !reflect.DeepEqual(got, want) {
			t.Errorf("wrong result. got:")
			for _, r := range got {
				t.Logf("  - %+v", r)
			}
		}
	})
}

func sortedIPStrings(ips []IP) []string {
	ret := make([]string, len(ips))
	for i, ip := range ips {
		ret[i] = ip.String()
	}
	sort.Strings(ret)
	return ret
}

func newResolverDialFunc(h *resolverDialHandler) func(ctx context.Context, network, address string) (Conn, error) {
	return func(ctx context.Context, network, address string) (Conn, error) {
		a := &resolverFuncConn{
			h:       h,
			network: network,
			address: address,
			ttl:     10, // 10 second default if unset
		}
		if h.StartDial != nil {
			if err := h.StartDial(network, address); err != nil {
				return nil, err
			}
		}
		return a, nil
	}
}

type resolverDialHandler struct {
	// StartDial, if non-nil, is called when Go first calls Resolver.Dial.
	// Any error returned aborts the dial and is returned unwrapped.
	StartDial func(network, address string) error

	Question func(dnsmessage.Header, dnsmessage.Question)

	// err may be ErrNotExist or ErrRefused; others map to SERVFAIL (RCode2).
	// A nil error means success.
	HandleA    func(w AWriter, name string) error
	HandleAAAA func(w AAAAWriter, name string) error
	HandleSRV  func(w SRVWriter, name string) error
}

type ResponseWriter struct{ a *resolverFuncConn }

func (w ResponseWriter) header() dnsmessage.ResourceHeader {
	q := w.a.q
	return dnsmessage.ResourceHeader{
		Name:  q.Name,
		Type:  q.Type,
		Class: q.Class,
		TTL:   w.a.ttl,
	}
}

// SetTTL sets the TTL for subsequent written resources.
// Once a resource has been written, SetTTL calls are no-ops.
// That is, it can only be called at most once, before anything
// else is written.
func (w ResponseWriter) SetTTL(seconds uint32) {
	// ... intention is last one wins and mutates all previously
	// written records too, but that's a little annoying.
	// But it's also annoying if the requirement is it needs to be set
	// last.
	// And it's also annoying if it's possible for users to set
	// different TTLs per Answer.
	if w.a.wrote {
		return
	}
	w.a.ttl = seconds

}

type AWriter struct{ ResponseWriter }

func (w AWriter) AddIP(v4 [4]byte) {
	w.a.wrote = true
	err := w.a.builder.AResource(w.header(), dnsmessage.AResource{A: v4})
	if err != nil {
		panic(err)
	}
}

type AAAAWriter struct{ ResponseWriter }

func (w AAAAWriter) AddIP(v6 [16]byte) {
	w.a.wrote = true
	err := w.a.builder.AAAAResource(w.header(), dnsmessage.AAAAResource{AAAA: v6})
	if err != nil {
		panic(err)
	}
}

type SRVWriter struct{ ResponseWriter }

// AddSRV adds a SRV record. The target name must end in a period and
// be 63 bytes or fewer.
func (w SRVWriter) AddSRV(priority, weight, port uint16, target string) error {
	targetName, err := dnsmessage.NewName(target)
	if err != nil {
		return err
	}
	w.a.wrote = true
	err = w.a.builder.SRVResource(w.header(), dnsmessage.SRVResource{
		Priority: priority,
		Weight:   weight,
		Port:     port,
		Target:   targetName,
	})
	if err != nil {
		panic(err) // internal fault, not user
	}
	return nil
}

var (
	ErrNotExist = errors.New("name does not exist") // maps to RCode3, NXDOMAIN
	ErrRefused  = errors.New("refused")             // maps to RCode5, REFUSED
)

type resolverFuncConn struct {
	h       *resolverDialHandler
	network string
	address string
	builder *dnsmessage.Builder
	q       dnsmessage.Question
	ttl     uint32
	wrote   bool

	rbuf bytes.Buffer
}

func (*resolverFuncConn) Close() error                       { return nil }
func (*resolverFuncConn) LocalAddr() Addr                    { return someaddr{} }
func (*resolverFuncConn) RemoteAddr() Addr                   { return someaddr{} }
func (*resolverFuncConn) SetDeadline(t time.Time) error      { return nil }
func (*resolverFuncConn) SetReadDeadline(t time.Time) error  { return nil }
func (*resolverFuncConn) SetWriteDeadline(t time.Time) error { return nil }

func (a *resolverFuncConn) Read(p []byte) (n int, err error) {
	return a.rbuf.Read(p)
}

func (a *resolverFuncConn) Write(packet []byte) (n int, err error) {
	if len(packet) < 2 {
		return 0, fmt.Errorf("short write of %d bytes; want 2+", len(packet))
	}
	reqLen := int(packet[0])<<8 | int(packet[1])
	req := packet[2:]
	if len(req) != reqLen {
		return 0, fmt.Errorf("packet declared length %d doesn't match body length %d", reqLen, len(req))
	}

	var parser dnsmessage.Parser
	h, err := parser.Start(req)
	if err != nil {
		// TODO: hook
		return 0, err
	}
	q, err := parser.Question()
	hadQ := (err == nil)
	if err == nil && a.h.Question != nil {
		a.h.Question(h, q)
	}
	if err != nil && err != dnsmessage.ErrSectionDone {
		return 0, err
	}

	resh := h
	resh.Response = true
	resh.Authoritative = true
	if hadQ {
		resh.RCode = dnsmessage.RCodeSuccess
	} else {
		resh.RCode = dnsmessage.RCodeNotImplemented
	}
	a.rbuf.Grow(514)
	a.rbuf.WriteByte('X') // reserved header for beu16 length
	a.rbuf.WriteByte('Y') // reserved header for beu16 length
	builder := dnsmessage.NewBuilder(a.rbuf.Bytes(), resh)
	a.builder = &builder
	if hadQ {
		a.q = q
		a.builder.StartQuestions()
		err := a.builder.Question(q)
		if err != nil {
			return 0, fmt.Errorf("Question: %w", err)
		}
		a.builder.StartAnswers()
		switch q.Type {
		case dnsmessage.TypeA:
			if a.h.HandleA != nil {
				resh.RCode = mapRCode(a.h.HandleA(AWriter{ResponseWriter{a}}, q.Name.String()))
			}
		case dnsmessage.TypeAAAA:
			if a.h.HandleAAAA != nil {
				resh.RCode = mapRCode(a.h.HandleAAAA(AAAAWriter{ResponseWriter{a}}, q.Name.String()))
			}
		case dnsmessage.TypeSRV:
			if a.h.HandleSRV != nil {
				resh.RCode = mapRCode(a.h.HandleSRV(SRVWriter{ResponseWriter{a}}, q.Name.String()))
			}
		}
	}
	tcpRes, err := builder.Finish()
	if err != nil {
		return 0, fmt.Errorf("Finish: %w", err)
	}

	n = len(tcpRes) - 2
	tcpRes[0] = byte(n >> 8)
	tcpRes[1] = byte(n)
	a.rbuf.Write(tcpRes[2:])

	return len(packet), nil
}

type someaddr struct{}

func (someaddr) Network() string { return "unused" }
func (someaddr) String() string  { return "unused-someaddr" }

func mapRCode(err error) dnsmessage.RCode {
	switch err {
	case nil:
		return dnsmessage.RCodeSuccess
	case ErrNotExist:
		return dnsmessage.RCodeNameError
	case ErrRefused:
		return dnsmessage.RCodeRefused
	default:
		return dnsmessage.RCodeServerFailure
	}
}
