blob: 4ab13f697401843154dfe11cfab0231a2c623798 [file] [log] [blame]
 // Copyright 2015 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package stats import ( "fmt" "math" "sort" "strings" "testing" ) func testDiscreteCDF(t *testing.T, name string, dist DiscreteDist) { // Build the expected CDF out of the PMF. l, h := dist.Bounds() s := dist.Step() want := map[float64]float64{l - 0.1: 0, h: 1} sum := 0.0 for x := l; x < h; x += s { sum += dist.PMF(x) want[x] = sum want[x+s/2] = sum } testFunc(t, name, dist.CDF, want) } func testInvCDF(t *testing.T, dist Dist, bounded bool) { inv := InvCDF(dist) name := fmt.Sprintf("InvCDF(%+v)", dist) cdfName := fmt.Sprintf("CDF(%+v)", dist) // Test bounds. vals := map[float64]float64{-0.01: nan, 1.01: nan} if !bounded { vals[0] = -inf vals[1] = inf } testFunc(t, name, inv, vals) if bounded { lo, hi := inv(0), inv(1) vals := map[float64]float64{ lo - 0.01: 0, lo: 0, hi: 1, hi + 0.01: 1, } testFunc(t, cdfName, dist.CDF, vals) if got := dist.CDF(lo + 0.01); !(got > 0) { t.Errorf("%s(0)=%v, but %s(%v)=0", name, lo, cdfName, lo+0.01) } if got := dist.CDF(hi - 0.01); !(got < 1) { t.Errorf("%s(1)=%v, but %s(%v)=1", name, hi, cdfName, hi-0.01) } } // Test points between. vals = map[float64]float64{} for _, p := range vecLinspace(0, 1, 11) { if p == 0 || p == 1 { continue } x := inv(p) vals[x] = x } testFunc(t, fmt.Sprintf("InvCDF(CDF(%+v))", dist), func(x float64) float64 { return inv(dist.CDF(x)) }, vals) } // aeq returns true if expect and got are equal to 8 significant // figures (1 part in 100 million). func aeq(expect, got float64) bool { if expect < 0 && got < 0 { expect, got = -expect, -got } return expect*0.99999999 <= got && got*0.99999999 <= expect } func testFunc(t *testing.T, name string, f func(float64) float64, vals map[float64]float64) { xs := make([]float64, 0, len(vals)) for x := range vals { xs = append(xs, x) } sort.Float64s(xs) for _, x := range xs { want, got := vals[x], f(x) if math.IsNaN(want) && math.IsNaN(got) || aeq(want, got) { continue } var label string if strings.Contains(name, "%v") { label = fmt.Sprintf(name, x) } else { label = fmt.Sprintf("%s(%v)", name, x) } t.Errorf("want %s=%v, got %v", label, want, got) } } // vecLinspace returns num values spaced evenly between lo and hi, // inclusive. If num is 1, this returns an array consisting of lo. func vecLinspace(lo, hi float64, num int) []float64 { res := make([]float64, num) if num == 1 { res[0] = lo return res } for i := 0; i < num; i++ { res[i] = lo + float64(i)*(hi-lo)/float64(num-1) } return res }