// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package generator

import (
	"fmt"
	"math/rand/v2"
	"os"
	"runtime"
	"strings"
)

const (
	RandCtlNochecks = 0
	RandCtlChecks   = 1 << iota
	RandCtlCapture
	RandCtlPanic
)

func NewWrapRand(seed int64, ctl int) *wraprand {
	return &wraprand{seed: seed, ctl: ctl, rand: rand.New(rand.NewPCG(0, uint64(seed)))}
}

type wraprand struct {
	f32calls  int
	f64calls  int
	intncalls int
	seed      int64
	tag       string
	calls     []string
	ctl       int
	rand      *rand.Rand
}

func (w *wraprand) captureCall(tag string, val string) {
	var call strings.Builder
	fmt.Fprintf(&call, "%s: %s\n", tag, val)
	pc := make([]uintptr, 10)
	n := runtime.Callers(1, pc)
	if n == 0 {
		panic("why?")
	}
	pc = pc[:n] // pass only valid pcs to runtime.CallersFrames
	frames := runtime.CallersFrames(pc)
	for {
		frame, more := frames.Next()
		if strings.Contains(frame.File, "testing.") {
			break
		}
		fmt.Fprintf(&call, "%s %s:%d\n", frame.Function, frame.File, frame.Line)
		if !more {
			break
		}

	}
	w.calls = append(w.calls, call.String())
}

func (w *wraprand) Intn(n int64) int64 {
	w.intncalls++
	rv := w.rand.Int64N(n)
	if w.ctl&RandCtlCapture != 0 {
		w.captureCall("Intn", fmt.Sprintf("%d", rv))
	}
	return rv
}

func (w *wraprand) Float32() float32 {
	w.f32calls++
	rv := w.rand.Float32()
	if w.ctl&RandCtlCapture != 0 {
		w.captureCall("Float32", fmt.Sprintf("%f", rv))
	}
	return rv
}

func (w *wraprand) NormFloat64() float64 {
	w.f64calls++
	rv := w.rand.NormFloat64()
	if w.ctl&RandCtlCapture != 0 {
		w.captureCall("NormFloat64", fmt.Sprintf("%f", rv))
	}
	return rv
}

func (w *wraprand) emitCalls(fn string) {
	outf, err := os.OpenFile(fn, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o666)
	if err != nil {
		panic(err)
	}
	for _, c := range w.calls {
		fmt.Fprint(outf, c)
	}
	outf.Close()
}

func (w *wraprand) Equal(w2 *wraprand) bool {
	return w.f32calls == w2.f32calls &&
		w.f64calls == w2.f64calls &&
		w.intncalls == w2.intncalls
}

func (w *wraprand) Check(w2 *wraprand) {
	if w.ctl != 0 && !w.Equal(w2) {
		fmt.Fprintf(os.Stderr, "wraprand consistency check failed:\n")
		t := "w"
		if w.tag != "" {
			t = w.tag
		}
		t2 := "w2"
		if w2.tag != "" {
			t2 = w2.tag
		}
		fmt.Fprintf(os.Stderr, " %s: {f32:%d f64:%d i:%d}\n", t,
			w.f32calls, w.f64calls, w.intncalls)
		fmt.Fprintf(os.Stderr, " %s: {f32:%d f64:%d i:%d}\n", t2,
			w2.f32calls, w2.f64calls, w2.intncalls)
		if w.ctl&RandCtlCapture != 0 {
			f := fmt.Sprintf("/tmp/%s.txt", t)
			f2 := fmt.Sprintf("/tmp/%s.txt", t2)
			w.emitCalls(f)
			w2.emitCalls(f2)
			fmt.Fprintf(os.Stderr, "=-= emitted calls to %s, %s\n", f, f2)
		}
		if w.ctl&RandCtlPanic != 0 {
			panic("bad")
		}
	}
}

func (w *wraprand) Checkpoint(tag string) {
	if w.ctl&RandCtlCapture != 0 {
		w.calls = append(w.calls, "=-=\n"+tag+"\n=-=\n")
	}
}
