internal/memoize: document the complicated parts of the memoize package

This change is more of an exercise for myself to better understand the
implementation of the memoize package. It adds detailed documentation
for the get function in particular.

I also modified the tests to use a table-driven test format. I'm not
certain if this was the right approach (in case we want to add a
different type of test case in the future), but for now, it seems to
work fine.

Change-Id: I191a3b65af230e0af54b221c9f671582adec6c79
Reviewed-on: https://go-review.googlesource.com/c/tools/+/181685
Run-TryBot: Rebecca Stambler <rstambler@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Cottrell <iancottrell@google.com>
diff --git a/internal/memoize/memoize.go b/internal/memoize/memoize.go
index 457f914..6f9c621 100644
--- a/internal/memoize/memoize.go
+++ b/internal/memoize/memoize.go
@@ -2,13 +2,16 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-// Package memoize supports functions with idempotent results that are expensive
-// to compute having their return value memorized and returned again the next
-// time they are invoked.
-// The return values are only remembered for as long as there is still a user
-// to prevent excessive memory use.
+// Package memoize supports memoizing the return values of functions with
+// idempotent results that are expensive to compute.
+//
+// The memoizied result is returned again the next time the function is invoked.
+// To prevent excessive memory use, the return values are only remembered
+// for as long as they still have a user.
+//
 // To use this package, build a store and use it to aquire handles with the
 // Bind method.
+//
 package memoize
 
 import (
@@ -69,17 +72,17 @@
 }
 
 // Bind returns a handle for the given key and function.
-// Each call to bind will generate a new handle, but all the handles for a
-// single key will refer to the same value, and only the first handle to try to
-// get the value will cause the function to be invoked.
-// The results of the function are held for as long as there are handles through
-// which the result has been accessed.
+//
+// Each call to bind will generate a new handle.
+// All of of the handles for a single key will refer to the same value.
+// Only the first handle to get the value will cause the function to be invoked.
+// The value will be held for as long as there are handles through which it has been accessed.
 // Bind does not cause the value to be generated.
 func (s *Store) Bind(key interface{}, function Function) *Handle {
 	// panic early if the function is nil
 	// it would panic later anyway, but in a way that was much harder to debug
 	if function == nil {
-		panic("Function passed to bind must not be nil")
+		panic("the function passed to bind must not be nil")
 	}
 	// check if we already have the key
 	s.mu.Lock()
@@ -100,8 +103,9 @@
 }
 
 // Cached returns the value associated with a key.
-// It cannot cause the value to be generated, but will return the cached
-// value if present.
+//
+// It cannot cause the value to be generated.
+// It will return the cached value, if present.
 func (s *Store) Cached(key interface{}) interface{} {
 	s.mu.Lock()
 	defer s.mu.Unlock()
@@ -115,8 +119,9 @@
 }
 
 // Cached returns the value associated with a handle.
-// It will never cause the value to be generated, it will return the cached
-// value if present.
+//
+// It will never cause the value to be generated.
+// It will return the cached value, if present.
 func (h *Handle) Cached() interface{} {
 	h.mu.Lock()
 	defer h.mu.Unlock()
@@ -129,10 +134,10 @@
 }
 
 // Get returns the value associated with a handle.
+//
 // If the value is not yet ready, the underlying function will be invoked.
