go/callgraph/vta: allow nil initial call graph

When nil is passed as the initial call graph, vta will use a more
performant version of CHA. For this purpose, lazyCallees function of CHA
is exposed to VTA.

This change reduces the time and memory footprint for ~10%, measured on
several large real world Go projects.

Updates golang/go#57357

Change-Id: Ib5c5edca0026e6902e453fa10fc14f2b763849db
Reviewed-on: https://go-review.googlesource.com/c/tools/+/609978
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Alan Donovan <adonovan@google.com>
diff --git a/go/callgraph/cha/cha.go b/go/callgraph/cha/cha.go
index 3040f3d..67a0356 100644
--- a/go/callgraph/cha/cha.go
+++ b/go/callgraph/cha/cha.go
@@ -25,12 +25,10 @@
 // TODO(zpavlinovic): update CHA for how it handles generic function bodies.
 
 import (
-	"go/types"
-
 	"golang.org/x/tools/go/callgraph"
+	"golang.org/x/tools/go/callgraph/internal/chautil"
 	"golang.org/x/tools/go/ssa"
 	"golang.org/x/tools/go/ssa/ssautil"
-	"golang.org/x/tools/go/types/typeutil"
 )
 
 // CallGraph computes the call graph of the specified program using the
@@ -53,13 +51,6 @@
 		// (io.Writer).Write is assumed to call every concrete
 		// Write method in the program, the call graph can
 		// contain a lot of duplication.
-		//
-		// TODO(taking): opt: consider making lazyCallees public.
-		// Using the same benchmarks as callgraph_test.go, removing just
-		// the explicit callgraph.Graph construction is 4x less memory
-		// and is 37% faster.
-		// CHA			86 ms/op	16 MB/op
-		// lazyCallees	63 ms/op	 4 MB/op
 		for _, g := range callees {
 			addEdge(fnode, site, g)
 		}
@@ -83,82 +74,4 @@
 	return cg
 }
 
