internal/frontend,middleware: add a handler for page statistics

If the GO_DISCOVERY_SERVE_STATS env var is true, add a /detail-stats
handler that handles any detail page, but instead of serving the page
it serves JSON with page statistics, like timings and size.

Implemented as middleware that runs the actual detail handler,
bypassing the cache.

Will be useful for performance work, like forthcoming doc rendering
changes.

Change-Id: I738f7413a33caaabc4a116235e68ac3932c18f96
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/257245
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Julie Qiu <julie@golang.org>
diff --git a/cmd/frontend/main.go b/cmd/frontend/main.go
index 2769d14..8e716dd 100644
--- a/cmd/frontend/main.go
+++ b/cmd/frontend/main.go
@@ -134,6 +134,7 @@
 		DevMode:              *devMode,
 		AppVersionLabel:      cfg.AppVersionLabel(),
 		GoogleTagManagerID:   cfg.GoogleTagManagerID,
+		ServeStats:           cfg.ServeStats,
 	})
 	if err != nil {
 		log.Fatalf(ctx, "frontend.NewServer: %v", err)
diff --git a/internal/config/config.go b/internal/config/config.go
index 061d153..2eee8f6 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -176,6 +176,10 @@
 	// DynamicConfigLocation is the location (either a file or gs://bucket/object) for
 	// dynamic configuration.
 	DynamicConfigLocation string
+
+	// ServeStats determines whether the server has an endpoint that serves statistics for
+	// benchmarking or other purposes.
+	ServeStats bool
 }
 
 // AppVersionLabel returns the version label for the current instance.  This is
@@ -402,7 +406,8 @@
 			MaxTimeout:       time.Duration(GetEnvInt("GO_DISCOVERY_TEEPROXY_MAX_TIMEOUT_SECONDS", 240)) * time.Second,
 			SuccsToGreen:     GetEnvInt("GO_DISCOVERY_TEEPROXY_SUCCS_TO_GREEN", 20),
 		},
-		LogLevel: os.Getenv("GO_DISCOVERY_LOG_LEVEL"),
+		LogLevel:   os.Getenv("GO_DISCOVERY_LOG_LEVEL"),
+		ServeStats: os.Getenv("GO_DISCOVERY_SERVE_STATS") == "true",
 	}
 	bucket := os.Getenv("GO_DISCOVERY_CONFIG_BUCKET")
 	object := os.Getenv("GO_DISCOVERY_CONFIG_DYNAMIC")
diff --git a/internal/frontend/server.go b/internal/frontend/server.go
index f75d9a9..519147d 100644
--- a/internal/frontend/server.go
+++ b/internal/frontend/server.go
@@ -43,6 +43,7 @@
 	errorPage            []byte
 	appVersionLabel      string
 	googleTagManagerID   string
+	serveStats           bool
 
 	mu        sync.Mutex // Protects all fields below
 	templates map[string]*template.Template
@@ -61,6 +62,7 @@
 	DevMode              bool
 	AppVersionLabel      string
 	GoogleTagManagerID   string
+	ServeStats           bool
 }
 
 // NewServer creates a new Server for the given database and template directory.
@@ -83,6 +85,7 @@
 		taskIDChangeInterval: scfg.TaskIDChangeInterval,
 		appVersionLabel:      scfg.AppVersionLabel,
 		googleTagManagerID:   scfg.GoogleTagManagerID,
