blob: a1041c03fb365ec01b310483e948b9ffa0a77e91 [file] [log] [blame]
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package memoize_test
import (
"bytes"
"context"
"fmt"
"io"
"runtime"
"strings"
"testing"
"time"
"golang.org/x/tools/internal/memoize"
)
func TestStore(t *testing.T) {
pinned := []string{"b", "_1", "_3"}
unpinned := []string{"a", "c", "d", "_2", "_4"}
ctx := context.Background()
s := &memoize.Store{}
logBuffer := &bytes.Buffer{}
s.Bind("logger", func(context.Context) interface{} { return logBuffer }).Get(ctx)
verifyBuffer := func(name, expect string) {
got := logBuffer.String()
if got != expect {
t.Errorf("at %q expected:\n%v\ngot:\n%s", name, expect, got)
}
logBuffer.Reset()
}
verifyBuffer("nothing", ``)
s.Bind("_1", generate(s, "_1")).Get(ctx)
verifyBuffer("get 1", `
start @1
simple a = A
simple b = B
simple c = C
end @1 = A B C
`[1:])
s.Bind("_1", generate(s, "_1")).Get(ctx)
verifyBuffer("redo 1", ``)
s.Bind("_2", generate(s, "_2")).Get(ctx)
verifyBuffer("get 2", `
start @2
simple d = D
simple e = E
simple f = F
end @2 = D E F
`[1:])
s.Bind("_2", generate(s, "_2")).Get(ctx)
verifyBuffer("redo 2", ``)
s.Bind("_3", generate(s, "_3")).Get(ctx)
verifyBuffer("get 3", `
start @3
end @3 = @1[ A B C] @2[ D E F]
`[1:])
s.Bind("_4", generate(s, "_4")).Get(ctx)
verifyBuffer("get 4", `
start @3
simple g = G
error ERR = fail
simple h = H
end @3 = G !fail H
`[1:])
var pins []*memoize.Handle
for _, key := range pinned {
h := s.Bind(key, generate(s, key))
h.Get(ctx)
pins = append(pins, h)
}
runAllFinalizers(t)
for _, k := range pinned {
if v := s.Cached(k); v == nil {
t.Errorf("Pinned value %q was nil", k)
}
}
for _, k := range unpinned {
if v := s.Cached(k); v != nil {
t.Errorf("Unpinned value %q was %q", k, v)
}
}
runtime.KeepAlive(pins)
}
func runAllFinalizers(t *testing.T) {
// the following is very tricky, be very careful changing it
// it relies on behavior of finalizers that is not guaranteed
// first run the GC to queue the finalizers
runtime.GC()
// wait is used to signal that the finalizers are all done
wait := make(chan struct{})
// register a finalizer against an immediately collectible object
runtime.SetFinalizer(&struct{ s string }{"obj"}, func(_ interface{}) { close(wait) })
// now run the GC again to pick up the tracker
runtime.GC()
// now wait for the finalizers to run
select {
case <-wait:
case <-time.Tick(time.Second):
t.Fatalf("Finalizers had not run after a second")
}
}
type stringOrError struct {
memoize.NoCopy
value string
err error
}
func (v *stringOrError) String() string {
if v.err != nil {
return v.err.Error()
}
return v.value
}
func asValue(v interface{}) *stringOrError {
if v == nil {
return nil
}
return v.(*stringOrError)
}
func generate(s *memoize.Store, key interface{}) memoize.Function {
return func(ctx context.Context) interface{} {
name := key.(string)
switch name {
case "err":
return logGenerator(ctx, s, "ERR", "", fmt.Errorf("fail"))
case "_1":
return joinValues(ctx, s, "@1", "a", "b", "c")
case "_2":
return joinValues(ctx, s, "@2", "d", "e", "f")
case "_3":
return joinValues(ctx, s, "@3", "_1", "_2")
case "_4":
return joinValues(ctx, s, "@3", "g", "err", "h")
default:
return logGenerator(ctx, s, name, strings.ToUpper(name), nil)
}
}
}
func logGenerator(ctx context.Context, s *memoize.Store, name string, v string, err error) *stringOrError {
w := s.Cached("logger").(io.Writer)
if err != nil {
fmt.Fprintf(w, "error %v = %v\n", name, err)
} else {
fmt.Fprintf(w, "simple %v = %v\n", name, v)
}
return &stringOrError{value: v, err: err}
}
func joinValues(ctx context.Context, s *memoize.Store, name string, keys ...string) *stringOrError {
w := s.Cached("logger").(io.Writer)
fmt.Fprintf(w, "start %v\n", name)
value := ""
for _, key := range keys {
v := asValue(s.Bind(key, generate(s, key)).Get(ctx))
if v == nil {
value = value + " <nil>"
} else if v.err != nil {
value = value + " !" + v.err.Error()
} else {
value = value + " " + v.value
}
}
fmt.Fprintf(w, "end %v = %v\n", name, value)
return &stringOrError{value: fmt.Sprintf("%s[%v]", name, value)}
}