vulncheck: add witness search logic for call stacks

The CL adds API for traversal of call graph in search of call stacks
that can serve as witnesses for uses of vulnerable symbols.

Cherry-picked: https://go-review.googlesource.com/c/exp/+/381777

Change-Id: Ib290716122b96f25e0f74a85f942dbde242ab04d
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/395058
Trust: Julie Qiu <julie@golang.org>
Run-TryBot: Julie Qiu <julie@golang.org>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/vulncheck/witness.go b/vulncheck/witness.go
index 0ef7003..57c74ca 100644
--- a/vulncheck/witness.go
+++ b/vulncheck/witness.go
@@ -6,6 +6,9 @@
 
 import (
 	"container/list"
+	"fmt"
+	"sort"
+	"strings"
 	"sync"
 )
 
@@ -14,18 +17,18 @@
 // known vulnerabilities.
 type ImportChain []*PkgNode
 
-// ImportChains performs a BFS search of res.RequireGraph for imports of vulnerable
-// packages. Search is performed for each vulnerable package in res.Vulns. The search
-// starts at a vulnerable package and goes up until reaching an entry package in
-// res.ImportGraph.Entries, hence producing an import chain. During the search, a
-// package is visited only once to avoid analyzing every possible import chain.
-// Hence, not all possible vulnerable import chains are reported.
-//
-// Note that the resulting map produces an import chain for each Vuln. Thus, a Vuln
-// with the same PkgPath will have the same list of identified import chains.
-//
-// The reported import chains are ordered by how seemingly easy is to understand
+// ImportChains lists import chains for each vulnerability in res. The
+// reported chains are ordered by how seemingly easy is to understand
 // them. Shorter import chains appear earlier in the returned slices.
+//
+// ImportChains does not list all import chains for a vulnerability.
+// It performs a BFS search of res.RequireGraph starting at a vulnerable
+// package and going up until reaching an entry package in res.ImportGraph.Entries.
+// During this search, a package is visited only once to avoid analyzing
+// every possible import chain.
+//
+// Note that the resulting map produces an import chain for each Vuln. Vulns
+// with the same PkgPath will have the same list of identified import chains.
 func ImportChains(res *Result) map[*Vuln][]ImportChain {
 	// Group vulns per package.
 	vPerPkg := make(map[int][]*Vuln)
@@ -126,3 +129,175 @@
 	// nil when the frame represents an entry point of the stack.
 	Call *CallSite
 }
