blob: bba178dc317c1a11fe47b9c881b82a92a0e540e9 [file] [log] [blame]
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package generator
import (
"fmt"
"math/rand"
"os"
"runtime"
"strings"
)
const (
RandCtlNochecks = 0
RandCtlChecks = 1 << iota
RandCtlCapture
RandCtlPanic
)
func NewWrapRand(seed int64, ctl int) *wraprand {
rand.Seed(seed)
return &wraprand{seed: seed, ctl: ctl}
}
type wraprand struct {
f32calls int
f64calls int
intncalls int
seed int64
tag string
calls []string
ctl int
}
func (w *wraprand) captureCall(tag string, val string) {
call := tag + ": " + val + "\n"
pc := make([]uintptr, 10)
n := runtime.Callers(1, pc)
if n == 0 {
panic("why?")
}
pc = pc[:n] // pass only valid pcs to runtime.CallersFrames
frames := runtime.CallersFrames(pc)
for {
frame, more := frames.Next()
if strings.Contains(frame.File, "testing.") {
break
}
call += fmt.Sprintf("%s %s:%d\n", frame.Function, frame.File, frame.Line)
if !more {
break
}
}
w.calls = append(w.calls, call)
}
func (w *wraprand) Intn(n int64) int64 {
w.intncalls++
rv := rand.Int63n(n)
if w.ctl&RandCtlCapture != 0 {
w.captureCall("Intn", fmt.Sprintf("%d", rv))
}
return rv
}
func (w *wraprand) Float32() float32 {
w.f32calls++
rv := rand.Float32()
if w.ctl&RandCtlCapture != 0 {
w.captureCall("Float32", fmt.Sprintf("%f", rv))
}
return rv
}
func (w *wraprand) NormFloat64() float64 {
w.f64calls++
rv := rand.NormFloat64()
if w.ctl&RandCtlCapture != 0 {
w.captureCall("NormFloat64", fmt.Sprintf("%f", rv))
}
return rv
}
func (w *wraprand) emitCalls(fn string) {
outf, err := os.OpenFile(fn, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666)
if err != nil {
panic(err)
}
for _, c := range w.calls {
fmt.Fprint(outf, c)
}
outf.Close()
}
func (w *wraprand) Equal(w2 *wraprand) bool {
return w.f32calls == w2.f32calls &&
w.f64calls == w2.f64calls &&
w.intncalls == w2.intncalls
}
func (w *wraprand) Check(w2 *wraprand) {
if w.ctl != 0 && !w.Equal(w2) {
fmt.Fprintf(os.Stderr, "wraprand consistency check failed:\n")
t := "w"
if w.tag != "" {
t = w.tag
}
t2 := "w2"
if w2.tag != "" {
t2 = w2.tag
}
fmt.Fprintf(os.Stderr, " %s: {f32:%d f64:%d i:%d}\n", t,
w.f32calls, w.f64calls, w.intncalls)
fmt.Fprintf(os.Stderr, " %s: {f32:%d f64:%d i:%d}\n", t2,
w2.f32calls, w2.f64calls, w2.intncalls)
if w.ctl&RandCtlCapture != 0 {
f := fmt.Sprintf("/tmp/%s.txt", t)
f2 := fmt.Sprintf("/tmp/%s.txt", t2)
w.emitCalls(f)
w2.emitCalls(f2)
fmt.Fprintf(os.Stderr, "=-= emitted calls to %s, %s\n", f, f2)
}
if w.ctl&RandCtlPanic != 0 {
panic("bad")
}
}
}
func (w *wraprand) Checkpoint(tag string) {
if w.ctl&RandCtlCapture != 0 {
w.calls = append(w.calls, "=-=\n"+tag+"\n=-=\n")
}
}