vulncheck: add witness search logic for import chains

Cherry-picked: https://go-review.googlesource.com/c/exp/+/379994

Change-Id: I3afa6b61b723fe461f8483bfaabcbc75af7c9104
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/395057
Trust: Julie Qiu <julie@golang.org>
Run-TryBot: Julie Qiu <julie@golang.org>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/vulncheck/witness.go b/vulncheck/witness.go
index 8e2d420..0ef7003 100644
--- a/vulncheck/witness.go
+++ b/vulncheck/witness.go
@@ -4,11 +4,114 @@
 
 package vulncheck
 
+import (
+	"container/list"
+	"sync"
+)
+
 // ImportChain is sequence of import paths starting with
 // a client package and ending with a package with some
 // known vulnerabilities.
 type ImportChain []*PkgNode
 
+// ImportChains performs a BFS search of res.RequireGraph for imports of vulnerable
+// packages. Search is performed for each vulnerable package in res.Vulns. The search
+// starts at a vulnerable package and goes up until reaching an entry package in
+// res.ImportGraph.Entries, hence producing an import chain. During the search, a
+// package is visited only once to avoid analyzing every possible import chain.
+// Hence, not all possible vulnerable import chains are reported.
+//
+// Note that the resulting map produces an import chain for each Vuln. Thus, a Vuln
+// with the same PkgPath will have the same list of identified import chains.
+//
+// The reported import chains are ordered by how seemingly easy is to understand
+// them. Shorter import chains appear earlier in the returned slices.
+func ImportChains(res *Result) map[*Vuln][]ImportChain {
+	// Group vulns per package.
+	vPerPkg := make(map[int][]*Vuln)
+	for _, v := range res.Vulns {
+		vPerPkg[v.ImportSink] = append(vPerPkg[v.ImportSink], v)
+	}
+
+	// Collect chains in parallel for every package path.
+	var wg sync.WaitGroup
+	var mu sync.Mutex
+	chains := make(map[*Vuln][]ImportChain)
+	for pkgID, vulns := range vPerPkg {
+		pID := pkgID
+		vs := vulns
+		wg.Add(1)
+		go func() {
+			pChains := importChains(pID, res)
+			mu.Lock()
+			for _, v := range vs {
+				chains[v] = pChains
+			}
+			mu.Unlock()
+			wg.Done()
+		}()
+	}
+	wg.Wait()
+	return chains
+}
+
+// importChains finds representative chains of package imports
+// leading to vulnerable package identified with vulnSinkID.
+func importChains(vulnSinkID int, res *Result) []ImportChain {
+	if vulnSinkID == 0 {
+		return nil
+	}
+
+	// Entry packages, needed for finalizing chains.
+	entries := make(map[int]bool)
+	for _, e := range res.Imports.Entries {
+		entries[e] = true
+	}
+
+	var chains []ImportChain
+	seen := make(map[int]bool)
+
+	queue := list.New()
+	queue.PushBack(&importChain{pkg: res.Imports.Packages[vulnSinkID]})
+	for queue.Len() > 0 {
+		front := queue.Front()
+		c := front.Value.(*importChain)
+		queue.Remove(front)
+
+		pkg := c.pkg
+		if seen[pkg.ID] {
+			continue
+		}
+		seen[pkg.ID] = true
+
+		for _, impBy := range pkg.ImportedBy {
+			imp := res.Imports.Packages[impBy]
+			newC := &importChain{pkg: imp, child: c}
+			// If the next package is an entry, we have
+			// a chain to report.
+			if entries[imp.ID] {
+				chains = append(chains, newC.ImportChain())
+			}
+			queue.PushBack(newC)
+		}
+	}
+	return chains
+}
+
+// importChain models an chain of package imports.
+type importChain struct {
+	pkg   *PkgNode
+	child *importChain
+}
+
+// ImportChain converts importChain to ImportChain type.
+func (r *importChain) ImportChain() ImportChain {
+	if r == nil {
+		return nil
+	}
+	return append([]*PkgNode{r.pkg}, r.child.ImportChain()...)
+}
+
 // CallStack models a trace of function calls starting
 // with a client function or method and ending with a
 // call to a vulnerable symbol.
diff --git a/vulncheck/witness_test.go b/vulncheck/witness_test.go
new file mode 100644
index 0000000..eef6cf2
--- /dev/null
+++ b/vulncheck/witness_test.go
@@ -0,0 +1,63 @@
+package vulncheck
+
+import (
+	"reflect"
+	"strings"
+	"testing"
+)
+
+// chainsToString converts map Vuln:chains to Vuln.PkgPath:["pkg1->...->pkgN", ...]
+// string representation.
+func chainsToString(chains map[*Vuln][]ImportChain) map[string][]string {
+	m := make(map[string][]string)
+	for v, chs := range chains {
+		var chsStr []string
+		for _, ch := range chs {
+			var chStr []string
+			for _, imp := range ch {
+				chStr = append(chStr, imp.Path)
+			}
+			chsStr = append(chsStr, strings.Join(chStr, "->"))
+		}
+		m[v.PkgPath] = chsStr
+	}
+	return m
+}
+
+func TestImportChains(t *testing.T) {
+	// Package import structure for the test program
+	//    entry1  entry2
+	//      |       |
+	//    interm1   |
+	//      |    \  |
+	//      |   interm2
+	//      |   /     |
+	//     vuln1    vuln2
+	e1 := &PkgNode{ID: 1, Path: "entry1"}
+	e2 := &PkgNode{ID: 2, Path: "entry2"}
+	i1 := &PkgNode{ID: 3, Path: "interm1", ImportedBy: []int{1}}
+	i2 := &PkgNode{ID: 4, Path: "interm2", ImportedBy: []int{2, 3}}
+	v1 := &PkgNode{ID: 5, Path: "vuln1", ImportedBy: []int{3, 4}}
+	v2 := &PkgNode{ID: 6, Path: "vuln2", ImportedBy: []int{4}}
+
+	ig := &ImportGraph{
+		Packages: map[int]*PkgNode{1: e1, 2: e2, 3: i1, 4: i2, 5: v1, 6: v2},
+		Entries:  []int{1, 2},
+	}
+	vuln1 := &Vuln{ImportSink: 5, PkgPath: "vuln1"}
+	vuln2 := &Vuln{ImportSink: 6, PkgPath: "vuln2"}
+	res := &Result{Imports: ig, Vulns: []*Vuln{vuln1, vuln2}}
+
+	// The chain entry1->interm1->interm2->vuln1 is not reported
+	// as there exist a shorter trace going from entry1 to vuln1
+	// via interm1.
+	want := map[string][]string{
+		"vuln1": {"entry1->interm1->vuln1", "entry2->interm2->vuln1"},
+		"vuln2": {"entry2->interm2->vuln2", "entry1->interm1->interm2->vuln2"},
+	}
+
+	chains := ImportChains(res)
+	if got := chainsToString(chains); !reflect.DeepEqual(want, got) {
+		t.Errorf("want %v; got %v", want, got)
+	}
+}