-// This makes this handle active, it will remember the value for as long as
-// it exists, and cause any other handles for the same key to also return the
-// same value.
+// This activates the handle, and it will remember the value for as long as it exists.
+// This will cause any other handles for the same key to also return the same value.
 func (h *Handle) Get(ctx context.Context) interface{} {
 	h.mu.Lock()
 	defer h.mu.Unlock()
@@ -149,51 +154,71 @@
 // get is the implementation of Get.
 func (e *entry) get(ctx context.Context, f Function) (interface{}, bool) {
 	e.mu.Lock()
+	// Note: This defer is not paired with the above lock.
 	defer e.mu.Unlock()
-	// fast path if we already have a value
+
+	// Fast path: If the entry is ready, it already has a value.
 	if e.ready {
 		return unref(e), true
 	}
-	// value is not ready, and we hold the lock
-	// see if the value is already being calculated
+	// Only begin evaluating the function if no other goroutine is doing so.
 	var value interface{}
 	if e.wait == nil {
 		e.wait = make(chan struct{})
 		go func() {
+			// Note: We do not hold the lock on the entry in this goroutine.
+			//
+			// We immediately defer signaling that the entry is ready,
+			// since we cannot guarantee that the function, f, will not panic.
 			defer func() {
+				// Note: We have to hold the entry's lock before returning.
 				close(e.wait)
 				e.wait = nil
 			}()
-			// e is not locked here
+
+			// Use the background context to avoid canceling the function.
+			// The function cannot be canceled even if the context is canceled
+			// because multiple goroutines may depend on it.
 			ctx := context.Background()
 			value = f(ctx)
-			// function is done, return to locked state so we can update the entry
+
+			// The function has completed. Update the value in the entry.
 			e.mu.Lock()
+
+			// Note: Because this defer will execute before the first defer,
+			// we will hold the lock while we update the entry's wait channel.
 			defer e.mu.Unlock()
 			setref(e, value)
 		}()
 	}
-	// get a local copy of wait while we still hold the lock
+
+	// Get a local copy of wait while we still hold the lock.
 	wait := e.wait
+
+	// Release the lock while we wait for the value.
 	e.mu.Unlock()
-	// release the lock while we wait
+
 	select {
 	case <-wait:
-		// we should now have a value
+		// We should now have a value. Lock the entry, and don't defer an unlock,
+		// since we already have done so at the beginning of this function.
 		e.mu.Lock()
 		result := unref(e)
-		// the keep alive makes sure value is not garbage collected before unref
+
+		// This keep alive makes sure that value is not garbage collected before
+		// we call unref and acquire a strong reference to it.
 		runtime.KeepAlive(value)
 		return result, true
 	case <-ctx.Done():
-		// our context was cancelled
+		// The context was canceled, but we have to lock the entry again,
+		// since we already deferred an unlock at the beginning of this function.
 		e.mu.Lock()
 		return nil, false
 	}
 }
 
-// setref is called to store a value into an entry
-// it must only be called when the lock is held
+// setref is called to store a weak reference to a value into an entry.
+// It assumes that the caller is holding the entry's lock.
 func setref(e *entry, value interface{}) interface{} {
 	// this is only called when the entry lock is already held
 	data := (*[2]uintptr)(unsafe.Pointer(&value))
@@ -201,22 +226,32 @@
 	e.typ, e.ptr = data[0], data[1]
 	e.ready = true
 	if e.ptr != 0 {
-		// and arrange to clear the weak reference if the object is collected
+		// Arrange to clear the weak reference when the object is garbage collected.
 		runtime.SetFinalizer(value, func(_ interface{}) {
-			// clear the now invalid non pointer
 			e.mu.Lock()
 			defer e.mu.Unlock()
+
+			// Clear the now-invalid non-pointer.
 			e.typ, e.ptr = 0, 0
+			// The value is no longer available.
 			e.ready = false
 		})
 	}
 	return value
 }
 
+// unref returns a strong reference to value stored in the given entry.
+// It assumes that the caller is holding the entry's lock.
 func unref(e *entry) interface{} {
 	// this is only called when the entry lock is already held
 	var v interface{}
 	data := (*[2]uintptr)(unsafe.Pointer(&v))
+
+	// Note: This approach for computing weak references and converting between
+	// weak and strong references would be rendered invalid if Go's runtime
+	// changed to allow moving objects on the heap.
+	// If such a change were to occur, some modifications would need to be made
+	// to this library.
 	data[0], data[1] = e.typ, e.ptr
 	return v
 }
diff --git a/internal/memoize/memoize_test.go b/internal/memoize/memoize_test.go
index a1041c0..3591988 100644
--- a/internal/memoize/memoize_test.go
+++ b/internal/memoize/memoize_test.go
@@ -18,54 +18,89 @@
 )
 
 func TestStore(t *testing.T) {
-	pinned := []string{"b", "_1", "_3"}
-	unpinned := []string{"a", "c", "d", "_2", "_4"}
 	ctx := context.Background()
 	s := &memoize.Store{}
 	logBuffer := &bytes.Buffer{}
+
 	s.Bind("logger", func(context.Context) interface{} { return logBuffer }).Get(ctx)
-	verifyBuffer := func(name, expect string) {
-		got := logBuffer.String()
-		if got != expect {
-			t.Errorf("at %q expected:\n%v\ngot:\n%s", name, expect, got)
-		}
-		logBuffer.Reset()
-	}
-	verifyBuffer("nothing", ``)
-	s.Bind("_1", generate(s, "_1")).Get(ctx)
-	verifyBuffer("get 1", `
+
+	// These tests check the behavior of the Bind and Get functions.
+	// They confirm that the functions only ever run once for a given value.
+	for _, test := range []struct {
+		name, key, want string
+	}{
+		{
+			name: "nothing",
+		},
+		{
+			name: "get 1",
+			key:  "_1",
+			want: `
 start @1
 simple a = A
 simple b = B
 simple c = C
 end @1 =  A B C
-`[1:])
-	s.Bind("_1", generate(s, "_1")).Get(ctx)
-	verifyBuffer("redo 1", ``)
-	s.Bind("_2", generate(s, "_2")).Get(ctx)
-	verifyBuffer("get 2", `
+`[1:],
+		},
+		{
+			name: "redo 1",
+			key:  "_1",
+			want: ``,
+		},
+		{
+			name: "get 2",
+			key:  "_2",
+			want: `
 start @2
 simple d = D
 simple e = E
 simple f = F
 end @2 =  D E F
-`[1:])
-	s.Bind("_2", generate(s, "_2")).Get(ctx)
-	verifyBuffer("redo 2", ``)
-	s.Bind("_3", generate(s, "_3")).Get(ctx)
-	verifyBuffer("get 3", `
+`[1:],
+		},
+		{
+			name: "redo 2",
+			key:  "_2",
+			want: ``,
+		},
+		{
+			name: "get 3",
+			key:  "_3",
+			want: `
 start @3
 end @3 =  @1[ A B C] @2[ D E F]
-`[1:])
-	s.Bind("_4", generate(s, "_4")).Get(ctx)
-	verifyBuffer("get 4", `
+`[1:],
+		},
+		{
+			name: "get 4",
+			key:  "_4",
+			want: `
 start @3
 simple g = G
 error ERR = fail
 simple h = H
 end @3 =  G !fail H
-`[1:])
+`[1:],
+		},
+	} {
+		s.Bind(test.key, generate(s, test.key)).Get(ctx)
+		got := logBuffer.String()
+		if got != test.want {
+			t.Errorf("at %q expected:\n%v\ngot:\n%s", test.name, test.want, got)
+		}
+		logBuffer.Reset()
+	}
 
