internal/memoize: a new library to memoize functions

This library holds onto results with a weak reference, and guarantees that for
as long as
a result has not been garbage collected it will return the same result for the
same key.

Change-Id: I4a4528f31bf8bbf18809cbffe95dc93e05d769fe
Reviewed-on: https://go-review.googlesource.com/c/tools/+/180845
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
diff --git a/internal/memoize/memoize.go b/internal/memoize/memoize.go
new file mode 100644
index 0000000..457f914
--- /dev/null
+++ b/internal/memoize/memoize.go
@@ -0,0 +1,222 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package memoize supports functions with idempotent results that are expensive
+// to compute having their return value memorized and returned again the next
+// time they are invoked.
+// The return values are only remembered for as long as there is still a user
+// to prevent excessive memory use.
+// To use this package, build a store and use it to aquire handles with the
+// Bind method.
+package memoize
+
+import (
+	"context"
+	"runtime"
+	"sync"
+	"unsafe"
+)
+
+// Store binds keys to functions, returning handles that can be used to access
+// the functions results.
+type Store struct {
+	mu sync.Mutex
+	// entries is the set of values stored.
+	entries map[interface{}]*entry
+}
+
+// Function is the type for functions that can be memoized.
+// The result must be a pointer.
+type Function func(ctx context.Context) interface{}
+
+// Handle is returned from a store when a key is bound to a function.
+// It is then used to access the results of that function.
+type Handle struct {
+	mu       sync.Mutex
+	function Function
+	entry    *entry
+	value    interface{}
+}
+
+// entry holds the machinery to manage a function and its result such that
+// there is only one instance of the result live at any given time.
+type entry struct {
+	noCopy
+	// mu contols access to the typ and ptr fields
+	mu sync.Mutex
+	// the calculated value, as stored in an interface{}
+	typ, ptr uintptr
+	ready    bool
+	// wait is used to block until the value is ready
+	// will only be non nil if the generator is already running
+	wait chan struct{}
+}
+
+// Has returns true if they key is currently valid for this store.
+func (s *Store) Has(key interface{}) bool {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	_, found := s.entries[key]
+	return found
+}
+
+// Delete removes a key from the store, if present.
+func (s *Store) Delete(key interface{}) {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	delete(s.entries, key)
+}
+
+// Bind returns a handle for the given key and function.
+// Each call to bind will generate a new handle, but all the handles for a
+// single key will refer to the same value, and only the first handle to try to
+// get the value will cause the function to be invoked.
+// The results of the function are held for as long as there are handles through
+// which the result has been accessed.
+// Bind does not cause the value to be generated.
+func (s *Store) Bind(key interface{}, function Function) *Handle {
+	// panic early if the function is nil
+	// it would panic later anyway, but in a way that was much harder to debug
+	if function == nil {
+		panic("Function passed to bind must not be nil")
+	}
+	// check if we already have the key
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	e, found := s.entries[key]
+	if !found {
+		// we have not seen this key before, add a new entry
+		if s.entries == nil {
+			s.entries = make(map[interface{}]*entry)
+		}
+		e = &entry{}
+		s.entries[key] = e
+	}
+	return &Handle{
+		entry:    e,
+		function: function,
+	}
+}
+
+// Cached returns the value associated with a key.
+// It cannot cause the value to be generated, but will return the cached
+// value if present.
+func (s *Store) Cached(key interface{}) interface{} {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	e, found := s.entries[key]
+	if !found {
+		return nil
+	}
+	e.mu.Lock()
+	defer e.mu.Unlock()
+	return unref(e)
+}
+
+// Cached returns the value associated with a handle.
+// It will never cause the value to be generated, it will return the cached
+// value if present.
+func (h *Handle) Cached() interface{} {
+	h.mu.Lock()
+	defer h.mu.Unlock()
+	if h.value == nil {
+		h.entry.mu.Lock()
+		defer h.entry.mu.Unlock()
+		h.value = unref(h.entry)
+	}
+	return h.value
+}
+
+// Get returns the value associated with a handle.
+// If the value is not yet ready, the underlying function will be invoked.
+// This makes this handle active, it will remember the value for as long as
+// it exists, and cause any other handles for the same key to also return the
+// same value.
+func (h *Handle) Get(ctx context.Context) interface{} {
+	h.mu.Lock()
+	defer h.mu.Unlock()
+	if h.function != nil {
+		if v, ok := h.entry.get(ctx, h.function); ok {
+			h.value = v
+			h.function = nil
+			h.entry = nil
+		}
+	}
+	return h.value
+}
+
+// get is the implementation of Get.
+func (e *entry) get(ctx context.Context, f Function) (interface{}, bool) {
+	e.mu.Lock()
+	defer e.mu.Unlock()
+	// fast path if we already have a value
+	if e.ready {
+		return unref(e), true
+	}
+	// value is not ready, and we hold the lock
+	// see if the value is already being calculated
+	var value interface{}
+	if e.wait == nil {
+		e.wait = make(chan struct{})
+		go func() {
+			defer func() {
+				close(e.wait)
+				e.wait = nil
+			}()
+			// e is not locked here
+			ctx := context.Background()
+			value = f(ctx)
+			// function is done, return to locked state so we can update the entry
+			e.mu.Lock()
+			defer e.mu.Unlock()
+			setref(e, value)
+		}()
+	}
+	// get a local copy of wait while we still hold the lock
+	wait := e.wait
+	e.mu.Unlock()
+	// release the lock while we wait
+	select {
+	case <-wait:
+		// we should now have a value
+		e.mu.Lock()
+		result := unref(e)
+		// the keep alive makes sure value is not garbage collected before unref
+		runtime.KeepAlive(value)
+		return result, true
+	case <-ctx.Done():
+		// our context was cancelled
+		e.mu.Lock()
+		return nil, false
+	}
+}
+
+// setref is called to store a value into an entry
+// it must only be called when the lock is held
+func setref(e *entry, value interface{}) interface{} {
+	// this is only called when the entry lock is already held
+	data := (*[2]uintptr)(unsafe.Pointer(&value))
+	// store the value back to the entry as a weak reference
+	e.typ, e.ptr = data[0], data[1]
+	e.ready = true
+	if e.ptr != 0 {
+		// and arrange to clear the weak reference if the object is collected
+		runtime.SetFinalizer(value, func(_ interface{}) {
+			// clear the now invalid non pointer
+			e.mu.Lock()
+			defer e.mu.Unlock()
+			e.typ, e.ptr = 0, 0
+			e.ready = false
+		})
+	}
+	return value
+}
+
+func unref(e *entry) interface{} {
+	// this is only called when the entry lock is already held
+	var v interface{}
+	data := (*[2]uintptr)(unsafe.Pointer(&v))
+	data[0], data[1] = e.typ, e.ptr
+	return v
+}
diff --git a/internal/memoize/memoize_test.go b/internal/memoize/memoize_test.go
new file mode 100644
index 0000000..a1041c0
--- /dev/null
+++ b/internal/memoize/memoize_test.go
@@ -0,0 +1,176 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package memoize_test
+
+import (
+	"bytes"
+	"context"
+	"fmt"
+	"io"
+	"runtime"
+	"strings"
+	"testing"
+	"time"
+
+	"golang.org/x/tools/internal/memoize"
+)
+
+func TestStore(t *testing.T) {
+	pinned := []string{"b", "_1", "_3"}
+	unpinned := []string{"a", "c", "d", "_2", "_4"}
+	ctx := context.Background()
+	s := &memoize.Store{}
+	logBuffer := &bytes.Buffer{}
+	s.Bind("logger", func(context.Context) interface{} { return logBuffer }).Get(ctx)
+	verifyBuffer := func(name, expect string) {
+		got := logBuffer.String()
+		if got != expect {
+			t.Errorf("at %q expected:\n%v\ngot:\n%s", name, expect, got)
+		}
+		logBuffer.Reset()
+	}
+	verifyBuffer("nothing", ``)
+	s.Bind("_1", generate(s, "_1")).Get(ctx)
+	verifyBuffer("get 1", `
+start @1
+simple a = A
+simple b = B
+simple c = C
+end @1 =  A B C
+`[1:])
+	s.Bind("_1", generate(s, "_1")).Get(ctx)
+	verifyBuffer("redo 1", ``)
+	s.Bind("_2", generate(s, "_2")).Get(ctx)
+	verifyBuffer("get 2", `
+start @2
+simple d = D
+simple e = E
+simple f = F
+end @2 =  D E F
+`[1:])
+	s.Bind("_2", generate(s, "_2")).Get(ctx)
+	verifyBuffer("redo 2", ``)
+	s.Bind("_3", generate(s, "_3")).Get(ctx)
+	verifyBuffer("get 3", `
+start @3
+end @3 =  @1[ A B C] @2[ D E F]
+`[1:])
+	s.Bind("_4", generate(s, "_4")).Get(ctx)
+	verifyBuffer("get 4", `
+start @3
+simple g = G
+error ERR = fail
+simple h = H
+end @3 =  G !fail H
+`[1:])
+
+	var pins []*memoize.Handle
+	for _, key := range pinned {
+		h := s.Bind(key, generate(s, key))
+		h.Get(ctx)
+		pins = append(pins, h)
+	}
+
+	runAllFinalizers(t)
+
+	for _, k := range pinned {
+		if v := s.Cached(k); v == nil {
+			t.Errorf("Pinned value %q was nil", k)
+		}
+	}
+	for _, k := range unpinned {
+		if v := s.Cached(k); v != nil {
+			t.Errorf("Unpinned value %q was %q", k, v)
+		}
+	}
+	runtime.KeepAlive(pins)
+}
+
+func runAllFinalizers(t *testing.T) {
+	// the following is very tricky, be very careful changing it
+	// it relies on behavior of finalizers that is not guaranteed
+	// first run the GC to queue the finalizers
+	runtime.GC()
+	// wait is used to signal that the finalizers are all done
+	wait := make(chan struct{})
+	// register a finalizer against an immediately collectible object
+	runtime.SetFinalizer(&struct{ s string }{"obj"}, func(_ interface{}) { close(wait) })
+	// now run the GC again to pick up the tracker
+	runtime.GC()
+	// now wait for the finalizers to run
+	select {
+	case <-wait:
+	case <-time.Tick(time.Second):
+		t.Fatalf("Finalizers had not run after a second")
+	}
+}
+
+type stringOrError struct {
+	memoize.NoCopy
+	value string
+	err   error
+}
+
+func (v *stringOrError) String() string {
+	if v.err != nil {
+		return v.err.Error()
+	}
+	return v.value
+}
+
+func asValue(v interface{}) *stringOrError {
+	if v == nil {
+		return nil
+	}
+	return v.(*stringOrError)
+}
+
+func generate(s *memoize.Store, key interface{}) memoize.Function {
+	return func(ctx context.Context) interface{} {
+		name := key.(string)
+		switch name {
+		case "err":
+			return logGenerator(ctx, s, "ERR", "", fmt.Errorf("fail"))
+		case "_1":
+			return joinValues(ctx, s, "@1", "a", "b", "c")
+		case "_2":
+			return joinValues(ctx, s, "@2", "d", "e", "f")
+		case "_3":
+			return joinValues(ctx, s, "@3", "_1", "_2")
+		case "_4":
+			return joinValues(ctx, s, "@3", "g", "err", "h")
+		default:
+			return logGenerator(ctx, s, name, strings.ToUpper(name), nil)
+		}
+	}
+}
+
+func logGenerator(ctx context.Context, s *memoize.Store, name string, v string, err error) *stringOrError {
+	w := s.Cached("logger").(io.Writer)
+	if err != nil {
+		fmt.Fprintf(w, "error %v = %v\n", name, err)
+	} else {
+		fmt.Fprintf(w, "simple %v = %v\n", name, v)
+	}
+	return &stringOrError{value: v, err: err}
+}
+
+func joinValues(ctx context.Context, s *memoize.Store, name string, keys ...string) *stringOrError {
+	w := s.Cached("logger").(io.Writer)
+	fmt.Fprintf(w, "start %v\n", name)
+	value := ""
+	for _, key := range keys {
+		v := asValue(s.Bind(key, generate(s, key)).Get(ctx))
+		if v == nil {
+			value = value + " <nil>"
+		} else if v.err != nil {
+			value = value + " !" + v.err.Error()
+		} else {
+			value = value + " " + v.value
+		}
+	}
+	fmt.Fprintf(w, "end %v = %v\n", name, value)
+	return &stringOrError{value: fmt.Sprintf("%s[%v]", name, value)}
+}
diff --git a/internal/memoize/nocopy.go b/internal/memoize/nocopy.go
new file mode 100644
index 0000000..8913225
--- /dev/null
+++ b/internal/memoize/nocopy.go
@@ -0,0 +1,24 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package memoize
+
+// NoCopy is a type with no public methods that will trigger a vet check if it
+// is ever copied.
+// You can embed this in any type intended to be used as a value. This helps
+// avoid accidentally holding a copy of a value instead of the value itself.
+type NoCopy struct {
+	noCopy noCopy
+}
+
+// noCopy may be embedded into structs which must not be copied
+// after the first use.
+//
+// See https://golang.org/issues/8005#issuecomment-190753527
+// for details.
+type noCopy struct{}
+
+// Lock is a no-op used by -copylocks checker from `go vet`.
+func (*noCopy) Lock()   {}
+func (*noCopy) Unlock() {}