+		serveStats:           scfg.ServeStats,
 	}
 	errorPageBytes, err := s.renderErrorPage(context.Background(), http.StatusInternalServerError, "error.tmpl", nil)
 	if err != nil {
@@ -119,6 +122,10 @@
 	handle("/about", http.RedirectHandler("https://go.dev/about", http.StatusFound))
 	handle("/badge/", http.HandlerFunc(s.badgeHandler))
 	handle("/", detailHandler)
+	if s.serveStats {
+		handle("/detail-stats/",
+			middleware.Stats()(http.StripPrefix("/detail-stats", s.errorHandler(s.serveDetails))))
+	}
 	handle("/autocomplete", http.HandlerFunc(s.handleAutoCompletion))
 	handle("/robots.txt", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		w.Header().Set("Content-Type", "text/plain; charset=utf-8")
diff --git a/internal/middleware/stats.go b/internal/middleware/stats.go
new file mode 100644
index 0000000..c62cefe
--- /dev/null
+++ b/internal/middleware/stats.go
@@ -0,0 +1,85 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package middleware
+
+import (
+	"context"
+	"encoding/json"
+	"hash"
+	"hash/fnv"
+	"net/http"
+	"time"
+)
+
+// Stats returns a Middleware that, instead of serving the page,
+// serves statistics about the page.
+func Stats() Middleware {
+	return func(h http.Handler) http.Handler {
+		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+			sw := newStatsResponseWriter()
+			h.ServeHTTP(sw, r)
+			sw.WriteStats(r.Context(), w)
+		})
+	}
+}
+
+// statsResponseWriter is an http.ResponseWriter that tracks statistics about
+// the page being written.
+type statsResponseWriter struct {
+	header http.Header // required for a ResponseWriter; ignored
+	start  time.Time   // start time of request
+	hasher hash.Hash64
+	stats  PageStats
+}
+
+type PageStats struct {
+	MillisToFirstByte int64
+	MillisToLastByte  int64
+	Hash              uint64 // hash of page contents
+	Size              int    // total size of data written
+	StatusCode        int    // HTTP status
+}
+
+func newStatsResponseWriter() *statsResponseWriter {
+	return &statsResponseWriter{
+		header: http.Header{},
+		start:  time.Now(),
+		hasher: fnv.New64a(),
+	}
+}
+
+// Header implements http.ResponseWriter.Header.
+func (s *statsResponseWriter) Header() http.Header { return s.header }
+
+// WriteHeader implements http.ResponseWriter.WriteHeader.
+func (s *statsResponseWriter) WriteHeader(statusCode int) {
+	s.stats.StatusCode = statusCode
+}
+
+// Write implements http.ResponseWriter.Write by
+// tracking statistics about the data being written.
+func (s *statsResponseWriter) Write(data []byte) (int, error) {
+	if s.stats.Size == 0 {
+		s.stats.MillisToFirstByte = time.Since(s.start).Milliseconds()
+	}
+	if s.stats.StatusCode == 0 {
+		s.WriteHeader(http.StatusOK)
+	}
+	s.stats.Size += len(data)
+	s.hasher.Write(data)
+	return len(data), nil
+}
+
+// WriteStats writes the statistics to w.
+func (s *statsResponseWriter) WriteStats(ctx context.Context, w http.ResponseWriter) {
+	s.stats.MillisToLastByte = time.Since(s.start).Milliseconds()
+	s.stats.Hash = s.hasher.Sum64()
+	data, err := json.Marshal(s.stats)
+	if err != nil {
+		http.Error(w, err.Error(), http.StatusInternalServerError)
+	} else {
+		_, _ = w.Write(data)
+	}
+}
diff --git a/internal/middleware/stats_test.go b/internal/middleware/stats_test.go
new file mode 100644
index 0000000..d677ec8
--- /dev/null
+++ b/internal/middleware/stats_test.go
@@ -0,0 +1,73 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package middleware
+
+import (
+	"encoding/json"
+	"hash/fnv"
+	"io/ioutil"
+	"net/http"
+	"net/http/httptest"
+	"testing"
+	"time"
+
+	"github.com/google/go-cmp/cmp"
+	"github.com/google/go-cmp/cmp/cmpopts"
+)
+
+func TestStats(t *testing.T) {
+	data := []byte("this is the data we are going to serve")
+	const code = 218
+	ts := httptest.NewServer(Stats()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.WriteHeader(code)
+		w.Write(data[:10])
+		time.Sleep(500 * time.Millisecond)
+		w.Write(data[10:])
+	})))
+	defer ts.Close()
+	res, err := ts.Client().Get(ts.URL)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer res.Body.Close()
+	if res.StatusCode != http.StatusOK {
+		t.Fatalf("failed with status %d", res.StatusCode)
+	}
+	gotData, err := ioutil.ReadAll(res.Body)
+	if err != nil {
+		t.Fatal(err)
+	}
+	var got PageStats
+	if err := json.Unmarshal(gotData, &got); err != nil {
+		t.Fatal(err)
+	}
+
+	h := fnv.New64a()
+	h.Write(data)
+	want := PageStats{
+		StatusCode: code,
+		Size:       len(data),
+		Hash:       h.Sum64(),
+	}
+	diff := cmp.Diff(want, got, cmpopts.IgnoreFields(PageStats{}, "MillisToFirstByte", "MillisToLastByte"))
+	if diff != "" {
+		t.Errorf("mismatch (-want, +got):\n%s", diff)
+	}
+	const tolerance = 50 // 50 ms of tolerance for time measurements
+	if g := got.MillisToFirstByte; !within(g, 0, tolerance) {
+		t.Errorf("MillisToFirstByte is %d, wanted 0 - %d", g, tolerance)
+	}
+	if g := got.MillisToLastByte; !within(g, 500, tolerance) {
+		t.Errorf("MillisToLastByte is %d, wanted 500 +/- %d", g, tolerance)
+	}
+}
+
+func within(got, want, tolerance int64) bool {
+	d := got - want
+	if d < 0 {
+		d = -d
+	}
+	return d <= tolerance
+}