+
+// CallStacks lists call stacks for each vulnerability in res. The listed call
+// stacks are ordered by how seemingly easy is to understand them. In general,
+// shorter call stacks with less dynamic call sites appear earlier in the returned
+// call stack slices.
+//
+// CallStacks does not report every possible call stack for a vulnerable symbol.
+// It performs a BFS search of res.CallGraph starting at the symbol and going up
+// until reaching an entry function or method in res.CallGraph.Entries. During
+// this search, each function is visited at most once to avoid potential
+// exponential explosion, thus skipping some call stacks.
+func CallStacks(res *Result) map[*Vuln][]CallStack {
+	var (
+		wg sync.WaitGroup
+		mu sync.Mutex
+	)
+	stacksPerVuln := make(map[*Vuln][]CallStack)
+	for _, vuln := range res.Vulns {
+		vuln := vuln
+		wg.Add(1)
+		go func() {
+			cs := callStacks(vuln.CallSink, res)
+			// sort call stacks by the estimated value to the user
+			sort.SliceStable(cs, func(i int, j int) bool { return stackLess(cs[i], cs[j]) })
+			mu.Lock()
+			stacksPerVuln[vuln] = cs
+			mu.Unlock()
+			wg.Done()
+		}()
+	}
+
+	wg.Wait()
+	return stacksPerVuln
+}
+
+// callStacks finds representative call stacks
+// for vulnerable symbol identified with vulnSinkID.
+func callStacks(vulnSinkID int, res *Result) []CallStack {
+	if vulnSinkID == 0 {
+		return nil
+	}
+
+	entries := make(map[int]bool)
+	for _, e := range res.Calls.Entries {
+		entries[e] = true
+	}
+
+	var stacks []CallStack
+	seen := make(map[int]bool)
+
+	queue := list.New()
+	queue.PushBack(&callChain{f: res.Calls.Functions[vulnSinkID]})
+
+	for queue.Len() > 0 {
+		front := queue.Front()
+		c := front.Value.(*callChain)
+		queue.Remove(front)
+
+		f := c.f
+		if seen[f.ID] {
+			continue
+		}
+		seen[f.ID] = true
+
+		for _, cs := range f.CallSites {
+			callee := res.Calls.Functions[cs.Parent]
+			nStack := &callChain{f: callee, call: cs, child: c}
+			if entries[callee.ID] {
+				stacks = append(stacks, nStack.CallStack())
+			}
+			queue.PushBack(nStack)
+		}
+	}
+	return stacks
+}
+
+// callChain models a chain of function calls.
+type callChain struct {
+	call  *CallSite // nil for entry points
+	f     *FuncNode
+	child *callChain
+}
+
+// CallStack converts callChain to CallStack type.
+func (c *callChain) CallStack() CallStack {
+	if c == nil {
+		return nil
+	}
+	return append(CallStack{StackEntry{Function: c.f, Call: c.call}}, c.child.CallStack()...)
+}
+
+// weight computes an approximate measure of how easy is to understand the call
+// stack when presented to the client as a witness. The smaller the value, the more
+// understandable the stack is. Currently defined as the number of unresolved
+// call sites in the stack.
+func weight(stack CallStack) int {
+	w := 0
+	for _, e := range stack {
+		if e.Call != nil && !e.Call.Resolved {
+			w += 1
+		}
+	}
+	return w
+}
+
+func isStdPackage(pkg string) bool {
+	if pkg == "" {
+		return false
+	}
+	// std packages do not have a "." in their path. For instance, see
+	// Contains in pkgsite/+/refs/heads/master/internal/stdlbib/stdlib.go.
+	if i := strings.IndexByte(pkg, '/'); i != -1 {
+		pkg = pkg[:i]
+	}
+	return !strings.Contains(pkg, ".")
+}
+
+// confidence computes an approximate measure of whether the stack
+// is realizeable in practice. Currently, it equals the number of call
+// sites in stack that go through standard libraries. Such call stacks
+// have been experimentally shown to often result in false positives.
+func confidence(stack CallStack) int {
+	c := 0
+	for _, e := range stack {
+		if isStdPackage(e.Function.PkgPath) {
+			c += 1
+		}
+	}
+	return c
+}
+
+// stackLess compares two call stacks in terms of their estimated
+// value to the user. Shorter stacks generally come earlier in the ordering.
+//
+// Two stacks are lexicographically ordered by:
+// 1) their estimated level of confidence in being a real call stack,
+// 2) their length, and 3) the number of dynamic call sites in the stack.
+func stackLess(s1, s2 CallStack) bool {
+	if c1, c2 := confidence(s1), confidence(s2); c1 != c2 {
+		return c1 < c2
+	}
+
+	if len(s1) != len(s2) {
+		return len(s1) < len(s2)
+	}
+
+	if w1, w2 := weight(s1), weight(s2); w1 != w2 {
+		return w1 < w2
+	}
+	// At this point we just need to make sure the ordering is deterministic.
+	// TODO(zpavlinovic): is there a more meaningful additional ordering?
+	return stackStrLess(s1, s2)
+}
+
+// stackStrLess compares string representation of stacks.
+func stackStrLess(s1, s2 CallStack) bool {
+	// Creates a unique string representation of a call stack
+	// for comparison purposes only.
+	stackStr := func(stack CallStack) string {
+		var stackStr []string
+		for _, cs := range stack {
+			s := cs.Function.String()
+			if cs.Call != nil && cs.Call.Pos != nil {
+				p := cs.Call.Pos
+				s = fmt.Sprintf("%s[%s:%d:%d:%d]", s, p.Filename, p.Line, p.Column, p.Offset)
+			}
+			stackStr = append(stackStr, s)
+		}
+		return strings.Join(stackStr, "->")
+	}
+	return strings.Compare(stackStr(s1), stackStr(s2)) <= 0
+}
diff --git a/vulncheck/witness_test.go b/vulncheck/witness_test.go
index eef6cf2..348b86d 100644
--- a/vulncheck/witness_test.go
+++ b/vulncheck/witness_test.go
@@ -1,3 +1,7 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
 package vulncheck
 
 import (
@@ -24,6 +28,24 @@
 	return m
 }
 
