internal/memoize: add a final argument to Bind for cleaning up

With its memoization and refcounting, the cache is well suited to the
sharing of other expensive resources, specifically those that interact
with the file system. However, it provides no means to clean up those
resources when they are no longer needed.

Add an additional argument to Bind to clean up any values produced by
the bound function when they are no longer referenced.

For golang/go#41836

Change-Id: Icb2b12949de06f2ec7daf868f12a9c699540fa5b
Reviewed-on: https://go-review.googlesource.com/c/tools/+/263937
Run-TryBot: Robert Findley <rfindley@google.com>
gopls-CI: kokoro <noreply+kokoro@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Heschi Kreinick <heschi@google.com>
Trust: Robert Findley <rfindley@google.com>
diff --git a/internal/lsp/cache/analysis.go b/internal/lsp/cache/analysis.go
index 3257324..872f1d2 100644
--- a/internal/lsp/cache/analysis.go
+++ b/internal/lsp/cache/analysis.go
@@ -143,7 +143,7 @@
 			}
 		}
 		return runAnalysis(ctx, snapshot, a, pkg, results)
-	})
+	}, nil)
 	act.handle = h
 
 	act = s.addActionHandle(act)
diff --git a/internal/lsp/cache/check.go b/internal/lsp/cache/check.go
index 74a3ffb..187a391 100644
--- a/internal/lsp/cache/check.go
+++ b/internal/lsp/cache/check.go
@@ -99,7 +99,7 @@
 		wg.Wait()
 
 		return data
-	})
+	}, nil)
 	ph.handle = h
 
 	// Cache the handle in the snapshot. If a package handle has already
diff --git a/internal/lsp/cache/mod.go b/internal/lsp/cache/mod.go
index 0e4ed4e..95a9132 100644
--- a/internal/lsp/cache/mod.go
+++ b/internal/lsp/cache/mod.go
@@ -88,7 +88,7 @@
 			}
 		}
 		return data
-	})
+	}, nil)
 
 	pmh := &parseModHandle{handle: h}
 	s.mu.Lock()
@@ -252,7 +252,7 @@
 			why[req.Mod.Path] = whyList[i]
 		}
 		return &modWhyData{why: why}
-	})
+	}, nil)
 
 	mwh := &modWhyHandle{handle: h}
 	s.mu.Lock()
@@ -360,7 +360,7 @@
 		return &modUpgradeData{
 			upgrades: upgrades,
 		}
-	})
+	}, nil)
 	muh := &modUpgradeHandle{handle: h}
 	s.mu.Lock()
 	s.modUpgradeHandles[fh.URI()] = muh
diff --git a/internal/lsp/cache/mod_tidy.go b/internal/lsp/cache/mod_tidy.go
index 3cb55ab..5f735b9 100644
--- a/internal/lsp/cache/mod_tidy.go
+++ b/internal/lsp/cache/mod_tidy.go
@@ -150,7 +150,7 @@
 				TidiedContent: tempContents,
 			},
 		}
-	})
+	}, nil)
 
 	mth := &modTidyHandle{handle: h}
 	s.mu.Lock()
diff --git a/internal/lsp/cache/parse.go b/internal/lsp/cache/parse.go
index 95e9de5..0a2c371 100644
--- a/internal/lsp/cache/parse.go
+++ b/internal/lsp/cache/parse.go
@@ -61,12 +61,12 @@
 	parseHandle := s.generation.Bind(key, func(ctx context.Context, arg memoize.Arg) interface{} {
 		snapshot := arg.(*snapshot)
 		return parseGo(ctx, snapshot.view.session.cache.fset, fh, mode)
-	})
+	}, nil)
 
 	astHandle := s.generation.Bind(astCacheKey(key), func(ctx context.Context, arg memoize.Arg) interface{} {
 		snapshot := arg.(*snapshot)
 		return buildASTCache(ctx, snapshot, parseHandle)
-	})
+	}, nil)
 
 	pgh := &parseGoHandle{
 		handle:         parseHandle,
diff --git a/internal/lsp/cache/snapshot.go b/internal/lsp/cache/snapshot.go
index ff06817..e452df1 100644
--- a/internal/lsp/cache/snapshot.go
+++ b/internal/lsp/cache/snapshot.go
@@ -1399,7 +1399,7 @@
 				Package:    pkg,
 			},
 		}
