vulncheck: include call graph edges that could be missed due to recursion

Consider the call graph G <--> F -> V where F is an entry point function
and V is a vulnerable function. The current algorithm for creation of
vulnerability graphs might start with F, visit V, figure out that V is
vulnerable and hence add F -> V to the vulnerability graph. Then, since
F is in this graph, G will be added too as G calls F and vice versa.

But if we start by analyzing first F -> G, then G won't be added to the
vulnerability graph as F is not yet added to that graph (since we have
not yet analyzed V). Hence, we might miss adding some edges to the
vulnerability graph.

Note that this bug does *not* miss vulnerabilities, just some paths from
entry points to vulnerabilities. In the above example, the missed paths
can be characterized with (F -> G)+ -> V.

The fix is to compute the vulnerability graph by 1) computing a
backwards slice from vulnerabilities, 2) computing a forward slice from
affected entry points, and 3) creating a vulnerability graph from the
intersection of the two call graph slices.

Note that imports and module vulnerability graph creation is not
affected as there can be no recursion in such graphs.

Change-Id: I4509a639ab60a5441b1998d56420f6cc3c38f960
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/411354
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/cmd/govulncheck/testdata/json.ct b/cmd/govulncheck/testdata/json.ct
index b19d49c..512a4ef 100644
--- a/cmd/govulncheck/testdata/json.ct
+++ b/cmd/govulncheck/testdata/json.ct
@@ -39,6 +39,19 @@
 		"Functions": {
 			"1": {
 				"ID": 1,
+				"Name": "main",
+				"RecvType": "",
+				"PkgPath": "vuln",
+				"Pos": {
+					"Filename": ".../vuln.go",
+					"Offset": 69,
+					"Line": 9,
+					"Column": 6
+				},
+				"CallSites": null
+			},
+			"2": {
+				"ID": 2,
 				"Name": "Parse",
 				"RecvType": "",
 				"PkgPath": "golang.org/x/text/language",
@@ -50,7 +63,7 @@
 				},
 				"CallSites": [
 					{
-						"Parent": 2,
+						"Parent": 1,
 						"Name": "Parse",
 						"RecvType": "",
 						"Pos": {
@@ -62,23 +75,10 @@
 						"Resolved": true
 					}
 				]
-			},
-			"2": {
-				"ID": 2,
-				"Name": "main",
-				"RecvType": "",
-				"PkgPath": "vuln",
-				"Pos": {
-					"Filename": ".../vuln.go",
-					"Offset": 69,
-					"Line": 9,
-					"Column": 6
-				},
-				"CallSites": null
 			}
 		},
 		"Entries": [
-			2
+			1
 		]
 	},
 	"Imports": {
@@ -183,7 +183,7 @@
 			"Symbol": "Parse",
 			"PkgPath": "golang.org/x/text/language",
 			"ModPath": "golang.org/x/text",
-			"CallSink": 1,
+			"CallSink": 2,
 			"ImportSink": 1,
 			"RequireSink": 1
 		}
diff --git a/vulncheck/source.go b/vulncheck/source.go
index 5a23e5c..79e718e 100644
--- a/vulncheck/source.go
+++ b/vulncheck/source.go
@@ -302,19 +302,79 @@
 	return id
 }
 