+// stacksToString converts map *Vuln:stacks to Vuln.Symbol:["f1->...->fN", ...]
+// string representation.
+func stacksToString(stacks map[*Vuln][]CallStack) map[string][]string {
+	m := make(map[string][]string)
+	for v, sts := range stacks {
+		var stsStr []string
+		for _, st := range sts {
+			var stStr []string
+			for _, call := range st {
+				stStr = append(stStr, call.Function.Name)
+			}
+			stsStr = append(stsStr, strings.Join(stStr, "->"))
+		}
+		m[v.Symbol] = stsStr
+	}
+	return m
+}
+
 func TestImportChains(t *testing.T) {
 	// Package import structure for the test program
 	//    entry1  entry2
@@ -61,3 +83,38 @@
 		t.Errorf("want %v; got %v", want, got)
 	}
 }
+
+func TestCallStacks(t *testing.T) {
+	// Call graph structure for the test program
+	//    entry1      entry2
+	//      |           |
+	//    interm1(std)  |
+	//      |    \     /
+	//      |   interm2(interface)
+	//      |   /     |
+	//     vuln1    vuln2
+	e1 := &FuncNode{ID: 1, Name: "entry1"}
+	e2 := &FuncNode{ID: 2, Name: "entry2"}
+	i1 := &FuncNode{ID: 3, Name: "interm1", PkgPath: "net/http", CallSites: []*CallSite{&CallSite{Parent: 1, Resolved: true}}}
+	i2 := &FuncNode{ID: 4, Name: "interm2", CallSites: []*CallSite{&CallSite{Parent: 2, Resolved: true}, &CallSite{Parent: 3, Resolved: true}}}
+	v1 := &FuncNode{ID: 5, Name: "vuln1", CallSites: []*CallSite{&CallSite{Parent: 3, Resolved: true}, &CallSite{Parent: 4, Resolved: false}}}
+	v2 := &FuncNode{ID: 6, Name: "vuln2", CallSites: []*CallSite{&CallSite{Parent: 4, Resolved: false}}}
+
+	cg := &CallGraph{
+		Functions: map[int]*FuncNode{1: e1, 2: e2, 3: i1, 4: i2, 5: v1, 6: v2},
+		Entries:   []int{1, 2},
+	}
+	vuln1 := &Vuln{CallSink: 5, Symbol: "vuln1"}
+	vuln2 := &Vuln{CallSink: 6, Symbol: "vuln2"}
+	res := &Result{Calls: cg, Vulns: []*Vuln{vuln1, vuln2}}
+
+	want := map[string][]string{
+		"vuln1": {"entry2->interm2->vuln1", "entry1->interm1->vuln1"},
+		"vuln2": {"entry2->interm2->vuln2", "entry1->interm1->interm2->vuln2"},
+	}
+
+	stacks := CallStacks(res)
+	if got := stacksToString(stacks); !reflect.DeepEqual(want, got) {
+		t.Errorf("want %v; got %v", want, got)
+	}
+}