blob: 70bde537575568a7a28a33b4d5920616ad6d2fdf [file] [log] [blame]
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package middleware
import (
"context"
"crypto/hmac"
"crypto/sha256"
"errors"
"io"
"net"
"net/http"
"strings"
"time"
"github.com/go-redis/redis/v8"
rrate "github.com/go-redis/redis_rate/v9"
"go.opencensus.io/stats"
"go.opencensus.io/stats/view"
"go.opencensus.io/tag"
"golang.org/x/pkgsite/internal/config"
"golang.org/x/pkgsite/internal/log"
)
var (
keyQuotaBlocked = tag.MustNewKey("quota.blocked")
quotaResults = stats.Int64(
"go-discovery/quota_result_count",
"The result of a quota check.",
stats.UnitDimensionless,
)
// QuotaResultCount is a counter of quota results, by whether the request was blocked or not.
QuotaResultCount = &view.View{
Name: "go-discovery/quota/result_count",
Measure: quotaResults,
Aggregation: view.Count(),
Description: "quota results, by blocked or allowed",
TagKeys: []tag.Key{keyQuotaBlocked},
}
)
func recordQuotaMetric(ctx context.Context, blocked string) {
stats.RecordWithTags(ctx, []tag.Mutator{
tag.Upsert(keyQuotaBlocked, blocked),
}, quotaResults.M(1))
}
func ipKey(s string) string {
addr, _, _ := strings.Cut(s, ",")
// First field is the originating IP address.
origin := strings.TrimSpace(addr)
ip := net.ParseIP(origin)
if ip == nil {
return ""
}
// Zero out last byte, to cover ranges of IPv4 addresses. (It's less clear what the
// effect will be on IPv6 addresses: it will certainly cover a range of them, but we
// don't know if that range is likely to be allocated to a single organization.)
ip[len(ip)-1] = 0
return ip.String()
}
// Quota implements a simple IP-based rate limiter. Each set of incoming IP
// addresses with the same low-order byte gets settings.QPS requests per second.
//
// Information is kept in a redis instance.
//
// If a request is disallowed, a 429 (TooManyRequests) will be served.
func Quota(settings config.QuotaSettings, client *redis.Client) Middleware {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !settings.Enable {
recordQuotaMetric(ctx, "disabled")
h.ServeHTTP(w, r)
return
}
authVal := r.Header.Get(config.BypassQuotaAuthHeader)
for _, wantVal := range settings.AuthValues {
if authVal == wantVal {
recordQuotaMetric(ctx, "bypassed")
log.Infof(ctx, "Quota: accepting %q", authVal)
h.ServeHTTP(w, r)
return
}
}
header := r.Header.Get("X-Godoc-Forwarded-For")
if header == "" {
header = r.Header.Get("X-Forwarded-For")
}
blocked, reason := enforceQuota(ctx, client, settings.QPS, header, settings.HMACKey)
recordQuotaMetric(ctx, reason)
if blocked && settings.RecordOnly != nil && !*settings.RecordOnly {
const tmr = http.StatusTooManyRequests
http.Error(w, http.StatusText(tmr), tmr)
return
}
h.ServeHTTP(w, r)
})
}
}
func enforceQuota(ctx context.Context, client *redis.Client, qps int, header string, hmacKey []byte) (blocked bool, reason string) {
// Fail open if header is missing or can't be parsed.
if header == "" {
return false, "no header"
}
key := ipKey(header)
if key == "" {
return false, "bad header"
}
mac := hmac.New(sha256.New, hmacKey)
io.WriteString(mac, key)
rrateKey := string(mac.Sum(nil))
res, err := rrate.NewLimiter(client.WithTimeout(15*time.Millisecond)).Allow(ctx, rrateKey, rrate.PerSecond(qps))
if err != nil {
var nerr *net.OpError
if errors.Is(err, context.DeadlineExceeded) || (errors.As(err, &nerr) && nerr.Timeout()) {
log.Warningf(ctx, "quota: redis limiter: %v", err)
return false, "timeout"
}
log.Errorf(ctx, "quota: redis limiter: %v", err)
return false, "error"
}
if res.Allowed > 0 {
return false, "allowed"
}
return true, "blocked"
}