| // Copyright 2020 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package middleware |
| |
| import ( |
| "context" |
| "encoding/json" |
| "hash" |
| "hash/fnv" |
| "net/http" |
| "time" |
| ) |
| |
| // statsKey is the type of the context key for stats. |
| type statsKey struct{} |
| |
| // Stats returns a Middleware that, instead of serving the page, |
| // serves statistics about the page. |
| func Stats() Middleware { |
| return func(h http.Handler) http.Handler { |
| return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| sw := newStatsResponseWriter() |
| ctx := context.WithValue(r.Context(), statsKey{}, sw.stats.Other) |
| h.ServeHTTP(sw, r.WithContext(ctx)) |
| sw.WriteStats(ctx, w) |
| }) |
| } |
| } |
| |
| // SetStat sets a stat named key in the current context. If key already has a |
| // value, the old and new value are both stored in a slice. |
| func SetStat(ctx context.Context, key string, value interface{}) { |
| x := ctx.Value(statsKey{}) |
| if x == nil { |
| return |
| } |
| m := x.(map[string]interface{}) |
| v, ok := m[key] |
| if !ok { |
| m[key] = value |
| } else if s, ok := v.([]interface{}); ok { |
| m[key] = append(s, value) |
| } else { |
| m[key] = []interface{}{v, value} |
| } |
| } |
| |
| // ElapsedStat records as a stat the elapsed time for a |
| // function execution. Invoke like so: |
| // defer ElapsedStat(ctx, "FunctionName")() |
| // The resulting stat will be called "FunctionName ms" and will |
| // be the wall-clock execution time of the function in milliseconds. |
| func ElapsedStat(ctx context.Context, name string) func() { |
| start := time.Now() |
| return func() { |
| SetStat(ctx, name+" ms", time.Since(start).Milliseconds()) |
| } |
| } |
| |
| // statsResponseWriter is an http.ResponseWriter that tracks statistics about |
| // the page being written. |
| type statsResponseWriter struct { |
| header http.Header // required for a ResponseWriter; ignored |
| start time.Time // start time of request |
| hasher hash.Hash64 |
| stats PageStats |
| } |
| |
| type PageStats struct { |
| MillisToFirstByte int64 |
| MillisToLastByte int64 |
| Hash uint64 // hash of page contents |
| Size int // total size of data written |
| StatusCode int // HTTP status |
| Other map[string]interface{} |
| } |
| |
| func newStatsResponseWriter() *statsResponseWriter { |
| return &statsResponseWriter{ |
| header: http.Header{}, |
| start: time.Now(), |
| hasher: fnv.New64a(), |
| stats: PageStats{Other: map[string]interface{}{}}, |
| } |
| } |
| |
| // Header implements http.ResponseWriter.Header. |
| func (s *statsResponseWriter) Header() http.Header { return s.header } |
| |
| // WriteHeader implements http.ResponseWriter.WriteHeader. |
| func (s *statsResponseWriter) WriteHeader(statusCode int) { |
| s.stats.StatusCode = statusCode |
| } |
| |
| // Write implements http.ResponseWriter.Write by |
| // tracking statistics about the data being written. |
| func (s *statsResponseWriter) Write(data []byte) (int, error) { |
| if s.stats.Size == 0 { |
| s.stats.MillisToFirstByte = time.Since(s.start).Milliseconds() |
| } |
| if s.stats.StatusCode == 0 { |
| s.WriteHeader(http.StatusOK) |
| } |
| s.stats.Size += len(data) |
| s.hasher.Write(data) |
| return len(data), nil |
| } |
| |
| // WriteStats writes the statistics to w. |
| func (s *statsResponseWriter) WriteStats(ctx context.Context, w http.ResponseWriter) { |
| s.stats.MillisToLastByte = time.Since(s.start).Milliseconds() |
| s.stats.Hash = s.hasher.Sum64() |
| data, err := json.Marshal(s.stats) |
| if err != nil { |
| http.Error(w, err.Error(), http.StatusInternalServerError) |
| } else { |
| _, _ = w.Write(data) |
| } |
| } |