-func vulnCallGraphSlice(entries []*ssa.Function, modVulns moduleVulnerabilities, cg *callgraph.Graph, result *Result) {
-	// analyzedFuncs contains information on functions analyzed thus far.
-	// If a function is mapped to nil, this means it has been visited
-	// but it does not lead to a vulnerable call. Otherwise, a visited
-	// function is mapped to Calls function node.
-	analyzedFuncs := make(map[*ssa.Function]*FuncNode)
-	for _, entry := range entries {
-		// Top level entries that lead to vulnerable calls
-		// are stored as result.Calls graph entry points.
-		if e := vulnCallSlice(entry, modVulns, cg, result, analyzedFuncs); e != nil {
-			result.Calls.Entries = append(result.Calls.Entries, e.ID)
+// vulnCallGraphSlice checks if known vulnerabilities are transitively reachable from sources
+// via call graph cg. If so, populates result.Calls graph with this reachability information.
+func vulnCallGraphSlice(sources []*ssa.Function, modVulns moduleVulnerabilities, cg *callgraph.Graph, result *Result) {
+	sinksWithVulns := vulnFuncs(cg, modVulns)
+
+	// Compute call graph backwards reachable
+	// from vulnerable functions and methods.
+	var sinks []*callgraph.Node
+	for n := range sinksWithVulns {
+		sinks = append(sinks, n)
+	}
+	bcg := callGraphSlice(sinks, false)
+
+	// Interesect backwards call graph with forward
+	// reachable graph to remove redundant edges.
+	var filteredSources []*callgraph.Node
+	for _, e := range sources {
+		if n, ok := bcg.Nodes[e]; ok {
+			filteredSources = append(filteredSources, n)
 		}
 	}
+	fcg := callGraphSlice(filteredSources, true)
+
+	// Get the sinks that are in fact reachable from entry points.
+	filteredSinks := make(map[*callgraph.Node][]*osv.Entry)
+	for n, vs := range sinksWithVulns {
+		if fn, ok := fcg.Nodes[n.Func]; ok {
+			filteredSinks[fn] = vs
+		}
+	}
+
+	// Transform the resulting call graph slice into
+	// vulncheck representation and store it to result.
+	vulnCallGraph(filteredSources, filteredSinks, result)
+}
+
+// callGraphSlice computes a slice of callgraph beginning at starts
+// in the direction (forward/backward) controlled by forward flag.
+func callGraphSlice(starts []*callgraph.Node, forward bool) *callgraph.Graph {
+	g := &callgraph.Graph{Nodes: make(map[*ssa.Function]*callgraph.Node)}
+
+	visited := make(map[*callgraph.Node]bool)
+	var visit func(*callgraph.Node)
+	visit = func(n *callgraph.Node) {
+		if visited[n] {
+			return
+		}
+		visited[n] = true
+
+		var edges []*callgraph.Edge
+		if forward {
+			edges = n.Out
+		} else {
+			edges = n.In
+		}
+
+		for _, edge := range edges {
+			nCallee := g.CreateNode(edge.Callee.Func)
+			nCaller := g.CreateNode(edge.Caller.Func)
+			callgraph.AddEdge(nCaller, edge.Site, nCallee)
+
+			if forward {
+				visit(edge.Callee)
+			} else {
+				visit(edge.Caller)
+			}
+		}
+	}
+
+	for _, s := range starts {
+		visit(s)
+	}
+	return g
 }
 
 // funID is an id counter for nodes of Calls graph.
@@ -325,82 +385,93 @@
 	return funID
 }
 
-// vulnCallSlice checks if f has some vulnerabilities or transitively calls
-// a function with known vulnerabilities. If so, populates result.Calls
-// graph with this reachability information and returns the result.Call
-// function node. Otherwise, returns nil.
-func vulnCallSlice(f *ssa.Function, modVulns moduleVulnerabilities, cg *callgraph.Graph, result *Result, analyzed map[*ssa.Function]*FuncNode) *FuncNode {
-	if fn, ok := analyzed[f]; ok {
+// vulnCallGraph creates vulnerability call graph from sources -> sinks reachability info.
+func vulnCallGraph(sources []*callgraph.Node, sinks map[*callgraph.Node][]*osv.Entry, result *Result) {
+	nodes := make(map[*ssa.Function]*FuncNode)
+	createNode := func(f *ssa.Function) *FuncNode {
+		if fn, ok := nodes[f]; ok {
+			return fn
+		}
+		fn := funcNode(f)
+		nodes[f] = fn
+		result.Calls.Functions[fn.ID] = fn
 		return fn
 	}
 
-	fn := cg.Nodes[f]
-	if fn == nil {
-		return nil
+	// First create entries and sinks and store relevant information.
+	for _, s := range sources {
+		fn := createNode(s.Func)
+		result.Calls.Entries = append(result.Calls.Entries, fn.ID)
 	}
 
-	// Check if f has known vulnerabilities.
-	vulns := modVulns.vulnsForSymbol(pkgPath(f), dbFuncName(f))
+	for s, vulns := range sinks {
+		f := s.Func
+		funNode := createNode(s.Func)
 
-	var funNode *FuncNode
-	// If there are vulnerabilities for f, create node for f and
-	// save it immediately. This allows us to include F in the
-	// slice when analyzing chain V -> F -> V where V is vulnerable.
-	if len(vulns) > 0 {
-		funNode = funcNode(f)
-	}
-	analyzed[f] = funNode
-
-	// Recursively compute which callees lead to a call of a
-	// vulnerable function. Remember the nodes of such callees.
-	type siteNode struct {
-		call ssa.CallInstruction
-		fn   *FuncNode
-	}
-	var onSlice []siteNode
-	for _, edge := range fn.Out {
-		if calleeNode := vulnCallSlice(edge.Callee.Func, modVulns, cg, result, analyzed); calleeNode != nil {
-			onSlice = append(onSlice, siteNode{call: edge.Site, fn: calleeNode})
-		}
-	}
-
-	// If f is not vulnerable nor it transitively leads
-	// to vulnerable calls, jump out.
-	if len(onSlice) == 0 && len(vulns) == 0 {
-		return nil
-	}
-
-	// If f is not vulnerable, then at this point it has
-	// to be on the path leading to a vulnerable call.
-	if funNode == nil {
-		funNode = funcNode(f)
-		analyzed[f] = funNode
-	}
-	result.Calls.Functions[funNode.ID] = funNode
-
-	// Save node predecessor information.
-	for _, calleeSliceInfo := range onSlice {
-		call, node := calleeSliceInfo.call, calleeSliceInfo.fn
-		cs := &CallSite{
-			Parent:   funNode.ID,
-			Name:     call.Common().Value.Name(),
-			RecvType: callRecvType(call),
-			Resolved: resolved(call),
-			Pos:      instrPosition(call),
-		}
-		node.CallSites = append(node.CallSites, cs)
-	}
-
-	// Populate CallSink field for each detected vuln symbol.
-	for _, osv := range vulns {
-		for _, affected := range osv.Affected {
-			if affected.Package.Name != funNode.PkgPath {
-				continue
+		// Populate CallSink field for each detected vuln symbol.
+		for _, osv := range vulns {
+			for _, affected := range osv.Affected {
+				if affected.Package.Name != funNode.PkgPath {
+					continue
+				}
+				addCallSinkForVuln(funNode.ID, osv, dbFuncName(f), funNode.PkgPath, result)
 			}
-			addCallSinkForVuln(funNode.ID, osv, dbFuncName(f), funNode.PkgPath, result)
 		}
 	}
-	return funNode
+
+	visited := make(map[*callgraph.Node]bool)
+	var visit func(*callgraph.Node)
+	visit = func(n *callgraph.Node) {
+		if visited[n] {
+			return
+		}
+		visited[n] = true
+
+		// make the resulting graph deterministic
+		// in the ordering of call graph edges.
+		sortEdges(n.In)
+		for _, edge := range n.In {
+			nCallee := createNode(edge.Callee.Func)
+			nCaller := createNode(edge.Caller.Func)
+
+			call := edge.Site
+			cs := &CallSite{
+				Parent:   nCaller.ID,
+				Name:     call.Common().Value.Name(),
+				RecvType: callRecvType(call),
+				Resolved: resolved(call),
+				Pos:      instrPosition(call),
+			}
+			nCallee.CallSites = append(nCallee.CallSites, cs)
+
+			visit(edge.Caller)
+		}
+	}
+
+	for s := range sinks {
+		visit(s)
+	}
+}
+
+// sortEdges sorts edges by their string representation that takes
+// into account caller, callee, and the call site.
+func sortEdges(edges []*callgraph.Edge) {
+	str := func(e *callgraph.Edge) string {
+		return fmt.Sprintf("%v[%v]%v", e.Caller, e.Site, e.Callee)
+	}
+	sort.SliceStable(edges, func(i, j int) bool { return str(edges[i]) < str(edges[j]) })
+}
+
+// vulnFuncs returns vulnerability information for vulnerable functions in cg.
+func vulnFuncs(cg *callgraph.Graph, modVulns moduleVulnerabilities) map[*callgraph.Node][]*osv.Entry {
+	m := make(map[*callgraph.Node][]*osv.Entry)
+	for f, n := range cg.Nodes {
+		vulns := modVulns.vulnsForSymbol(pkgPath(f), dbFuncName(f))
+		if len(vulns) > 0 {
+			m[n] = vulns
+		}
+	}
+	return m
 }
 
 // pkgPath returns the path of the f's enclosing package, if any.
diff --git a/vulncheck/source_test.go b/vulncheck/source_test.go
index 4687ac9..60bd3e9 100644
--- a/vulncheck/source_test.go
+++ b/vulncheck/source_test.go
@@ -649,3 +649,61 @@
 		t.Errorf("want stack of length 2; got stack of length %v", len(stack))
 	}
 }
+
+func TestRecursion(t *testing.T) {
+	e := packagestest.Export(t, packagestest.Modules, []packagestest.Module{
+		{
+			Name: "golang.org/entry",
+			Files: map[string]interface{}{
+				"x/x.go": `
+			package x
+
+			import "golang.org/bmod/bvuln"
+
+
+			func X() {
+				y()
+				bvuln.Vuln()
+				z()
+			}
+
+			func y() {
+				X()
+			}
+
+			func z() {}
+			`,
+			},
+		},
+		{
+			Name: "golang.org/bmod@v0.5.0",
+			Files: map[string]interface{}{"bvuln/bvuln.go": `
+			package bvuln
+
+			func Vuln() {}
+			`},
+		},
+	})
+	defer e.Cleanup()
+
+	// Load x as entry package.
+	pkgs, err := loadPackages(e, path.Join(e.Temp(), "entry/x"))
+	if err != nil {
+		t.Fatal(err)
+	}
+	if len(pkgs) != 1 {
+		t.Fatal("failed to load x test package")
+	}
+
+	cfg := &Config{
+		Client: testClient,
+	}
+	result, err := Source(context.Background(), Convert(pkgs), cfg)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if l := len(result.Calls.Functions); l != 3 {
+		t.Errorf("want 3 functions (X, y, Vuln) in vulnerability graph; got %v", l)
+	}
+}