vulncheck: add witness search logic for import chains
Cherry-picked: https://go-review.googlesource.com/c/exp/+/379994
Change-Id: I3afa6b61b723fe461f8483bfaabcbc75af7c9104
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/395057
Trust: Julie Qiu <julie@golang.org>
Run-TryBot: Julie Qiu <julie@golang.org>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/vulncheck/witness.go b/vulncheck/witness.go
index 8e2d420..0ef7003 100644
--- a/vulncheck/witness.go
+++ b/vulncheck/witness.go
@@ -4,11 +4,114 @@
package vulncheck
+import (
+ "container/list"
+ "sync"
+)
+
// ImportChain is sequence of import paths starting with
// a client package and ending with a package with some
// known vulnerabilities.
type ImportChain []*PkgNode
+// ImportChains performs a BFS search of res.RequireGraph for imports of vulnerable
+// packages. Search is performed for each vulnerable package in res.Vulns. The search
+// starts at a vulnerable package and goes up until reaching an entry package in
+// res.ImportGraph.Entries, hence producing an import chain. During the search, a
+// package is visited only once to avoid analyzing every possible import chain.
+// Hence, not all possible vulnerable import chains are reported.
+//
+// Note that the resulting map produces an import chain for each Vuln. Thus, a Vuln
+// with the same PkgPath will have the same list of identified import chains.
+//
+// The reported import chains are ordered by how seemingly easy is to understand
+// them. Shorter import chains appear earlier in the returned slices.
+func ImportChains(res *Result) map[*Vuln][]ImportChain {
+ // Group vulns per package.
+ vPerPkg := make(map[int][]*Vuln)
+ for _, v := range res.Vulns {
+ vPerPkg[v.ImportSink] = append(vPerPkg[v.ImportSink], v)
+ }
+
+ // Collect chains in parallel for every package path.
+ var wg sync.WaitGroup
+ var mu sync.Mutex
+ chains := make(map[*Vuln][]ImportChain)
+ for pkgID, vulns := range vPerPkg {
+ pID := pkgID
+ vs := vulns
+ wg.Add(1)
+ go func() {
+ pChains := importChains(pID, res)
+ mu.Lock()
+ for _, v := range vs {
+ chains[v] = pChains
+ }
+ mu.Unlock()
+ wg.Done()
+ }()
+ }
+ wg.Wait()
+ return chains
+}
+
+// importChains finds representative chains of package imports
+// leading to vulnerable package identified with vulnSinkID.
+func importChains(vulnSinkID int, res *Result) []ImportChain {
+ if vulnSinkID == 0 {
+ return nil
+ }
+
+ // Entry packages, needed for finalizing chains.
+ entries := make(map[int]bool)
+ for _, e := range res.Imports.Entries {
+ entries[e] = true
+ }
+
+ var chains []ImportChain
+ seen := make(map[int]bool)
+
+ queue := list.New()
+ queue.PushBack(&importChain{pkg: res.Imports.Packages[vulnSinkID]})
+ for queue.Len() > 0 {
+ front := queue.Front()
+ c := front.Value.(*importChain)
+ queue.Remove(front)
+
+ pkg := c.pkg
+ if seen[pkg.ID] {
+ continue
+ }
+ seen[pkg.ID] = true
+
+ for _, impBy := range pkg.ImportedBy {
+ imp := res.Imports.Packages[impBy]
+ newC := &importChain{pkg: imp, child: c}
+ // If the next package is an entry, we have
+ // a chain to report.
+ if entries[imp.ID] {
+ chains = append(chains, newC.ImportChain())
+ }
+ queue.PushBack(newC)
+ }
+ }
+ return chains
+}
+
+// importChain models an chain of package imports.
+type importChain struct {
+ pkg *PkgNode
+ child *importChain
+}
+
+// ImportChain converts importChain to ImportChain type.
+func (r *importChain) ImportChain() ImportChain {
+ if r == nil {
+ return nil
+ }
+ return append([]*PkgNode{r.pkg}, r.child.ImportChain()...)
+}
+
// CallStack models a trace of function calls starting
// with a client function or method and ending with a
// call to a vulnerable symbol.
diff --git a/vulncheck/witness_test.go b/vulncheck/witness_test.go
new file mode 100644
index 0000000..eef6cf2
--- /dev/null
+++ b/vulncheck/witness_test.go
@@ -0,0 +1,63 @@
+package vulncheck
+
+import (
+ "reflect"
+ "strings"
+ "testing"
+)
+
+// chainsToString converts map Vuln:chains to Vuln.PkgPath:["pkg1->...->pkgN", ...]
+// string representation.
+func chainsToString(chains map[*Vuln][]ImportChain) map[string][]string {
+ m := make(map[string][]string)
+ for v, chs := range chains {
+ var chsStr []string
+ for _, ch := range chs {
+ var chStr []string
+ for _, imp := range ch {
+ chStr = append(chStr, imp.Path)
+ }
+ chsStr = append(chsStr, strings.Join(chStr, "->"))
+ }
+ m[v.PkgPath] = chsStr
+ }
+ return m
+}
+
+func TestImportChains(t *testing.T) {
+ // Package import structure for the test program
+ // entry1 entry2
+ // | |
+ // interm1 |
+ // | \ |
+ // | interm2
+ // | / |
+ // vuln1 vuln2
+ e1 := &PkgNode{ID: 1, Path: "entry1"}
+ e2 := &PkgNode{ID: 2, Path: "entry2"}
+ i1 := &PkgNode{ID: 3, Path: "interm1", ImportedBy: []int{1}}
+ i2 := &PkgNode{ID: 4, Path: "interm2", ImportedBy: []int{2, 3}}
+ v1 := &PkgNode{ID: 5, Path: "vuln1", ImportedBy: []int{3, 4}}
+ v2 := &PkgNode{ID: 6, Path: "vuln2", ImportedBy: []int{4}}
+
+ ig := &ImportGraph{
+ Packages: map[int]*PkgNode{1: e1, 2: e2, 3: i1, 4: i2, 5: v1, 6: v2},
+ Entries: []int{1, 2},
+ }
+ vuln1 := &Vuln{ImportSink: 5, PkgPath: "vuln1"}
+ vuln2 := &Vuln{ImportSink: 6, PkgPath: "vuln2"}
+ res := &Result{Imports: ig, Vulns: []*Vuln{vuln1, vuln2}}
+
+ // The chain entry1->interm1->interm2->vuln1 is not reported
+ // as there exist a shorter trace going from entry1 to vuln1
+ // via interm1.
+ want := map[string][]string{
+ "vuln1": {"entry1->interm1->vuln1", "entry2->interm2->vuln1"},
+ "vuln2": {"entry2->interm2->vuln2", "entry1->interm1->interm2->vuln2"},
+ }
+
+ chains := ImportChains(res)
+ if got := chainsToString(chains); !reflect.DeepEqual(want, got) {
+ t.Errorf("want %v; got %v", want, got)
+ }
+}