blob: f2b03a8774d6735fc340e94eca83941e728acb8f [file] [log] [blame]
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package middleware
import (
"context"
"fmt"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/go-redis/redis/v8"
)
func TestIPKey(t *testing.T) {
for _, test := range []struct {
in string
want any
}{
{"", ""},
{"1.2.3", ""},
{"128.197.17.3", "128.197.17.0"},
{" 128.197.17.3, foo ", "128.197.17.0"},
{"2001:db8::ff00:42:8329", "2001:db8::ff00:42:8300"},
} {
got := ipKey(test.in)
if got != test.want {
t.Errorf("%q: got %v, want %v", test.in, got, test.want)
}
}
}
func TestEnforceQuota(t *testing.T) {
// This test is inherently time-dependent, so inherently flaky, especially on CI.
// So run it a few times before giving up.
ctx := context.Background()
s, err := miniredis.Run()
if err != nil {
t.Fatal(err)
}
defer s.Close()
c := redis.NewClient(&redis.Options{Addr: s.Addr()})
defer c.Close()
const qps = 5
var failReason string
for n := 0; n < 10; n++ {
failReason = ""
check := func(n int, ip string, want bool) {
if failReason != "" {
return
}
for i := 0; i < n; i++ {
blocked, reason := enforceQuota(ctx, c, qps, ip+",x", []byte{1, 2, 3, 4})
got := !blocked
if got != want {
failReason = fmt.Sprintf("%d: got %t, want %t (reason=%q)", i, got, want, reason)
break
}
}
}
check(qps, "1.2.3.4", true) // first qps requests are allowed
check(1, "1.2.3.4", false) // anything after that fails
check(1, "1.2.3.5", false) // low-order byte doesn't matter
check(qps, "1.2.4.1", true) // other IP is allowed
check(1, "1.2.4.9", false) // other IP blocked after qps requests
if failReason == "" {
return
}
time.Sleep(2 * time.Second)
}
t.Error(failReason)
}