vulncheck: make call stack search faster while preserving determinism

We sort edges while doing call stack search instead of sorting edges in
the vulnerability call graph (which is not promised to clients anyhow).
We can use the structure of call stack search to do sorting in a smart
way.

This reduces k8s and vault times by 7 and 5 seconds, resp.

Change-Id: I46b6623fd6543fdef898d991b7f29f228ca59d91
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/412194
Run-TryBot: Zvonimir Pavlinovic <zpavlinovic@google.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/vulncheck/source.go b/vulncheck/source.go
index 79e718e..157d4dd 100644
--- a/vulncheck/source.go
+++ b/vulncheck/source.go
@@ -427,9 +427,6 @@
 		}
 		visited[n] = true
 
-		// make the resulting graph deterministic
-		// in the ordering of call graph edges.
-		sortEdges(n.In)
 		for _, edge := range n.In {
 			nCallee := createNode(edge.Callee.Func)
 			nCaller := createNode(edge.Caller.Func)
@@ -453,15 +450,6 @@
 	}
 }
 
-// sortEdges sorts edges by their string representation that takes
-// into account caller, callee, and the call site.
-func sortEdges(edges []*callgraph.Edge) {
-	str := func(e *callgraph.Edge) string {
-		return fmt.Sprintf("%v[%v]%v", e.Caller, e.Site, e.Callee)
-	}
-	sort.SliceStable(edges, func(i, j int) bool { return str(edges[i]) < str(edges[j]) })
-}
-
 // vulnFuncs returns vulnerability information for vulnerable functions in cg.
 func vulnFuncs(cg *callgraph.Graph, modVulns moduleVulnerabilities) map[*callgraph.Node][]*osv.Entry {
 	m := make(map[*callgraph.Node][]*osv.Entry)
diff --git a/vulncheck/witness.go b/vulncheck/witness.go
index c46a8fd..75f1722 100644
--- a/vulncheck/witness.go
+++ b/vulncheck/witness.go
@@ -6,6 +6,8 @@
 
 import (
 	"container/list"
+	"fmt"
+	"go/token"
 	"sort"
 	"strings"
 	"sync"
@@ -194,7 +196,9 @@
 		}
 		seen[f.ID] = true
 
-		for _, cs := range f.CallSites {
+		// Pick a single call site for each function in determinstic order.
+		// A single call site is sufficient as we visit a function only once.
+		for _, cs := range callsites(f.CallSites, res, seen) {
 			callee := res.Calls.Functions[cs.Parent]
 			nStack := &callChain{f: callee, call: cs, child: c}
 			if entries[callee.ID] {
@@ -206,6 +210,34 @@
 	return stacks
 }
 
+// callsites picks a call site from sites for each non-visited function.
+// For each such function, the smallest (posLess) call site is chosen. The
+// returned slice is sorted by caller functions (funcLess). Assumes callee
+// of each call site is the same.
+func callsites(sites []*CallSite, result *Result, visited map[int]bool) []*CallSite {
+	minCs := make(map[int]*CallSite)
+	for _, cs := range sites {
+		if visited[cs.Parent] {
+			continue
+		}
+		if csLess(cs, minCs[cs.Parent]) {
+			minCs[cs.Parent] = cs
+		}
+	}
+
+	var fs []*FuncNode
+	for id := range minCs {
+		fs = append(fs, result.Calls.Functions[id])
+	}
+	sort.SliceStable(fs, func(i, j int) bool { return funcLess(fs[i], fs[j]) })
+
+	var css []*CallSite
+	for _, f := range fs {
+		css = append(css, minCs[f.ID])
+	}
+	return css
+}
+
 // callChain models a chain of function calls.
 type callChain struct {
 	call  *CallSite // nil for entry points
@@ -286,3 +318,77 @@
 	// search algorithm.
 	return true
 }
+
+// csLess compares two call sites by their locations and, if needed,
+// their string representation.
+func csLess(cs1, cs2 *CallSite) bool {
+	if cs2 == nil {
+		return true
+	}
+
+	// fast code path
+	if p1, p2 := cs1.Pos, cs2.Pos; p1 != nil && p2 != nil {
+		if posLess(*p1, *p2) {
+			return true
+		}
+		if posLess(*p2, *p1) {
+			return false
+		}
+		// for sanity, should not occur in practice
+		return fmt.Sprintf("%v.%v", cs1.RecvType, cs2.Name) < fmt.Sprintf("%v.%v", cs2.RecvType, cs2.Name)
+	}
+
+	// code path rarely exercised
+	if cs2.Pos == nil {
+		return true
+	}
+	if cs1.Pos == nil {
+		return false
+	}
+	// should very rarely occur in practice
+	return fmt.Sprintf("%v.%v", cs1.RecvType, cs2.Name) < fmt.Sprintf("%v.%v", cs2.RecvType, cs2.Name)
+}
+
+// posLess compares two positions by their line and column number,
+// and filename if needed.
+func posLess(p1, p2 token.Position) bool {
+	if p1.Line < p2.Line {
+		return true
+	}
+	if p2.Line < p1.Line {
+		return false
+	}
+
+	if p1.Column < p2.Column {
+		return true
+	}
+	if p2.Column < p1.Column {
+		return false
+	}
+
+	return strings.Compare(p1.Filename, p2.Filename) == -1
+}
+
+// funcLess compares two function nodes by locations of
+// corresponding functions and, if needed, their string representation.
+func funcLess(f1, f2 *FuncNode) bool {
+	if p1, p2 := f1.Pos, f2.Pos; p1 != nil && p2 != nil {
+		if posLess(*p1, *p2) {
+			return true
+		}
+		if posLess(*p2, *p1) {
+			return false
+		}
+		// for sanity, should not occur in practice
+		return f1.String() < f2.String()
+	}
+
+	if f2.Pos == nil {
+		return true
+	}
+	if f1.Pos == nil {
+		return false
+	}
+	// should happen only for inits
+	return f1.String() < f2.String()
+}