// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

//go:build go1.21

package quic

import (
	"bytes"
	"context"
	"crypto/tls"
	"io"
	"net"
	"net/netip"
	"reflect"
	"testing"
	"time"
)

func TestConnect(t *testing.T) {
	newLocalConnPair(t, &Config{}, &Config{})
}

func TestStreamTransfer(t *testing.T) {
	ctx := context.Background()
	cli, srv := newLocalConnPair(t, &Config{}, &Config{})
	data := makeTestData(1 << 20)

	srvdone := make(chan struct{})
	go func() {
		defer close(srvdone)
		s, err := srv.AcceptStream(ctx)
		if err != nil {
			t.Errorf("AcceptStream: %v", err)
			return
		}
		b, err := io.ReadAll(s)
		if err != nil {
			t.Errorf("io.ReadAll(s): %v", err)
			return
		}
		if !bytes.Equal(b, data) {
			t.Errorf("read data mismatch (got %v bytes, want %v", len(b), len(data))
		}
		if err := s.Close(); err != nil {
			t.Errorf("s.Close() = %v", err)
		}
	}()

	s, err := cli.NewSendOnlyStream(ctx)
	if err != nil {
		t.Fatalf("NewStream: %v", err)
	}
	n, err := io.Copy(s, bytes.NewBuffer(data))
	if n != int64(len(data)) || err != nil {
		t.Fatalf("io.Copy(s, data) = %v, %v; want %v, nil", n, err, len(data))
	}
	if err := s.Close(); err != nil {
		t.Fatalf("s.Close() = %v", err)
	}
}

func newLocalConnPair(t *testing.T, conf1, conf2 *Config) (clientConn, serverConn *Conn) {
	t.Helper()
	ctx := context.Background()
	e1 := newLocalEndpoint(t, serverSide, conf1)
	e2 := newLocalEndpoint(t, clientSide, conf2)
	c2, err := e2.Dial(ctx, "udp", e1.LocalAddr().String())
	if err != nil {
		t.Fatal(err)
	}
	c1, err := e1.Accept(ctx)
	if err != nil {
		t.Fatal(err)
	}
	return c2, c1
}

func newLocalEndpoint(t *testing.T, side connSide, conf *Config) *Endpoint {
	t.Helper()
	if conf.TLSConfig == nil {
		newConf := *conf
		conf = &newConf
		conf.TLSConfig = newTestTLSConfig(side)
	}
	e, err := Listen("udp", "127.0.0.1:0", conf)
	if err != nil {
		t.Fatal(err)
	}
	t.Cleanup(func() {
		e.Close(context.Background())
	})
	return e
}

type testEndpoint struct {
	t                     *testing.T
	e                     *Endpoint
	now                   time.Time
	recvc                 chan *datagram
	idlec                 chan struct{}
	conns                 map[*Conn]*testConn
	acceptQueue           []*testConn
	configTransportParams []func(*transportParameters)
	configTestConn        []func(*testConn)
	sentDatagrams         [][]byte
	peerTLSConn           *tls.QUICConn
	lastInitialDstConnID  []byte // for parsing Retry packets
}

func newTestEndpoint(t *testing.T, config *Config) *testEndpoint {
	te := &testEndpoint{
		t:     t,
		now:   time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC),
		recvc: make(chan *datagram),
		idlec: make(chan struct{}),
		conns: make(map[*Conn]*testConn),
	}
	var err error
	te.e, err = newEndpoint((*testEndpointUDPConn)(te), config, (*testEndpointHooks)(te))
	if err != nil {
		t.Fatal(err)
	}
	t.Cleanup(te.cleanup)
	return te
}

func (te *testEndpoint) cleanup() {
	te.e.Close(canceledContext())
}

func (te *testEndpoint) wait() {
	select {
	case te.idlec <- struct{}{}:
	case <-te.e.closec:
	}
	for _, tc := range te.conns {
		tc.wait()
	}
}

// accept returns a server connection from the endpoint.
// Unlike Endpoint.Accept, connections are available as soon as they are created.
func (te *testEndpoint) accept() *testConn {
	if len(te.acceptQueue) == 0 {
		te.t.Fatalf("accept: expected available conn, but found none")
	}
	tc := te.acceptQueue[0]
	te.acceptQueue = te.acceptQueue[1:]
	return tc
}

func (te *testEndpoint) write(d *datagram) {
	te.recvc <- d
	te.wait()
}

var testClientAddr = netip.MustParseAddrPort("10.0.0.1:8000")

