// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package middleware

import (
	"context"
	"fmt"
	"hash/fnv"
	"net/http"
	"time"

	"cloud.google.com/go/errorreporting"
	"golang.org/x/pkgsite/internal"
	"golang.org/x/pkgsite/internal/derrors"
	"golang.org/x/pkgsite/internal/experiment"
	"golang.org/x/pkgsite/internal/log"
	"golang.org/x/pkgsite/internal/poller"
)

const experimentQueryParamKey = "experiment"

// A Reporter sends errors to the Error-Reporting service.
type Reporter interface {
	Report(errorreporting.Entry)
}

// ExperimentGetter is the signature of a function that gets experiments.
type ExperimentGetter func(context.Context) ([]*internal.Experiment, error)

// An Experimenter contains information about active experiments from the
// experiment source.
type Experimenter struct {
	p *poller.Poller
}

// NewExperimenter returns an Experimenter for use in the middleware. The
// experimenter regularly polls for updates to the snapshot in the background.
func NewExperimenter(ctx context.Context, pollEvery time.Duration, getter ExperimentGetter, rep Reporter) (_ *Experimenter, err error) {
	defer derrors.Wrap(&err, "middleware.NewExperimenter")

	initial, err := getter(ctx)
	// If we can't load the initial state, then fail.
	if err != nil {
		return nil, err
	}
	e := &Experimenter{
		p: poller.New(
			initial,
			func(ctx context.Context) (interface{}, error) {
				return getter(ctx)
			},
			func(err error) {
				// Log and report // the error.
				log.Error(ctx, err)
				if rep != nil {
					rep.Report(errorreporting.Entry{
						Error: fmt.Errorf("loading experiments: %v", err),
					})
				}
			}),
	}
	e.p.Start(ctx, pollEvery)
	return e, nil
}

// Experiment returns a new Middleware that sets active experiments for each
// incoming request.
func Experiment(e *Experimenter) Middleware {
	return func(h http.Handler) http.Handler {
		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			r2 := e.setExperimentsForRequest(r)
			h.ServeHTTP(w, r2)
		})
	}
}

// Experiments returns the experiments currently in use.
func (e *Experimenter) Experiments() []*internal.Experiment {
	// Make a copy so the caller can't modify our state.
	snapshot := e.p.Current().([]*internal.Experiment)
	// We don't need a lock here because e.p.current will be updated
	// without modification.
	exps := make([]*internal.Experiment, len(snapshot))
	for i, x := range snapshot {
		// Assume internal.Experiment has no pointers to mutable data.
		nx := *x
		exps[i] = &nx
	}
	return exps
}

// setExperimentsForRequest sets the experiments for a given request.
// Experiments should be stable for a given IP address.
func (e *Experimenter) setExperimentsForRequest(r *http.Request) *http.Request {
	snapshot := e.p.Current().([]*internal.Experiment)
	var exps []string
	for _, exp := range snapshot {
		if shouldSetExperiment(r, exp) {
			exps = append(exps, exp.Name)
		}
	}
	exps = append(exps, r.URL.Query()[experimentQueryParamKey]...)
	return r.WithContext(experiment.NewContext(r.Context(), exps...))
}

// shouldSetExperiment reports whether a given request should be enrolled in
// the experiment, based on the ip. e.Name, and e.Rollout.
//
// Requests from empty ip addresses are never enrolled.
// All requests from the same IP will be enrolled in the same set of
// experiments.
func shouldSetExperiment(r *http.Request, e *internal.Experiment) bool {
	if e.Rollout == 0 {
		return false
	}
	if e.Rollout >= 100 {
		return true
	}
	ip := ipKey(r.Header.Get("X-Forwarded-For"))
	if ip == "" {
		return false
	}
	h := fnv.New32a()
	fmt.Fprintf(h, "%s %s", ip, e.Name)
	return uint(h.Sum32())%100 < e.Rollout
}
