vulncheck: add support for Source imports only mode

Cherry-picked: https://go-review.googlesource.com/c/exp/+/359174

Change-Id: Ib05f3d7bbc6e32af7a1311ec5e10625a22f8809d
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/395038
Trust: Julie Qiu <julie@golang.org>
Run-TryBot: Julie Qiu <julie@golang.org>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/vulncheck/fetch.go b/vulncheck/fetch.go
index d028513..b065fe7 100644
--- a/vulncheck/fetch.go
+++ b/vulncheck/fetch.go
@@ -97,7 +97,14 @@
 	return mv, nil
 }
 
+// fetchingInTesting is a flag used to avoid skipping
+// loading local vulnerabilities in testing.
+var fetchingInTesting bool = false
+
 func isLocal(mod *packages.Module) bool {
+	if fetchingInTesting {
+		return false
+	}
 	modDir := mod.Dir
 	if mod.Replace != nil {
 		modDir = mod.Replace.Dir
diff --git a/vulncheck/helpers_test.go b/vulncheck/helpers_test.go
index 95eaa24..8aa6d18 100644
--- a/vulncheck/helpers_test.go
+++ b/vulncheck/helpers_test.go
@@ -7,6 +7,8 @@
 import (
 	"fmt"
 
+	"golang.org/x/tools/go/packages"
+	"golang.org/x/tools/go/packages/packagestest"
 	"golang.org/x/vulndb/osv"
 )
 
@@ -22,6 +24,34 @@
 	return nil, nil
 }
 
+// testClient contains the following test vulnerabilities
+//   golang.org/amod/avuln.{VulnData.Vuln1, vulnData.Vuln2}
+//   golang.org/bmod/bvuln.{Vuln}
+var testClient = &mockClient{
+	ret: map[string][]*osv.Entry{
+		"golang.org/amod": []*osv.Entry{
+			{
+				ID: "VA",
+				Affected: []osv.Affected{{
+					Package:           osv.Package{Name: "golang.org/amod/avuln"},
+					Ranges:            osv.Affects{{Type: osv.TypeSemver, Events: []osv.RangeEvent{{Introduced: "1.0.0"}, {Fixed: "1.0.4"}, {Introduced: "1.1.2"}}}},
+					EcosystemSpecific: osv.EcosystemSpecific{Symbols: []string{"VulnData.Vuln1", "VulnData.Vuln2"}},
+				}},
+			},
+		},
+		"golang.org/bmod": []*osv.Entry{
+			{
+				ID: "VB",
+				Affected: []osv.Affected{{
+					Package:           osv.Package{Name: "golang.org/bmod/bvuln"},
+					Ranges:            osv.Affects{{Type: osv.TypeSemver}},
+					EcosystemSpecific: osv.EcosystemSpecific{Symbols: []string{"Vuln"}},
+				}},
+			},
+		},
+	},
+}
+
 func moduleVulnerabilitiesToString(mv moduleVulnerabilities) string {
 	var s string
 	for _, m := range mv {
@@ -40,3 +70,21 @@
 	}
 	return s
 }
+
+func impGraphToStrMap(ig *ImportGraph) map[string][]string {
+	m := make(map[string][]string)
+	for _, n := range ig.Packages {
+		for _, predId := range n.ImportedBy {
+			pred := ig.Packages[predId]
+			m[pred.Path] = append(m[pred.Path], n.Path)
+		}
+	}
+	return m
+}
+
+func loadPackages(e *packagestest.Exported, patterns ...string) ([]*packages.Package, error) {
+	e.Config.Mode |= packages.NeedModule | packages.NeedName | packages.NeedFiles |
+		packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedTypes |
+		packages.NeedTypesSizes | packages.NeedSyntax | packages.NeedTypesInfo | packages.NeedDeps
+	return packages.Load(e.Config, patterns...)
+}
diff --git a/vulncheck/source.go b/vulncheck/source.go
new file mode 100644
index 0000000..46eab37
--- /dev/null
+++ b/vulncheck/source.go
@@ -0,0 +1,123 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package vulncheck
+
+import (
+	"golang.org/x/tools/go/packages"
+)
+
+// Source detects vulnerabilities in pkgs and computes slices of
+//  - imports graph related to an import of a package with some
+//    known vulnerabilities
+//  - requires graph related to a require of a module with a
+//    package that has some known vulnerabilities
+//  - call graph leading to the use of a known vulnerable function
+//    or method
+func Source(pkgs []*packages.Package, cfg *Config) (*Result, error) {
+	if !cfg.ImportsOnly {
+		panic("call graph feature is currently unsupported")
+	}
+
+	modVulns, err := fetchVulnerabilities(cfg.Client, extractModules(pkgs))
+	if err != nil {
+		return nil, err
+	}
+
+	result := &Result{
+		Imports:  &ImportGraph{Packages: make(map[int]*PkgNode)},
+		Requires: &RequireGraph{Modules: make(map[int]*ModNode)},
+	}
+	vulnPkgImportSlice(pkgs, modVulns, result)
+	// TODO(zpavlinovic): compute module and call graph slice.
+	return result, nil
+}
+
+// pkgId is an id counter for nodes of Imports graph.
+var pkgID int = 0
+
+func nextPkgID() int {
+	pkgID += 1
+	return pkgID
+}
+
+// vulnPkgImportSlice computes the slice of pkg imports graph leading to imports of vulnerable
+// packages in modVulns and stores the slice to result.
+func vulnPkgImportSlice(pkgs []*packages.Package, modVulns moduleVulnerabilities, result *Result) {
+	// analyzed contains information on packages analyzed thus far.
+	// If a package is mapped to nil, this means it has been visited
+	// but it does not lead to a vulnerable imports. Otherwise, a
+	// visited package is mapped to Imports package node.
+	analyzed := make(map[*packages.Package]*PkgNode)
+	for _, pkg := range pkgs {
+		// Top level packages that lead to vulnerable imports are
+		// stored as result.Imports graph entry points.
+		if e := vulnImportSlice(pkg, modVulns, result, analyzed); e != nil {
+			result.Imports.Entries = append(result.Imports.Entries, e)
+		}
+	}
+}
+
+// vulnImportSlice checks if pkg has some vulnerabilities or transitively imports
+// a package with known vulnerabilities. If that is the case, populates result.Imports
+// graph with this reachability information and returns the result.Imports package
+// node for pkg. Otherwise, returns nil.
+func vulnImportSlice(pkg *packages.Package, modVulns moduleVulnerabilities, result *Result, analyzed map[*packages.Package]*PkgNode) *PkgNode {
+	if pn, ok := analyzed[pkg]; ok {
+		return pn
+	}
+	analyzed[pkg] = nil
+	// Recursively compute which direct dependencies lead to an import of
+	// a vulnerable package and remember the nodes of such dependencies.
+	var onSlice []*PkgNode
+	for _, imp := range pkg.Imports {
+		if impNode := vulnImportSlice(imp, modVulns, result, analyzed); impNode != nil {
+			onSlice = append(onSlice, impNode)
+		}
+	}
+
+	// Check if pkg has known vulnerabilities.
+	vulns := modVulns.VulnsForPackage(pkg.PkgPath)
+
+	// If pkg is not vulnerable nor it transitively leads
+	// to vulnerabilities, jump out.
+	if len(onSlice) == 0 && len(vulns) == 0 {
+		return nil
+	}
+
+	// Module id gets populated later.
+	pkgNode := &PkgNode{
+		Name: pkg.Name,
+		Path: pkg.PkgPath,
+	}
+	analyzed[pkg] = pkgNode
+
+	id := nextPkgID()
+	result.Imports.Packages[id] = pkgNode
+
+	// Save node predecessor information.
+	for _, impSliceNode := range onSlice {
+		impSliceNode.ImportedBy = append(impSliceNode.ImportedBy, id)
+	}
+
+	// Create Vuln entry for each symbol of known OSV entries for pkg.
+	for _, osv := range vulns {
+		for _, affected := range osv.Affected {
+			if affected.Package.Name != pkgNode.Path {
+				continue
+			}
+			for _, symbol := range affected.EcosystemSpecific.Symbols {
+				vuln := &Vuln{
+					OSV:        osv,
+					Symbol:     symbol,
+					PkgPath:    pkgNode.Path,
+					ModPath:    modPath(pkg.Module),
+					ImportSink: id,
+				}
+				result.Vulns = append(result.Vulns, vuln)
+			}
+		}
+	}
+	return pkgNode
+}
diff --git a/vulncheck/source_test.go b/vulncheck/source_test.go
new file mode 100644
index 0000000..ef8a850
--- /dev/null
+++ b/vulncheck/source_test.go
@@ -0,0 +1,148 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package vulncheck
+
+import (
+	"path"
+	"reflect"
+	"testing"
+
+	"golang.org/x/tools/go/packages/packagestest"
+)
+
+// TestImportsOnly checks for module and imports graph correctness
+// for the Config.ImportsOnly=true mode. The inlined test code has
+// the following package (left) and module (right) imports graphs:
+//
+//       entry/x        entry/y                     entry
+//              \     /        \                   /     \
+//            amod/avuln      zmod/z           amod       zmod
+//                |                              |
+//              wmod/w                         wmod
+//                |                              |
+//            bmod/bvuln                       bmod
+//
+// Packages ending in "vuln" have some known vulnerabilities.
+func TestImportsOnly(t *testing.T) {
+	e := packagestest.Export(t, packagestest.Modules, []packagestest.Module{
+		{
+			Name: "golang.org/entry",
+			Files: map[string]interface{}{
+				"x/x.go": `
+			package x
+
+			import "golang.org/amod/avuln"
+
+			func X() {
+				avuln.VulnData{}.Vuln1()
+			}
+			`,
+				"y/y.go": `
+			package y
+
+			import (
+				"golang.org/amod/avuln"
+				"golang.org/zmod/z"
+			)
+
+			func Y() {
+				avuln.VulnData{}.Vuln2()
+				z.Z()
+			}
+		`}},
+		{
+			Name: "golang.org/zmod@v0.0.0",
+			Files: map[string]interface{}{"z/z.go": `
+			package z
+
+			func Z() {}
+			`},
+		},
+		{
+			Name: "golang.org/amod@v1.1.3",
+			Files: map[string]interface{}{"avuln/avuln.go": `
+			package avuln
+
+			import "golang.org/wmod/w"
+
+			type VulnData struct {}
+			func (v VulnData) Vuln1() { w.W() }
+			func (v VulnData) Vuln2() {}
+			`},
+		},
+		{
+			Name: "golang.org/bmod@v0.5.0",
+			Files: map[string]interface{}{"bvuln/bvuln.go": `
+			package bvuln
+
+			func Vuln() {}
+			`},
+		},
+		{
+			Name: "golang.org/wmod@v0.0.0",
+			Files: map[string]interface{}{"w/w.go": `
+			package w
+
+			import "golang.org/bmod/bvuln"
+
+			func W() { bvuln.Vuln() }
+			`},
+		},
+	})
+	defer e.Cleanup()
+
+	// Make sure local vulns can be loaded.
+	fetchingInTesting = true
+	// Load x and y as entry packages.
+	pkgs, err := loadPackages(e, path.Join(e.Temp(), "entry/x"), path.Join(e.Temp(), "entry/y"))
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if len(pkgs) != 2 {
+		t.Fatal("failed to load x and y test packages")
+	}
+
+	cfg := &Config{
+		Client:      testClient,
+		ImportsOnly: true,
+	}
+	result, err := Source(pkgs, cfg)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// TODO(zpavlinovic): add test below for module graph too.
+
+	// Check that we find the right number of vulnerabilities.
+	// There should be three entries as there are three vulnerable
+	// symbols in the two import-reachable OSVs.
+	if len(result.Vulns) != 3 {
+		t.Errorf("want 3 Vulns, got %d", len(result.Vulns))
+	}
+
+	// Check that vulnerabilities are connected to the imports graph.
+	for _, v := range result.Vulns {
+		if v.ImportSink == 0 {
+			t.Errorf("want ImportSink !=0 for vuln %v:%v; got 0", v.Symbol, v.PkgPath)
+		}
+	}
+
+	// The slice should include import chains:
+	//   x -> avuln -> w -> bvuln
+	//         |
+	//   y ---->
+	// That is, z package shoud not appear in the slice.
+	wantImports := map[string][]string{
+		"golang.org/entry/x":    {"golang.org/amod/avuln"},
+		"golang.org/entry/y":    {"golang.org/amod/avuln"},
+		"golang.org/amod/avuln": {"golang.org/wmod/w"},
+		"golang.org/wmod/w":     {"golang.org/bmod/bvuln"},
+	}
+
+	if igStrMap := impGraphToStrMap(result.Imports); !reflect.DeepEqual(wantImports, igStrMap) {
+		t.Errorf("want %v imports graph; got %v", wantImports, igStrMap)
+	}
+}