+	// This test checks that values are garbage collected and removed from the
+	// store when they are no longer used.
+
+	pinned := []string{"b", "_1", "_3"}             // keys to pin in memory
+	unpinned := []string{"a", "c", "d", "_2", "_4"} // keys to garbage collect
+
+	// Handles maintain a strong reference to their values.
+	// By generating handles for the pinned keys and keeping the pins alive in memory,
+	// we ensure these keys stay cached.
 	var pins []*memoize.Handle
 	for _, key := range pinned {
 		h := s.Bind(key, generate(s, key))
@@ -73,42 +108,57 @@
 		pins = append(pins, h)
 	}
 
+	// Force the garbage collector to run.
 	runAllFinalizers(t)
 
+	// Confirm our expectation that pinned values should remain cached,
+	// and unpinned values should be garbage collected.
 	for _, k := range pinned {
 		if v := s.Cached(k); v == nil {
-			t.Errorf("Pinned value %q was nil", k)
+			t.Errorf("pinned value %q was nil", k)
 		}
 	}
 	for _, k := range unpinned {
 		if v := s.Cached(k); v != nil {
-			t.Errorf("Unpinned value %q was %q", k, v)
+			t.Errorf("unpinned value %q should be nil, was %q", k, v)
 		}
 	}
+
+	// This forces the pins to stay alive until this point in the function.
 	runtime.KeepAlive(pins)
 }
 
 func runAllFinalizers(t *testing.T) {
-	// the following is very tricky, be very careful changing it
-	// it relies on behavior of finalizers that is not guaranteed
-	// first run the GC to queue the finalizers
+	// The following is very tricky, so be very when careful changing it.
+	// It relies on behavior of finalizers that is not guaranteed.
+
+	// First, run the GC to queue the finalizers.
 	runtime.GC()
-	// wait is used to signal that the finalizers are all done
+
+	// wait is used to signal that the finalizers are all done.
 	wait := make(chan struct{})
-	// register a finalizer against an immediately collectible object
+
+	// Register a finalizer against an immediately collectible object.
+	//
+	// The finalizer will signal on the wait channel once it executes,
+	// and it was the most recently registered finalizer,
+	// so the wait channel will be closed when all of the finalizers have run.
 	runtime.SetFinalizer(&struct{ s string }{"obj"}, func(_ interface{}) { close(wait) })
-	// now run the GC again to pick up the tracker
+
+	// Now, run the GC again to pick up the tracker object above.
 	runtime.GC()
-	// now wait for the finalizers to run
+
+	// Wait for the finalizers to run or a timeout.
 	select {
 	case <-wait:
 	case <-time.Tick(time.Second):
-		t.Fatalf("Finalizers had not run after a second")
+		t.Fatalf("finalizers had not run after 1 second")
 	}
 }
 
 type stringOrError struct {
 	memoize.NoCopy
+
 	value string
 	err   error
 }
@@ -131,6 +181,8 @@
 	return func(ctx context.Context) interface{} {
 		name := key.(string)
 		switch name {
+		case "":
+			return nil
 		case "err":
 			return logGenerator(ctx, s, "ERR", "", fmt.Errorf("fail"))
 		case "_1":
@@ -147,8 +199,11 @@
 	}
 }
 
+// logGenerator generates a *stringOrError value, while logging to the store's logger.
 func logGenerator(ctx context.Context, s *memoize.Store, name string, v string, err error) *stringOrError {
+	// Get the logger from the store.
 	w := s.Cached("logger").(io.Writer)
+
 	if err != nil {
 		fmt.Fprintf(w, "error %v = %v\n", name, err)
 	} else {
@@ -157,8 +212,11 @@
 	return &stringOrError{value: v, err: err}
 }
 
+// joinValues binds a list of keys to their values, while logging to the store's logger.
 func joinValues(ctx context.Context, s *memoize.Store, name string, keys ...string) *stringOrError {
+	// Get the logger from the store.
 	w := s.Cached("logger").(io.Writer)
+
 	fmt.Fprintf(w, "start %v\n", name)
 	value := ""
 	for _, key := range keys {