-	})
+	}, nil)
 	s.builtin = &builtinPackageHandle{handle: h}
 	return nil
 }
diff --git a/internal/memoize/memoize.go b/internal/memoize/memoize.go
index 3c7d3c1..d4b8773 100644
--- a/internal/memoize/memoize.go
+++ b/internal/memoize/memoize.go
@@ -82,6 +82,9 @@
 			if len(e.generations) == 0 {
 				delete(g.store.handles, k)
 				e.state = stateDestroyed
+				if e.cleanup != nil && e.value != nil {
+					e.cleanup(e.value)
+				}
 			}
 		}
 		e.mu.Unlock()
@@ -150,16 +153,22 @@
 	function Function
 	// value is set in completed state.
 	value interface{}
+	// cleanup, if non-nil, is used to perform any necessary clean-up on values
+	// produced by function.
+	cleanup func(interface{})
 }
 
 // Bind returns a handle for the given key and function.
 //
-// Each call to bind will return the same handle if it is already bound.
-// Bind will always return a valid handle, creating one if needed.
-// Each key can only have one handle at any given time.
-// The value will be held at least until the associated generation is destroyed.
-// Bind does not cause the value to be generated.
-func (g *Generation) Bind(key interface{}, function Function) *Handle {
+// Each call to bind will return the same handle if it is already bound. Bind
+// will always return a valid handle, creating one if needed. Each key can
+// only have one handle at any given time. The value will be held at least
+// until the associated generation is destroyed. Bind does not cause the value
+// to be generated.
+//
+// If cleanup is non-nil, it will be called on any non-nil values produced by
+// function when they are no longer referenced.
+func (g *Generation) Bind(key interface{}, function Function, cleanup func(interface{})) *Handle {
 	// panic early if the function is nil
 	// it would panic later anyway, but in a way that was much harder to debug
 	if function == nil {
@@ -176,6 +185,7 @@
 			key:         key,
 			function:    function,
 			generations: map[*Generation]struct{}{g: {}},
+			cleanup:     cleanup,
 		}
 		g.store.handles[key] = h
 		return h
@@ -220,17 +230,19 @@
 	}
 }
 
-func (g *Generation) Inherit(h *Handle) {
-	if atomic.LoadUint32(&g.destroyed) != 0 {
-		panic("inherit on destroyed generation " + g.name)
-	}
+func (g *Generation) Inherit(hs ...*Handle) {
+	for _, h := range hs {
+		if atomic.LoadUint32(&g.destroyed) != 0 {
+			panic("inherit on destroyed generation " + g.name)
+		}
 
-	h.mu.Lock()
-	defer h.mu.Unlock()
-	if h.state == stateDestroyed {
-		panic(fmt.Sprintf("inheriting destroyed handle %#v (type %T) into generation %v", h.key, h.key, g.name))
+		h.mu.Lock()
+		defer h.mu.Unlock()
+		if h.state == stateDestroyed {
+			panic(fmt.Sprintf("inheriting destroyed handle %#v (type %T) into generation %v", h.key, h.key, g.name))
+		}
+		h.generations[g] = struct{}{}
 	}
-	h.generations[g] = struct{}{}
 }
 
 // Cached returns the value associated with a handle.
@@ -309,6 +321,11 @@
 		}
 		v := function(childCtx, arg)
 		if childCtx.Err() != nil {
+			// It's possible that v was computed despite the context cancellation. In
+			// this case we should ensure that it is cleaned up.
+			if h.cleanup != nil && v != nil {
+				h.cleanup(v)
+			}
 			return
 		}
 
@@ -319,8 +336,13 @@
 		// checked childCtx above. Even so, that should be harmless, since each
 		// run should produce the same results.
 		if h.state != stateRunning {
+			// v will never be used, so ensure that it is cleaned up.
+			if h.cleanup != nil && v != nil {
+				h.cleanup(v)
+			}
 			return
 		}
+		// At this point v will be cleaned up whenever h is destroyed.
 		h.value = v
 		h.function = nil
 		h.state = stateCompleted
diff --git a/internal/memoize/memoize_test.go b/internal/memoize/memoize_test.go
index e6e7b0b..41f20d0 100644
--- a/internal/memoize/memoize_test.go
+++ b/internal/memoize/memoize_test.go
@@ -21,7 +21,7 @@
 	h := g.Bind("key", func(context.Context, memoize.Arg) interface{} {
 		evaled++
 		return "res"
-	})
+	}, nil)
 	expectGet(t, h, g, "res")
 	expectGet(t, h, g, "res")
 	if evaled != 1 {
@@ -30,6 +30,7 @@
 }
 
 func expectGet(t *testing.T, h *memoize.Handle, g *memoize.Generation, wantV interface{}) {
+	t.Helper()
 	gotV, gotErr := h.Get(context.Background(), g, nil)
 	if gotV != wantV || gotErr != nil {
 		t.Fatalf("Get() = %v, %v, wanted %v, nil", gotV, gotErr, wantV)
@@ -42,11 +43,12 @@
 		t.Fatalf("Get() = %v, %v, wanted err %q", gotV, gotErr, substr)
 	}
 }
