blob: d7aa341f7e1abbcb9ce5355e6c6847d6b5ce22c0 [file] [log] [blame]
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package middleware
import (
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"go.opencensus.io/stats/view"
"golang.org/x/pkgsite/internal/config"
)
func TestQuota(t *testing.T) {
mw := Quota(config.QuotaSettings{QPS: 1, Burst: 2, MaxEntries: 1, RecordOnly: boolptr(false)})
var npass int
h := func(w http.ResponseWriter, r *http.Request) {
npass++
}
ts := httptest.NewServer(mw(http.HandlerFunc(h)))
defer ts.Close()
c := ts.Client()
view.Register(QuotaResultCount)
defer view.Unregister(QuotaResultCount)
check := func(msg string, nwant int) {
npass = 0
for i := 0; i < 5; i++ {
req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
req.Header.Add("X-Forwarded-For", "1.2.3.4, and more")
res, err := c.Do(req)
if err != nil {
t.Fatalf("%s: %v", msg, err)
}
res.Body.Close()
want := http.StatusOK
if i >= nwant {
want = http.StatusTooManyRequests
}
if got := res.StatusCode; got != want {
t.Errorf("%s, #%d: got %d, want %d", msg, i, got, want)
}
}
if npass != nwant {
t.Errorf("%s: got %d requests to pass, want %d", msg, npass, nwant)
}
}
// When making multiple requests in quick succession from the same IP,
// only the first two get through; the rest are blocked.
check("before", 2)
// After a second (and a bit more), we should have one token back, meaning
// we can serve one request.
time.Sleep(1100 * time.Millisecond)
check("after", 1)
// Check the metric.
got := collectViewData(t)
want := map[bool]int{true: 7, false: 3} // only 3 requests of the ten we sent get through.
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("mismatch (-want +got):\n%s", diff)
}
}
func TestQuotaRecordOnly(t *testing.T) {
// Like TestQuota, but with in RecordOnly mode nothing is actually blocked.
mw := Quota(config.QuotaSettings{QPS: 1, Burst: 2, MaxEntries: 1, RecordOnly: boolptr(true)})
npass := 0
h := func(w http.ResponseWriter, r *http.Request) {
npass++
}
ts := httptest.NewServer(mw(http.HandlerFunc(h)))
defer ts.Close()
c := ts.Client()
view.Register(QuotaResultCount)
defer view.Unregister(QuotaResultCount)
const nreq = 100
for i := 0; i < nreq; i++ {
req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
req.Header.Add("X-Forwarded-For", "1.2.3.4, and more")
res, err := c.Do(req)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
}
if npass != nreq {
t.Errorf("%d passed, want %d", npass, nreq)
}
got := collectViewData(t)
want := map[bool]int{true: nreq - 2, false: 2} // record as if blocking occurred
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("mismatch (-want +got):\n%s", diff)
}
}
func TestQuotaBadKey(t *testing.T) {
// Verify that invalid IP addresses are not blocked.
mw := Quota(config.QuotaSettings{QPS: 1, Burst: 2, MaxEntries: 1, RecordOnly: boolptr(true)})
npass := 0
h := func(w http.ResponseWriter, r *http.Request) {
npass++
}
ts := httptest.NewServer(mw(http.HandlerFunc(h)))
defer ts.Close()
c := ts.Client()
view.Register(QuotaResultCount)
defer view.Unregister(QuotaResultCount)
const nreq = 100
for i := 0; i < nreq; i++ {
req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
req.Header.Add("X-Forwarded-For", "not.a.valid.ip, and more")
res, err := c.Do(req)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
}
if npass != nreq {
t.Errorf("%d passed, want %d", npass, nreq)
}
got := collectViewData(t)
want := map[bool]int{false: nreq} // no blocking occurred
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("mismatch (-want +got):\n%s", diff)
}
}
func collectViewData(t *testing.T) map[bool]int {
m := map[bool]int{}
rows, err := view.RetrieveData(QuotaResultCount.Name)
if err != nil {
t.Fatal(err)
}
for _, row := range rows {
blocked, err := strconv.ParseBool(row.Tags[0].Value)
if err != nil {
t.Fatalf("collectViewData: %v", err)
}
count := int(row.Data.(*view.CountData).Value)
m[blocked] = count
}
return m
}
func TestIPKey(t *testing.T) {
for _, test := range []struct {
in string
want interface{}
}{
{"", ""},
{"1.2.3", ""},
{"128.197.17.3", "128.197.17.0"},
{" 128.197.17.3, foo ", "128.197.17.0"},
{"2001:db8::ff00:42:8329", "2001:db8::ff00:42:8300"},
} {
got := ipKey(test.in)
if got != test.want {
t.Errorf("%q: got %v, want %v", test.in, got, test.want)
}
}
}
func boolptr(b bool) *bool { return &b }