// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package middleware

import (
	"context"
	"encoding/json"
	"hash"
	"hash/fnv"
	"net/http"
	"time"
)

// statsKey is the type of the context key for stats.
type statsKey struct{}

// Stats returns a Middleware that, instead of serving the page,
// serves statistics about the page.
func Stats() Middleware {
	return func(h http.Handler) http.Handler {
		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			sw := newStatsResponseWriter()
			ctx := context.WithValue(r.Context(), statsKey{}, sw.stats.Other)
			h.ServeHTTP(sw, r.WithContext(ctx))
			sw.WriteStats(ctx, w)
		})
	}
}

// SetStat sets a stat named key in the current context. If key already has a
// value, the old and new value are both stored in a slice.
func SetStat(ctx context.Context, key string, value interface{}) {
	x := ctx.Value(statsKey{})
	if x == nil {
		return
	}
	m := x.(map[string]interface{})
	v, ok := m[key]
	if !ok {
		m[key] = value
	} else if s, ok := v.([]interface{}); ok {
		m[key] = append(s, value)
	} else {
		m[key] = []interface{}{v, value}
	}
}

// ElapsedStat records as a stat the elapsed time for a
// function execution. Invoke like so:
//
//	defer ElapsedStat(ctx, "FunctionName")()
//
// The resulting stat will be called "FunctionName ms" and will
// be the wall-clock execution time of the function in milliseconds.
func ElapsedStat(ctx context.Context, name string) func() {
	start := time.Now()
	return func() {
		SetStat(ctx, name+" ms", time.Since(start).Milliseconds())
	}
}

// statsResponseWriter is an http.ResponseWriter that tracks statistics about
// the page being written.
type statsResponseWriter struct {
	header http.Header // required for a ResponseWriter; ignored
	start  time.Time   // start time of request
	hasher hash.Hash64
	stats  PageStats
}

type PageStats struct {
	MillisToFirstByte int64
	MillisToLastByte  int64
	Hash              uint64 // hash of page contents
	Size              int    // total size of data written
	StatusCode        int    // HTTP status
	Other             map[string]interface{}
}

func newStatsResponseWriter() *statsResponseWriter {
	return &statsResponseWriter{
		header: http.Header{},
		start:  time.Now(),
		hasher: fnv.New64a(),
		stats:  PageStats{Other: map[string]interface{}{}},
	}
}

// Header implements http.ResponseWriter.Header.
func (s *statsResponseWriter) Header() http.Header { return s.header }

// WriteHeader implements http.ResponseWriter.WriteHeader.
func (s *statsResponseWriter) WriteHeader(statusCode int) {
	s.stats.StatusCode = statusCode
}

// Write implements http.ResponseWriter.Write by
// tracking statistics about the data being written.
func (s *statsResponseWriter) Write(data []byte) (int, error) {
	if s.stats.Size == 0 {
		s.stats.MillisToFirstByte = time.Since(s.start).Milliseconds()
	}
	if s.stats.StatusCode == 0 {
		s.WriteHeader(http.StatusOK)
	}
	s.stats.Size += len(data)
	s.hasher.Write(data)
	return len(data), nil
}

// WriteStats writes the statistics to w.
func (s *statsResponseWriter) WriteStats(ctx context.Context, w http.ResponseWriter) {
	s.stats.MillisToLastByte = time.Since(s.start).Milliseconds()
	s.stats.Hash = s.hasher.Sum64()
	data, err := json.Marshal(s.stats)
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
	} else {
		_, _ = w.Write(data)
	}
}
