// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package middleware

import (
	"context"
	"net/http"
	"net/http/httptest"
	"testing"
	"time"

	"github.com/alicebob/miniredis/v2"
	"github.com/go-redis/redis/v8"
	"github.com/google/go-cmp/cmp"
	"go.opencensus.io/stats/view"
	"golang.org/x/pkgsite/internal/config"
)

func TestLegacyQuota(t *testing.T) {
	mw := LegacyQuota(config.QuotaSettings{QPS: 1, Burst: 2, MaxEntries: 1, RecordOnly: boolptr(false)})
	var npass int
	h := func(w http.ResponseWriter, r *http.Request) {
		npass++
	}
	ts := httptest.NewServer(mw(http.HandlerFunc(h)))
	defer ts.Close()
	c := ts.Client()
	view.Register(QuotaResultCount)
	defer view.Unregister(QuotaResultCount)

	check := func(msg string, nwant int) {
		npass = 0
		for i := 0; i < 5; i++ {
			req, err := http.NewRequest("GET", ts.URL, nil)
			if err != nil {
				t.Fatal(err)
			}
			req.Header.Add("X-Forwarded-For", "1.2.3.4, and more")
			res, err := c.Do(req)
			if err != nil {
				t.Fatalf("%s: %v", msg, err)
			}
			res.Body.Close()
			want := http.StatusOK
			if i >= nwant {
				want = http.StatusTooManyRequests
			}
			if got := res.StatusCode; got != want {
				t.Errorf("%s, #%d: got %d, want %d", msg, i, got, want)
			}
		}
		if npass != nwant {
			t.Errorf("%s: got %d requests to pass, want %d", msg, npass, nwant)
		}
	}

	// When making multiple requests in quick succession from the same IP,
	// only the first two get through; the rest are blocked.
	check("before", 2)
	// After a second (and a bit more), we should have one token back, meaning
	// we can serve one request.
	time.Sleep(1100 * time.Millisecond)
	check("after", 1)

	// Check the metric.
	got := collectViewData(t)
	want := map[bool]int{true: 7, false: 3} // only 3 requests of the ten we sent get through.
	if diff := cmp.Diff(want, got); diff != "" {
		t.Errorf("mismatch (-want +got):\n%s", diff)
	}
}

func TestLegacyQuotaRecordOnly(t *testing.T) {
	// Like TestQuota, but with in RecordOnly mode nothing is actually blocked.
	mw := LegacyQuota(config.QuotaSettings{QPS: 1, Burst: 2, MaxEntries: 1, RecordOnly: boolptr(true)})
	npass := 0
	h := func(w http.ResponseWriter, r *http.Request) {
		npass++
	}
	ts := httptest.NewServer(mw(http.HandlerFunc(h)))
	defer ts.Close()
	c := ts.Client()
	view.Register(QuotaResultCount)
	defer view.Unregister(QuotaResultCount)

	const nreq = 100
	for i := 0; i < nreq; i++ {
		req, err := http.NewRequest("GET", ts.URL, nil)
		if err != nil {
			t.Fatal(err)
		}
		req.Header.Add("X-Forwarded-For", "1.2.3.4, and more")
		res, err := c.Do(req)
		if err != nil {
			t.Fatal(err)
		}
		res.Body.Close()
	}
	if npass != nreq {
		t.Errorf("%d passed, want %d", npass, nreq)
	}
	got := collectViewData(t)
	want := map[bool]int{true: nreq - 2, false: 2} // record as if blocking occurred
	if diff := cmp.Diff(want, got); diff != "" {
		t.Errorf("mismatch (-want +got):\n%s", diff)
	}
}

func TestLegacyQuotaBadKey(t *testing.T) {
	// Verify that invalid IP addresses are not blocked.
	mw := LegacyQuota(config.QuotaSettings{QPS: 1, Burst: 2, MaxEntries: 1, RecordOnly: boolptr(true)})
	npass := 0
	h := func(w http.ResponseWriter, r *http.Request) {
		npass++
	}
	ts := httptest.NewServer(mw(http.HandlerFunc(h)))
	defer ts.Close()
	c := ts.Client()
	view.Register(QuotaResultCount)
	defer view.Unregister(QuotaResultCount)

	const nreq = 100
	for i := 0; i < nreq; i++ {
		req, err := http.NewRequest("GET", ts.URL, nil)
		if err != nil {
			t.Fatal(err)
		}
		req.Header.Add("X-Forwarded-For", "not.a.valid.ip, and more")
		res, err := c.Do(req)
		if err != nil {
			t.Fatal(err)
		}
		res.Body.Close()
	}
	if npass != nreq {
		t.Errorf("%d passed, want %d", npass, nreq)
	}
	got := collectViewData(t)
	want := map[bool]int{false: nreq} // no blocking occurred
	if diff := cmp.Diff(want, got); diff != "" {
		t.Errorf("mismatch (-want +got):\n%s", diff)
	}
}

func collectViewData(t *testing.T) map[bool]int {
	m := map[bool]int{}
	rows, err := view.RetrieveData(QuotaResultCount.Name)
	if err != nil {
		t.Fatal(err)
	}
	for _, row := range rows {
		blocked := row.Tags[0].Value == "blocked"
		if err != nil {
			t.Fatalf("collectViewData: %v", err)
		}
		count := int(row.Data.(*view.CountData).Value)
		m[blocked] = count
	}
	return m
}

func TestIPKey(t *testing.T) {
	for _, test := range []struct {
		in   string
		want interface{}
	}{
		{"", ""},
		{"1.2.3", ""},
		{"128.197.17.3", "128.197.17.0"},
		{"  128.197.17.3, foo  ", "128.197.17.0"},
		{"2001:db8::ff00:42:8329", "2001:db8::ff00:42:8300"},
	} {
		got := ipKey(test.in)
		if got != test.want {
			t.Errorf("%q: got %v, want %v", test.in, got, test.want)
		}
	}
}

func boolptr(b bool) *bool { return &b }

func TestEnforceQuota(t *testing.T) {
	ctx := context.Background()
	s, err := miniredis.Run()
	if err != nil {
		t.Fatal(err)
	}
	defer s.Close()

	c := redis.NewClient(&redis.Options{Addr: s.Addr()})
	defer c.Close()

	const qps = 5
	check := func(n int, ip string, want bool) {
		t.Helper()
		for i := 0; i < n; i++ {
			blocked, reason := enforceQuota(ctx, c, qps, ip+",x", []byte{1, 2, 3, 4})
			got := !blocked
			if got != want {
				t.Errorf("%d: got %t, want %t (reason=%q)", i, got, want, reason)
			}
		}
	}

	check(qps, "1.2.3.4", true) // first qps requests are allowed
	check(1, "1.2.3.4", false)  // anything after that fails
	check(1, "1.2.3.5", false)  // low-order byte doesn't matter
	check(qps, "1.2.4.1", true) // other IP is allowed
	check(1, "1.2.4.9", false)  // other IP blocked after qps requests
}