func (te *testEndpoint) writeDatagram(d *testDatagram) {
	te.t.Helper()
	logDatagram(te.t, "<- endpoint under test receives", d)
	var buf []byte
	for _, p := range d.packets {
		tc := te.connForDestination(p.dstConnID)
		if p.ptype != packetTypeRetry && tc != nil {
			space := spaceForPacketType(p.ptype)
			if p.num >= tc.peerNextPacketNum[space] {
				tc.peerNextPacketNum[space] = p.num + 1
			}
		}
		if p.ptype == packetTypeInitial {
			te.lastInitialDstConnID = p.dstConnID
		}
		pad := 0
		if p.ptype == packetType1RTT {
			pad = d.paddedSize - len(buf)
		}
		buf = append(buf, encodeTestPacket(te.t, tc, p, pad)...)
	}
	for len(buf) < d.paddedSize {
		buf = append(buf, 0)
	}
	addr := d.addr
	if !addr.IsValid() {
		addr = testClientAddr
	}
	te.write(&datagram{
		b:    buf,
		addr: addr,
	})
}

func (te *testEndpoint) connForDestination(dstConnID []byte) *testConn {
	for _, tc := range te.conns {
		for _, loc := range tc.conn.connIDState.local {
			if bytes.Equal(loc.cid, dstConnID) {
				return tc
			}
		}
	}
	return nil
}

func (te *testEndpoint) connForSource(srcConnID []byte) *testConn {
	for _, tc := range te.conns {
		for _, loc := range tc.conn.connIDState.remote {
			if bytes.Equal(loc.cid, srcConnID) {
				return tc
			}
		}
	}
	return nil
}

func (te *testEndpoint) read() []byte {
	te.t.Helper()
	te.wait()
	if len(te.sentDatagrams) == 0 {
		return nil
	}
	d := te.sentDatagrams[0]
	te.sentDatagrams = te.sentDatagrams[1:]
	return d
}

func (te *testEndpoint) readDatagram() *testDatagram {
	te.t.Helper()
	buf := te.read()
	if buf == nil {
		return nil
	}
	p, _ := parseGenericLongHeaderPacket(buf)
	tc := te.connForSource(p.dstConnID)
	d := parseTestDatagram(te.t, te, tc, buf)
	logDatagram(te.t, "-> endpoint under test sends", d)
	return d
}

// wantDatagram indicates that we expect the Endpoint to send a datagram.
func (te *testEndpoint) wantDatagram(expectation string, want *testDatagram) {
	te.t.Helper()
	got := te.readDatagram()
	if !reflect.DeepEqual(got, want) {
		te.t.Fatalf("%v:\ngot datagram:  %v\nwant datagram: %v", expectation, got, want)
	}
}

// wantIdle indicates that we expect the Endpoint to not send any more datagrams.
func (te *testEndpoint) wantIdle(expectation string) {
	if got := te.readDatagram(); got != nil {
		te.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, got)
	}
}

// advance causes time to pass.
func (te *testEndpoint) advance(d time.Duration) {
	te.t.Helper()
	te.advanceTo(te.now.Add(d))
}

// advanceTo sets the current time.
func (te *testEndpoint) advanceTo(now time.Time) {
	te.t.Helper()
	if te.now.After(now) {
		te.t.Fatalf("time moved backwards: %v -> %v", te.now, now)
	}
	te.now = now
	for _, tc := range te.conns {
		if !tc.timer.After(te.now) {
			tc.conn.sendMsg(timerEvent{})
			tc.wait()
		}
	}
}

// testEndpointHooks implements endpointTestHooks.
type testEndpointHooks testEndpoint

func (te *testEndpointHooks) timeNow() time.Time {
	return te.now
}

func (te *testEndpointHooks) newConn(c *Conn) {
	tc := newTestConnForConn(te.t, (*testEndpoint)(te), c)
	te.conns[c] = tc
}

// testEndpointUDPConn implements UDPConn.
type testEndpointUDPConn testEndpoint

func (te *testEndpointUDPConn) Close() error {
	close(te.recvc)
	return nil
}

func (te *testEndpointUDPConn) LocalAddr() net.Addr {
	return net.UDPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:443"))
}

func (te *testEndpointUDPConn) ReadMsgUDPAddrPort(b, control []byte) (n, controln, flags int, _ netip.AddrPort, _ error) {
	for {
		select {
		case d, ok := <-te.recvc:
			if !ok {
				return 0, 0, 0, netip.AddrPort{}, io.EOF
			}
			n = copy(b, d.b)
			return n, 0, 0, d.addr, nil
		case <-te.idlec:
		}
	}
}

func (te *testEndpointUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
	te.sentDatagrams = append(te.sentDatagrams, append([]byte(nil), b...))
	return len(b), nil
}
