blob: 08559921ca1a25699d9bbd3d576d9cee6bdbca16 [file] [log] [blame]
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package middleware
import (
"bytes"
"compress/gzip"
"context"
"fmt"
"io"
"net/http"
"strconv"
"time"
"github.com/go-redis/redis/v7"
"go.opencensus.io/plugin/ochttp"
"go.opencensus.io/stats"
"go.opencensus.io/stats/view"
"go.opencensus.io/tag"
"golang.org/x/pkgsite/internal/config"
"golang.org/x/pkgsite/internal/log"
)
var (
keyCacheHit = tag.MustNewKey("cache.hit")
keyCacheName = tag.MustNewKey("cache.name")
keyCacheOperation = tag.MustNewKey("cache.operation")
cacheResults = stats.Int64(
"go-discovery/cache/result_count",
"The result of a cache request.",
stats.UnitDimensionless,
)
cacheLatency = stats.Float64(
"go-discovery/cache/result_latency",
"Cache serving latency latency",
stats.UnitMilliseconds,
)
cacheErrors = stats.Int64(
"go-discovery/cache/errors",
"Errors retrieving from cache.",
stats.UnitDimensionless,
)
CacheResultCount = &view.View{
Name: "go-discovery/cache/result_count",
Measure: cacheResults,
Aggregation: view.Count(),
Description: "cache results, by cache name and whether it was a hit",
TagKeys: []tag.Key{keyCacheName, keyCacheHit},
}
CacheLatency = &view.View{
Name: "go-discovery/cache/result_latency",
Measure: cacheLatency,
Description: "cache result latency, by cache name and whether it was a hit",
Aggregation: ochttp.DefaultLatencyDistribution,
TagKeys: []tag.Key{keyCacheName, keyCacheHit},
}
CacheErrorCount = &view.View{
Name: "go-discovery/cache/errors",
Measure: cacheErrors,
Aggregation: view.Count(),
Description: "cache errors, by cache name",
TagKeys: []tag.Key{keyCacheName, keyCacheOperation},
}
// To avoid test flakiness, when testMode is true, cache writes are
// synchronous.
testMode = false
)
func recordCacheResult(ctx context.Context, name string, hit bool, latency time.Duration) {
ms := float64(latency) / float64(time.Millisecond)
stats.RecordWithTags(ctx, []tag.Mutator{
tag.Upsert(keyCacheName, name),
tag.Upsert(keyCacheHit, strconv.FormatBool(hit)),
}, cacheResults.M(1), cacheLatency.M(ms))
}
func recordCacheError(ctx context.Context, name, operation string) {
stats.RecordWithTags(ctx, []tag.Mutator{
tag.Upsert(keyCacheName, name),
tag.Upsert(keyCacheOperation, operation),
}, cacheErrors.M(1))
}
type cache struct {
name string
authValues []string
client *redis.Client
delegate http.Handler
expirer Expirer
}
// An Expirer computes the TTL that should be used when caching a page.
type Expirer func(r *http.Request) time.Duration
// TTL returns an Expirer that expires all pages after the given TTL.
func TTL(ttl time.Duration) Expirer {
return func(r *http.Request) time.Duration {
return ttl
}
}
// Cache returns a new Middleware that caches every request.
// The name of the cache is used only for metrics.
// The expirer is a func that is used to map a new request to its TTL.
// authHeader is the header key used by the cache to know that a
// request should bypass the cache.
// authValues is the set of values that could be set on the authHeader in
// order to bypass the cache.
func Cache(name string, client *redis.Client, expirer Expirer, authValues []string) Middleware {
return func(h http.Handler) http.Handler {
return &cache{
name: name,
authValues: authValues,
client: client,
delegate: h,
expirer: expirer,
}
}
}
func (c *cache) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Check auth header to see if request should bypass cache.
authVal := r.Header.Get(config.BypassCacheAuthHeader)
for _, wantVal := range c.authValues {
if authVal == wantVal {
c.delegate.ServeHTTP(w, r)
return
}
}
ctx := r.Context()
key := r.URL.String()
start := time.Now()
reader, hit := c.get(ctx, key)
recordCacheResult(ctx, c.name, hit, time.Since(start))
if hit {
if _, err := io.Copy(w, reader); err != nil {
log.Errorf(ctx, "error copying zip bytes: %v", err)
}
return
}
rec := newRecorder(w)
c.delegate.ServeHTTP(rec, r)
if rec.bufErr == nil && (rec.statusCode == 0 || rec.statusCode == http.StatusOK) {
ttl := c.expirer(r)
if testMode {
c.put(ctx, key, rec, ttl)
} else {
go c.put(ctx, key, rec, ttl)
}
}
}
func (c *cache) get(ctx context.Context, key string) (io.Reader, bool) {
// Set a short timeout for redis requests, so that we can quickly
// fall back to un-cached serving if redis is unavailable.
getCtx, cancelGet := context.WithTimeout(ctx, 100*time.Millisecond)
defer cancelGet()
val, err := c.client.WithContext(getCtx).Get(key).Bytes()
if err == redis.Nil {
return nil, false
}
if err != nil {
select {
case <-getCtx.Done():
log.Infof(ctx, "cache get(%q): context timed out", key)
default:
log.Infof(ctx, "cache get(%q): %v", key, err)
}
recordCacheError(ctx, c.name, "GET")
return nil, false
}
zr, err := gzip.NewReader(bytes.NewReader(val))
if err != nil {
log.Errorf(ctx, "cache: gzip.NewReader: %v", err)
recordCacheError(ctx, c.name, "UNZIP")
return nil, false
}
return zr, true
}
func (c *cache) put(ctx context.Context, key string, rec *cacheRecorder, ttl time.Duration) {
if err := rec.zipWriter.Close(); err != nil {
log.Errorf(ctx, "cache: error closing zip for %q: %v", key, err)
return
}
log.Infof(ctx, "caching response of length %d for %s", rec.buf.Len(), key)
setCtx, cancelSet := context.WithTimeout(context.Background(), 1*time.Second)
defer cancelSet()
_, err := c.client.WithContext(setCtx).Set(key, rec.buf.Bytes(), ttl).Result()
if err != nil {
recordCacheError(ctx, c.name, "SET")
log.Warningf(ctx, "cache set %q: %v", key, err)
}
}
func newRecorder(w http.ResponseWriter) *cacheRecorder {
buf := &bytes.Buffer{}
zw := gzip.NewWriter(buf)
return &cacheRecorder{ResponseWriter: w, buf: buf, zipWriter: zw}
}
// cacheRecorder is an http.ResponseWriter that collects http bytes for later
// writing to the cache. Along the way it collects any error, along with the
// resulting HTTP status code. We only cache 200 OK responses.
type cacheRecorder struct {
http.ResponseWriter
statusCode int
bufErr error
buf *bytes.Buffer
zipWriter *gzip.Writer
}
func (r *cacheRecorder) Write(b []byte) (int, error) {
n, err := r.ResponseWriter.Write(b)
// Only try writing to the buffer if we haven't yet encountered an error.
if r.bufErr == nil {
if err == nil {
zn, bufErr := r.zipWriter.Write(b)
if bufErr != nil {
r.bufErr = bufErr
}
if zn != n {
r.bufErr = fmt.Errorf("wrote %d to zip, but wanted %d", zn, n)
}
} else {
r.bufErr = fmt.Errorf("ResponseWriter.Write failed: %v", err)
}
}
return n, err
}
func (r *cacheRecorder) WriteHeader(statusCode int) {
if statusCode > r.statusCode {
// Defensively take the largest status code that's written, so if any
// middleware thinks the response is not OK, we will capture this.
r.statusCode = statusCode
}
r.ResponseWriter.WriteHeader(statusCode)
}