-// lazyCallees returns a function that maps a call site (in a function in fns)
-// to its callees within fns.
-//
-// The resulting function is not concurrency safe.
-func lazyCallees(fns map[*ssa.Function]bool) func(site ssa.CallInstruction) []*ssa.Function {
-	// funcsBySig contains all functions, keyed by signature.  It is
-	// the effective set of address-taken functions used to resolve
-	// a dynamic call of a particular signature.
-	var funcsBySig typeutil.Map // value is []*ssa.Function
-
-	// methodsByID contains all methods, grouped by ID for efficient
-	// lookup.
-	//
-	// We must key by ID, not name, for correct resolution of interface
-	// calls to a type with two (unexported) methods spelled the same but
-	// from different packages. The fact that the concrete type implements
-	// the interface does not mean the call dispatches to both methods.
-	methodsByID := make(map[string][]*ssa.Function)
-
-	// An imethod represents an interface method I.m.
-	// (There's no go/types object for it;
-	// a *types.Func may be shared by many interfaces due to interface embedding.)
-	type imethod struct {
-		I  *types.Interface
-		id string
-	}
-	// methodsMemo records, for every abstract method call I.m on
-	// interface type I, the set of concrete methods C.m of all
-	// types C that satisfy interface I.
-	//
-	// Abstract methods may be shared by several interfaces,
-	// hence we must pass I explicitly, not guess from m.
-	//
-	// methodsMemo is just a cache, so it needn't be a typeutil.Map.
-	methodsMemo := make(map[imethod][]*ssa.Function)
-	lookupMethods := func(I *types.Interface, m *types.Func) []*ssa.Function {
-		id := m.Id()
-		methods, ok := methodsMemo[imethod{I, id}]
-		if !ok {
-			for _, f := range methodsByID[id] {
-				C := f.Signature.Recv().Type() // named or *named
-				if types.Implements(C, I) {
-					methods = append(methods, f)
-				}
-			}
-			methodsMemo[imethod{I, id}] = methods
-		}
-		return methods
-	}
-
-	for f := range fns {
-		if f.Signature.Recv() == nil {
-			// Package initializers can never be address-taken.
-			if f.Name() == "init" && f.Synthetic == "package initializer" {
-				continue
-			}
-			funcs, _ := funcsBySig.At(f.Signature).([]*ssa.Function)
-			funcs = append(funcs, f)
-			funcsBySig.Set(f.Signature, funcs)
-		} else if obj := f.Object(); obj != nil {
-			id := obj.(*types.Func).Id()
-			methodsByID[id] = append(methodsByID[id], f)
-		}
-	}
-
-	return func(site ssa.CallInstruction) []*ssa.Function {
-		call := site.Common()
-		if call.IsInvoke() {
-			tiface := call.Value.Type().Underlying().(*types.Interface)
-			return lookupMethods(tiface, call.Method)
-		} else if g := call.StaticCallee(); g != nil {
-			return []*ssa.Function{g}
-		} else if _, ok := call.Value.(*ssa.Builtin); !ok {
-			fns, _ := funcsBySig.At(call.Signature()).([]*ssa.Function)
-			return fns
-		}
-		return nil
-	}
-}
+var lazyCallees = chautil.LazyCallees
diff --git a/go/callgraph/internal/chautil/lazy.go b/go/callgraph/internal/chautil/lazy.go
new file mode 100644
index 0000000..430bfea
--- /dev/null
+++ b/go/callgraph/internal/chautil/lazy.go
@@ -0,0 +1,96 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package chautil provides helper functions related to
+// class hierarchy analysis (CHA) for use in x/tools.
+package chautil
+
+import (
+	"go/types"
+
+	"golang.org/x/tools/go/ssa"
+	"golang.org/x/tools/go/types/typeutil"
+)
+
+// LazyCallees returns a function that maps a call site (in a function in fns)
+// to its callees within fns. The set of callees is computed using the CHA algorithm,
+// i.e., on the entire implements relation between interfaces and concrete types
+// in fns. Please see golang.org/x/tools/go/callgraph/cha for more information.
+//
+// The resulting function is not concurrency safe.
+func LazyCallees(fns map[*ssa.Function]bool) func(site ssa.CallInstruction) []*ssa.Function {
+	// funcsBySig contains all functions, keyed by signature.  It is
+	// the effective set of address-taken functions used to resolve
+	// a dynamic call of a particular signature.
+	var funcsBySig typeutil.Map // value is []*ssa.Function
+
+	// methodsByID contains all methods, grouped by ID for efficient
+	// lookup.
+	//
+	// We must key by ID, not name, for correct resolution of interface
+	// calls to a type with two (unexported) methods spelled the same but
+	// from different packages. The fact that the concrete type implements
+	// the interface does not mean the call dispatches to both methods.
+	methodsByID := make(map[string][]*ssa.Function)
+
+	// An imethod represents an interface method I.m.
+	// (There's no go/types object for it;
+	// a *types.Func may be shared by many interfaces due to interface embedding.)
+	type imethod struct {
+		I  *types.Interface
+		id string
+	}
+	// methodsMemo records, for every abstract method call I.m on
+	// interface type I, the set of concrete methods C.m of all
+	// types C that satisfy interface I.
+	//
+	// Abstract methods may be shared by several interfaces,
+	// hence we must pass I explicitly, not guess from m.
+	//
+	// methodsMemo is just a cache, so it needn't be a typeutil.Map.
+	methodsMemo := make(map[imethod][]*ssa.Function)
+	lookupMethods := func(I *types.Interface, m *types.Func) []*ssa.Function {
+		id := m.Id()
+		methods, ok := methodsMemo[imethod{I, id}]
+		if !ok {
+			for _, f := range methodsByID[id] {
+				C := f.Signature.Recv().Type() // named or *named
+				if types.Implements(C, I) {
+					methods = append(methods, f)
+				}
+			}
+			methodsMemo[imethod{I, id}] = methods
+		}
+		return methods
+	}
+
+	for f := range fns {
+		if f.Signature.Recv() == nil {
+			// Package initializers can never be address-taken.
+			if f.Name() == "init" && f.Synthetic == "package initializer" {
+				continue
+			}
+			funcs, _ := funcsBySig.At(f.Signature).([]*ssa.Function)
+			funcs = append(funcs, f)
+			funcsBySig.Set(f.Signature, funcs)
+		} else if obj := f.Object(); obj != nil {
+			id := obj.(*types.Func).Id()
+			methodsByID[id] = append(methodsByID[id], f)
+		}
+	}
+
+	return func(site ssa.CallInstruction) []*ssa.Function {
+		call := site.Common()
+		if call.IsInvoke() {
+			tiface := call.Value.Type().Underlying().(*types.Interface)
+			return lookupMethods(tiface, call.Method)
+		} else if g := call.StaticCallee(); g != nil {
+			return []*ssa.Function{g}
+		} else if _, ok := call.Value.(*ssa.Builtin); !ok {
+			fns, _ := funcsBySig.At(call.Signature()).([]*ssa.Function)
+			return fns
+		}
+		return nil
+	}
+}
diff --git a/go/callgraph/vta/graph.go b/go/callgraph/vta/graph.go
index 1eea423..1a9ed7c 100644
--- a/go/callgraph/vta/graph.go
+++ b/go/callgraph/vta/graph.go
@@ -9,7 +9,6 @@
 	"go/token"
 	"go/types"
 
