vulncheck: include more references to functions

Includes references to functions in operators of CallInstructions
when forward slicing for VTA.

Additionally avoids allocating slices of callees when forward
slicing. On net/http benchmark overall makes callgraph construction
5% faster and consume 2% less memory.

Change-Id: I52c775c397fb8ae06d6129957fd27d2516b8e740
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/460422
Run-TryBot: Tim King <taking@google.com>
Reviewed-by: Zvonimir Pavlinovic <zpavlinovic@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/vulncheck/slicing.go b/vulncheck/slicing.go
index 7f76787..71fffc4 100644
--- a/vulncheck/slicing.go
+++ b/vulncheck/slicing.go
@@ -9,39 +9,40 @@
 	"golang.org/x/tools/go/ssa"
 )
 
-// forwardReachableFrom computes the set of functions forward reachable from `sources`.
-// A function f is reachable from a function g if f is an anonymous function defined
-// in g or a function called in g as given by the callgraph `cg`.
-func forwardReachableFrom(sources map[*ssa.Function]bool, cg *callgraph.Graph) map[*ssa.Function]bool {
-	m := make(map[*ssa.Function]bool)
-	for s := range sources {
-		forward(s, cg, m)
-	}
-	return m
-}
+// forwardSlice computes the transitive closure of functions forward reachable
+// via calls in cg or referred to in an instruction starting from `sources`.
+func forwardSlice(sources map[*ssa.Function]bool, cg *callgraph.Graph) map[*ssa.Function]bool {
+	seen := make(map[*ssa.Function]bool)
+	var visit func(f *ssa.Function)
+	visit = func(f *ssa.Function) {
+		if seen[f] {
+			return
+		}
+		seen[f] = true
 
-func forward(f *ssa.Function, cg *callgraph.Graph, seen map[*ssa.Function]bool) {
-	if seen[f] {
-		return
-	}
-	seen[f] = true
-	var buf [10]*ssa.Value // avoid alloc in common case
-	for _, b := range f.Blocks {
-		for _, instr := range b.Instrs {
-			switch i := instr.(type) {
-			case ssa.CallInstruction:
-				for _, c := range siteCallees(i, cg) {
-					forward(c, cg, seen)
+		if n := cg.Nodes[f]; n != nil {
+			for _, e := range n.Out {
+				if e.Site != nil {
+					visit(e.Callee.Func)
 				}
-			default:
-				for _, op := range i.Operands(buf[:0]) {
+			}
+		}
+
+		var buf [10]*ssa.Value // avoid alloc in common case
+		for _, b := range f.Blocks {
+			for _, instr := range b.Instrs {
+				for _, op := range instr.Operands(buf[:0]) {
 					if fn, ok := (*op).(*ssa.Function); ok {
-						forward(fn, cg, seen)
+						visit(fn)
 					}
 				}
 			}
 		}
 	}
+	for source := range sources {
+		visit(source)
+	}
+	return seen
 }
 
 // pruneSet removes functions in `set` that are in `toPrune`.
diff --git a/vulncheck/slicing_test.go b/vulncheck/slicing_test.go
index a7e1b0b..dfaa241 100644
--- a/vulncheck/slicing_test.go
+++ b/vulncheck/slicing_test.go
@@ -100,7 +100,7 @@
 
 	pkg := ssaPkgs[0]
 	sources := map[*ssa.Function]bool{pkg.Func("Apply"): true, pkg.Func("Do"): true}
-	fs := funcNames(forwardReachableFrom(sources, cha.CallGraph(prog)))
+	fs := funcNames(forwardSlice(sources, cha.CallGraph(prog)))
 	want := map[string]bool{
 		"Apply":   true,
 		"Apply$1": true,
diff --git a/vulncheck/utils.go b/vulncheck/utils.go
index ee7ab9b..15dce72 100644
--- a/vulncheck/utils.go
+++ b/vulncheck/utils.go
@@ -70,7 +70,7 @@
 	initial := cha.CallGraph(prog)
 	allFuncs := ssautil.AllFunctions(prog)
 
-	fslice := forwardReachableFrom(entrySlice, initial)
+	fslice := forwardSlice(entrySlice, initial)
 	// Keep only actually linked functions.
 	pruneSet(fslice, allFuncs)
 
@@ -81,7 +81,7 @@
 
 	// Repeat the process once more, this time using
 	// the produced VTA call graph as the base graph.
-	fslice = forwardReachableFrom(entrySlice, vtaCg)
+	fslice = forwardSlice(entrySlice, vtaCg)
 	pruneSet(fslice, allFuncs)
 
 	if err := ctx.Err(); err != nil { // cancelled?
@@ -92,25 +92,6 @@
 	return cg, nil
 }
 
-// siteCallees computes a set of callees for call site `call` given program `callgraph`.
-func siteCallees(call ssa.CallInstruction, callgraph *callgraph.Graph) []*ssa.Function {
-	var matches []*ssa.Function
-
-	node := callgraph.Nodes[call.Parent()]
-	if node == nil {
-		return nil
-	}
-
-	for _, edge := range node.Out {
-		// Some callgraph analyses, such as CHA, might return synthetic (interface)
-		// methods as well as the concrete methods. Skip such synthetic functions.
-		if edge.Site == call {
-			matches = append(matches, edge.Callee.Func)
-		}
-	}
-	return matches
-}
-
 // dbTypeFormat formats the name of t according how types
 // are encoded in vulnerability database:
 //   - pointer designation * is skipped