+
 func TestGenerations(t *testing.T) {
 	s := &memoize.Store{}
 	// Evaluate key in g1.
 	g1 := s.Generation("g1")
-	h1 := g1.Bind("key", func(context.Context, memoize.Arg) interface{} { return "res" })
+	h1 := g1.Bind("key", func(context.Context, memoize.Arg) interface{} { return "res" }, nil)
 	expectGet(t, h1, g1, "res")
 
 	// Get key in g2. It should inherit the value from g1.
@@ -54,7 +56,7 @@
 	h2 := g2.Bind("key", func(context.Context, memoize.Arg) interface{} {
 		t.Fatal("h2 should not need evaluation")
 		return "error"
-	})
+	}, nil)
 	expectGet(t, h2, g2, "res")
 
 	// With g1 destroyed, g2 should still work.
@@ -64,6 +66,42 @@
 	// With all generations destroyed, key should be re-evaluated.
 	g2.Destroy()
 	g3 := s.Generation("g3")
-	h3 := g3.Bind("key", func(context.Context, memoize.Arg) interface{} { return "new res" })
+	h3 := g3.Bind("key", func(context.Context, memoize.Arg) interface{} { return "new res" }, nil)
 	expectGet(t, h3, g3, "new res")
 }
+
+func TestCleanup(t *testing.T) {
+	s := &memoize.Store{}
+	g1 := s.Generation("g1")
+	v1 := false
+	v2 := false
+	cleanup := func(v interface{}) {
+		*(v.(*bool)) = true
+	}
+	h1 := g1.Bind("key1", func(context.Context, memoize.Arg) interface{} {
+		return &v1
+	}, nil)
+	h2 := g1.Bind("key2", func(context.Context, memoize.Arg) interface{} {
+		return &v2
+	}, cleanup)
+	expectGet(t, h1, g1, &v1)
+	expectGet(t, h2, g1, &v2)
+	g2 := s.Generation("g2")
+	g2.Inherit(h1, h2)
+
+	g1.Destroy()
+	expectGet(t, h1, g2, &v1)
+	expectGet(t, h2, g2, &v2)
+	for k, v := range map[string]*bool{"key1": &v1, "key2": &v2} {
+		if got, want := *v, false; got != want {
+			t.Errorf("after destroying g1, bound value %q is cleaned up", k)
+		}
+	}
+	g2.Destroy()
+	if got, want := v1, false; got != want {
+		t.Error("after destroying g2, v1 is cleaned up")
+	}
+	if got, want := v2, true; got != want {
+		t.Error("after destroying g2, v2 is not cleaned up")
+	}
+}