-	"golang.org/x/tools/go/callgraph"
 	"golang.org/x/tools/go/ssa"
 	"golang.org/x/tools/go/types/typeutil"
 	"golang.org/x/tools/internal/aliases"
@@ -274,8 +273,8 @@
 // typePropGraph builds a VTA graph for a set of `funcs` and initial
 // `callgraph` needed to establish interprocedural edges. Returns the
 // graph and a map for unique type representatives.
-func typePropGraph(funcs map[*ssa.Function]bool, callgraph *callgraph.Graph) (vtaGraph, *typeutil.Map) {
-	b := builder{graph: make(vtaGraph), callGraph: callgraph}
+func typePropGraph(funcs map[*ssa.Function]bool, callees calleesFunc) (vtaGraph, *typeutil.Map) {
+	b := builder{graph: make(vtaGraph), callees: callees}
 	b.visit(funcs)
 	return b.graph, &b.canon
 }
@@ -283,8 +282,8 @@
 // Data structure responsible for linearly traversing the
 // code and building a VTA graph.
 type builder struct {
-	graph     vtaGraph
-	callGraph *callgraph.Graph // initial call graph for creating flows at unresolved call sites.
+	graph   vtaGraph
+	callees calleesFunc // initial call graph for creating flows at unresolved call sites.
 
 	// Specialized type map for canonicalization of types.Type.
 	// Semantically equivalent types can have different implementations,
@@ -598,7 +597,7 @@
 		return
 	}
 
-	siteCallees(c, b.callGraph)(func(f *ssa.Function) bool {
+	siteCallees(c, b.callees)(func(f *ssa.Function) bool {
 		addArgumentFlows(b, c, f)
 
 		site, ok := c.(ssa.Value)
diff --git a/go/callgraph/vta/graph_test.go b/go/callgraph/vta/graph_test.go
index 8ce4079..b32da4f 100644
--- a/go/callgraph/vta/graph_test.go
+++ b/go/callgraph/vta/graph_test.go
@@ -205,11 +205,21 @@
 				t.Fatalf("couldn't find want in `%s`", file)
 			}
 
-			g, _ := typePropGraph(ssautil.AllFunctions(prog), cha.CallGraph(prog))
+			fs := ssautil.AllFunctions(prog)
+
+			// First test propagation with lazy-CHA initial call graph.
+			g, _ := typePropGraph(fs, makeCalleesFunc(fs, nil))
 			got := vtaGraphStr(g)
 			if diff := setdiff(want, got); len(diff) > 0 {
 				t.Errorf("`%s`: want superset of %v;\n got %v\ndiff: %v", file, want, got, diff)
 			}
+
+			// Repeat the test with explicit CHA initial call graph.
+			g, _ = typePropGraph(fs, makeCalleesFunc(fs, cha.CallGraph(prog)))
+			got = vtaGraphStr(g)
+			if diff := setdiff(want, got); len(diff) > 0 {
+				t.Errorf("`%s`: want superset of %v;\n got %v\ndiff: %v", file, want, got, diff)
+			}
 		})
 	}
 }
diff --git a/go/callgraph/vta/initial.go b/go/callgraph/vta/initial.go
new file mode 100644
index 0000000..4dddc4e
--- /dev/null
+++ b/go/callgraph/vta/initial.go
@@ -0,0 +1,37 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package vta
+
+import (
+	"golang.org/x/tools/go/callgraph"
+	"golang.org/x/tools/go/callgraph/internal/chautil"
+	"golang.org/x/tools/go/ssa"
+)
+
+// calleesFunc abstracts call graph in one direction,
+// from call sites to callees.
+type calleesFunc func(ssa.CallInstruction) []*ssa.Function
+
+// makeCalleesFunc returns an initial call graph for vta as a
+// calleesFunc. If c is not nil, returns callees as given by c.
+// Otherwise, it returns chautil.LazyCallees over fs.
+func makeCalleesFunc(fs map[*ssa.Function]bool, c *callgraph.Graph) calleesFunc {
+	if c == nil {
+		return chautil.LazyCallees(fs)
+	}
+	return func(call ssa.CallInstruction) []*ssa.Function {
+		node := c.Nodes[call.Parent()]
+		if node == nil {
+			return nil
+		}
+		var cs []*ssa.Function
+		for _, edge := range node.Out {
+			if edge.Site == call {
+				cs = append(cs, edge.Callee.Func)
+			}
+		}
+		return cs
+	}
+}
diff --git a/go/callgraph/vta/utils.go b/go/callgraph/vta/utils.go
index 2792336..141eb07 100644
--- a/go/callgraph/vta/utils.go
+++ b/go/callgraph/vta/utils.go
@@ -7,7 +7,6 @@
 import (
 	"go/types"
 
-	"golang.org/x/tools/go/callgraph"
 	"golang.org/x/tools/go/ssa"
 	"golang.org/x/tools/internal/aliases"
 	"golang.org/x/tools/internal/typeparams"
@@ -149,22 +148,14 @@
 	}
 }
 
-// siteCallees returns a go1.23 iterator for the callees for call site `c`
-// given program `callgraph`.
-func siteCallees(c ssa.CallInstruction, callgraph *callgraph.Graph) func(yield func(*ssa.Function) bool) {
+// siteCallees returns a go1.23 iterator for the callees for call site `c`.
+func siteCallees(c ssa.CallInstruction, callees calleesFunc) func(yield func(*ssa.Function) bool) {
 	// TODO: when x/tools uses go1.23, change callers to use range-over-func
 	// (https://go.dev/issue/65237).
-	node := callgraph.Nodes[c.Parent()]
 	return func(yield func(*ssa.Function) bool) {
-		if node == nil {
-			return
-		}
-
-		for _, edge := range node.Out {
-			if edge.Site == c {
-				if !yield(edge.Callee.Func) {
-					return
-				}
+		for _, callee := range callees(c) {
+			if !yield(callee) {
+				return
 			}
 		}
 	}
diff --git a/go/callgraph/vta/vta.go b/go/callgraph/vta/vta.go
index 226f261..72bd4a4 100644
--- a/go/callgraph/vta/vta.go
+++ b/go/callgraph/vta/vta.go
@@ -65,17 +65,20 @@
 
 // CallGraph uses the VTA algorithm to compute call graph for all functions
 // f:true in funcs. VTA refines the results of initial call graph and uses it
-// to establish interprocedural type flow. The resulting graph does not have
-// a root node.
+// to establish interprocedural type flow. If initial is nil, VTA uses a more
+// efficient approach to construct a CHA call graph.
+//
+// The resulting graph does not have a root node.
 //
 // CallGraph does not make any assumptions on initial types global variables
 // and function/method inputs can have. CallGraph is then sound, modulo use of
 // reflection and unsafe, if the initial call graph is sound.
 func CallGraph(funcs map[*ssa.Function]bool, initial *callgraph.Graph) *callgraph.Graph {
-	vtaG, canon := typePropGraph(funcs, initial)
+	callees := makeCalleesFunc(funcs, initial)
+	vtaG, canon := typePropGraph(funcs, callees)
 	types := propagate(vtaG, canon)
 
-	c := &constructor{types: types, initial: initial, cache: make(methodCache)}
+	c := &constructor{types: types, callees: callees, cache: make(methodCache)}
 	return c.construct(funcs)
 }
 
@@ -85,7 +88,7 @@
 type constructor struct {
 	types   propTypeMap
 	cache   methodCache
-	initial *callgraph.Graph
+	callees calleesFunc
 }
 
 func (c *constructor) construct(funcs map[*ssa.Function]bool) *callgraph.Graph {
@@ -101,15 +104,15 @@
 func (c *constructor) constrct(g *callgraph.Graph, f *ssa.Function) {
 	caller := g.CreateNode(f)
 	for _, call := range calls(f) {
-		for _, c := range c.callees(call) {
+		for _, c := range c.resolves(call) {
 			callgraph.AddEdge(caller, call, g.CreateNode(c))
 		}
 	}
 }
 
-// callees computes the set of functions to which VTA resolves `c`. The resolved
-// functions are intersected with functions to which `initial` resolves `c`.
-func (c *constructor) callees(call ssa.CallInstruction) []*ssa.Function {
+// resolves computes the set of functions to which VTA resolves `c`. The resolved
+// functions are intersected with functions to which `c.initial` resolves `c`.
+func (c *constructor) resolves(call ssa.CallInstruction) []*ssa.Function {
 	cc := call.Common()
 	if cc.StaticCallee() != nil {
 		return []*ssa.Function{cc.StaticCallee()}
@@ -123,7 +126,7 @@
 	// Cover the case of dynamic higher-order and interface calls.
 	var res []*ssa.Function
 	resolved := resolve(call, c.types, c.cache)
-	siteCallees(call, c.initial)(func(f *ssa.Function) bool {
+	siteCallees(call, c.callees)(func(f *ssa.Function) bool {
 		if _, ok := resolved[f]; ok {
 			res = append(res, f)
 		}
diff --git a/go/callgraph/vta/vta_test.go b/go/callgraph/vta/vta_test.go
index 67db130..a6f2dcd 100644
--- a/go/callgraph/vta/vta_test.go
+++ b/go/callgraph/vta/vta_test.go
@@ -19,6 +19,14 @@
 )
 
 func TestVTACallGraph(t *testing.T) {
+	errDiff := func(want, got, missing []string) {
+		t.Errorf("got:\n%s\n\nwant:\n%s\n\nmissing:\n%s\n\ndiff:\n%s",
+			strings.Join(got, "\n"),
+			strings.Join(want, "\n"),
+			strings.Join(missing, "\n"),
+			cmp.Diff(got, want)) // to aid debugging
+	}
+
 	for _, file := range []string{
 		"testdata/src/callgraph_static.go",
 		"testdata/src/callgraph_ho.go",
@@ -46,14 +54,18 @@
 				t.Fatalf("couldn't find want in `%s`", file)
 			}
 
-			g := CallGraph(ssautil.AllFunctions(prog), cha.CallGraph(prog))
+			// First test VTA with lazy-CHA initial call graph.
+			g := CallGraph(ssautil.AllFunctions(prog), nil)
 			got := callGraphStr(g)
 			if missing := setdiff(want, got); len(missing) > 0 {
-				t.Errorf("got:\n%s\n\nwant:\n%s\n\nmissing:\n%s\n\ndiff:\n%s",
-					strings.Join(got, "\n"),
-					strings.Join(want, "\n"),
-					strings.Join(missing, "\n"),
-					cmp.Diff(got, want)) // to aid debugging
+				errDiff(want, got, missing)
+			}
+
+			// Repeat the test with explicit CHA initial call graph.
+			g = CallGraph(ssautil.AllFunctions(prog), cha.CallGraph(prog))
+			got = callGraphStr(g)
+			if missing := setdiff(want, got); len(missing) > 0 {
+				errDiff(want, got, missing)
 			}
 		})
 	}
@@ -168,7 +180,7 @@
 		t.Fatalf("couldn't find want in `%s`", file)
 	}
 
-	g, _ := typePropGraph(ssautil.AllFunctions(prog), cha.CallGraph(prog))
+	g, _ := typePropGraph(ssautil.AllFunctions(prog), makeCalleesFunc(nil, cha.CallGraph(prog)))
 	got := vtaGraphStr(g)
 	if diff := setdiff(want, got); len(diff) != 0 {
 		t.Errorf("`%s`: want superset of %v;\n got %v", file, want, got)