vulndb/govulncheck: surface unreachable vulns

Surface vulnerabilities in imported packages which are not reachable in
the call graph. Additionally we pre-filter vulnerabilities which do not
apply to the versions used, skipping unnecessary analysis.

Change-Id: If845a376406cd079a5f96935f419e6af5eabd76c
Reviewed-on: https://go-review.googlesource.com/c/exp/+/335171
Trust: Roland Shoemaker <roland@golang.org>
Trust: Zvonimir Pavlinovic <zpavlinovic@google.com>
Run-TryBot: Roland Shoemaker <roland@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Zvonimir Pavlinovic <zpavlinovic@google.com>
diff --git a/vulndb/govulncheck/main.go b/vulndb/govulncheck/main.go
index 7232c28..7c9d185 100644
--- a/vulndb/govulncheck/main.go
+++ b/vulndb/govulncheck/main.go
@@ -30,6 +30,7 @@
 	"golang.org/x/exp/vulndb/internal/binscan"
 	"golang.org/x/tools/go/packages"
 	"golang.org/x/tools/go/ssa/ssautil"
+	"golang.org/x/vulndb/osv"
 )
 
 var (
@@ -66,6 +67,60 @@
 databases are merged.
 `
 
+type results struct {
+	ImportedPackages []string
+	Vulns            []*osv.Entry
+	Findings         []audit.Finding
+}
+
+func (r *results) unreachable() []*osv.Entry {
+	seen := map[string]bool{}
+	for _, f := range r.Findings {
+		for _, v := range f.Vulns {
+			seen[v.ID] = true
+		}
+	}
+	unseen := []*osv.Entry{}
+	for _, v := range r.Vulns {
+		if seen[v.ID] {
+			continue
+		}
+		unseen = append(unseen, v)
+	}
+	return unseen
+}
+
+// presentTo pretty-prints results to out.
+func (r *results) presentTo(out io.Writer) {
+	sort.Strings(r.ImportedPackages)
+	sort.Slice(r.Vulns, func(i, j int) bool { return r.Vulns[i].ID < r.Vulns[j].ID })
+	sort.SliceStable(r.Findings, func(i int, j int) bool { return audit.FindingCompare(r.Findings[i], r.Findings[j]) })
+	if !*jsonFlag {
+		for _, finding := range r.Findings {
+			finding.Write(out)
+			out.Write([]byte{'\n'})
+		}
+		if unreachable := r.unreachable(); len(unreachable) > 0 {
+			fmt.Fprintf(out, "The following %d vulnerabilities don't affect this project:\n", len(unreachable))
+			for _, u := range unreachable {
+				var aliases string
+				if len(u.Aliases) > 0 {
+					aliases = fmt.Sprintf(" (%s)", strings.Join(u.Aliases, ", "))
+				}
+				fmt.Fprintf(out, "- %s%s (package imported, but vulnerable symbol is not reachable)\n", u.ID, aliases)
+			}
+		}
+		return
+	}
+	b, err := json.MarshalIndent(r, "", "\t")
+	if err != nil {
+		fmt.Fprintf(os.Stderr, "govulncheck: %s\n", err)
+		os.Exit(1)
+	}
+	out.Write(b)
+	out.Write([]byte{'\n'})
+}
+
 func main() {
 	flag.Usage = func() { fmt.Fprintln(os.Stderr, usage) }
 	flag.Parse()
@@ -84,32 +139,13 @@
 		Mode: packages.LoadAllSyntax | packages.NeedModule,
 	}
 
-	findings, err := run(cfg, flag.Args(), *importsFlag, dbs)
+	r, err := run(cfg, flag.Args(), *importsFlag, dbs)
 	if err != nil {
 		fmt.Fprintf(os.Stderr, "govulncheck: %s\n", err)
 		os.Exit(1)
 	}
 
-	sort.SliceStable(findings, func(i int, j int) bool { return audit.FindingCompare(findings[i], findings[j]) })
-	presentTo(os.Stdout, findings)
-}
-
-// presentTo pretty-prints findings to out.
-func presentTo(out io.Writer, findings []audit.Finding) {
-	if !*jsonFlag {
-		for _, finding := range findings {
-			finding.Write(out)
-			out.Write([]byte{'\n'})
-		}
-		return
-	}
-	b, err := json.MarshalIndent(findings, "", "\t")
-	if err != nil {
-		fmt.Fprintf(os.Stderr, "govulncheck: %s\n", err)
-		os.Exit(1)
-	}
-	out.Write(b)
-	out.Write([]byte{'\n'})
+	r.presentTo(os.Stdout)
 }
 
 // allPkgPaths computes a list of all packages, in
@@ -145,7 +181,20 @@
 	return !s.IsDir()
 }
 
-func run(cfg *packages.Config, patterns []string, importsOnly bool, dbs []string) ([]audit.Finding, error) {
+func filterVulns(vulns []*osv.Entry, packageVersions map[string]string) []*osv.Entry {
+	filtered := []*osv.Entry{}
+	for _, v := range vulns {
+		version, ok := packageVersions[v.Package.Name]
+		if !ok || !v.Affects.AffectsSemver(version) {
+			continue
+		}
+		filtered = append(filtered, v)
+	}
+	return filtered
+}
+
+func run(cfg *packages.Config, patterns []string, importsOnly bool, dbs []string) (*results, error) {
+	r := &results{}
 	if len(patterns) == 1 && isFile(patterns[0]) {
 		packages, symbols, err := binscan.ExtractPackagesAndSymbols(patterns[0])
 		if err != nil {
@@ -156,14 +205,20 @@
 		for pkg := range packages {
 			paths = append(paths, pkg)
 		}
+		r.ImportedPackages = paths
 
 		vulns, err := audit.LoadVulnerabilities(dbs, paths)
 		if err != nil {
 			return nil, fmt.Errorf("failed to load vulnerability dbs: %v", err)
 		}
-		env := audit.Env{OS: runtime.GOOS, Arch: runtime.GOARCH, PkgVersions: packages, Vulns: vulns}
+		vulns = filterVulns(vulns, packages)
+		if len(vulns) == 0 {
+			return r, nil
+		}
+		r.Vulns = vulns
 
-		return audit.VulnerablePackageSymbols(symbols, env), nil
+		r.Findings = audit.VulnerablePackageSymbols(symbols, audit.Env{OS: runtime.GOOS, Arch: runtime.GOARCH, PkgVersions: packages, Vulns: vulns})
+		return r, nil
 	}
 
 	// Load packages.
@@ -181,22 +236,29 @@
 		log.Printf("\t%d loaded packages\n", len(pkgs))
 	}
 
+	// Load package versions.
+	pkgVersions := audit.PackageVersions(pkgs)
+
 	// Load database.
 	if *verboseFlag {
 		log.Println("loading database...")
 	}
-	vulns, err := audit.LoadVulnerabilities(dbs, allPkgPaths(pkgs))
+	importedPackages := allPkgPaths(pkgs)
+	r.ImportedPackages = importedPackages
+	vulns, err := audit.LoadVulnerabilities(dbs, importedPackages)
 	if err != nil {
 		return nil, fmt.Errorf("failed to load vulnerability dbs: %v", err)
 	}
+	vulns = filterVulns(vulns, pkgVersions)
+	if len(vulns) == 0 {
+		return r, nil
+	}
+	r.Vulns = vulns
 
 	if *verboseFlag {
 		log.Printf("\t%d known vulnerabilities.\n", len(vulns))
 	}
 
-	// Load package versions.
-	pkgVersions := audit.PackageVersions(pkgs)
-
 	// Load SSA.
 	if *verboseFlag {
 		log.Println("building ssa...")
@@ -214,12 +276,13 @@
 	var findings []audit.Finding
 	env := audit.Env{OS: runtime.GOOS, Arch: runtime.GOARCH, PkgVersions: pkgVersions, Vulns: vulns}
 	if importsOnly {
-		findings = audit.VulnerableImports(ssaPkgs, env)
+		r.Findings = audit.VulnerableImports(ssaPkgs, env)
 	} else {
-		findings = audit.VulnerableSymbols(ssaPkgs, env)
+		r.Findings = audit.VulnerableSymbols(ssaPkgs, env)
 	}
 	if *verboseFlag {
 		log.Printf("\t%d detected findings.\n", len(findings))
 	}
-	return findings, nil
+
+	return r, nil
 }
diff --git a/vulndb/govulncheck/main_test.go b/vulndb/govulncheck/main_test.go
index 6042b28..ad09378 100644
--- a/vulndb/govulncheck/main_test.go
+++ b/vulndb/govulncheck/main_test.go
@@ -15,6 +15,7 @@
 	"os/exec"
 	"path"
 	"path/filepath"
+	"reflect"
 	"runtime"
 	"sort"
 	"strings"
@@ -23,6 +24,7 @@
 	"golang.org/x/exp/vulndb/internal/audit"
 	"golang.org/x/tools/go/packages"
 	"golang.org/x/tools/go/packages/packagestest"
+	"golang.org/x/vulndb/osv"
 )
 
 // TODO(zpavlinovic): improve integration tests.
@@ -210,12 +212,12 @@
 			}
 		}
 
-		finds, err := run(cfg, []string{hashiVaultOkta}, false, []string{test.source})
+		r, err := run(cfg, []string{hashiVaultOkta}, false, []string{test.source})
 		if err != nil {
 			t.Fatal(err)
 		}
-		sort.SliceStable(finds, func(i int, j int) bool { return audit.FindingCompare(finds[i], finds[j]) })
-		if fs := testFindings(finds); !subset(test.want, fs) {
+		sort.SliceStable(r.Findings, func(i int, j int) bool { return audit.FindingCompare(r.Findings[i], r.Findings[j]) })
+		if fs := testFindings(r.Findings); !subset(test.want, fs) {
 			t.Errorf("want %v subset of findings; got %v", test.want, fs)
 		}
 	}
@@ -362,13 +364,64 @@
 			}
 		}
 
-		finds, err := run(cfg, []string{"./..."}, false, []string{test.source})
+		r, err := run(cfg, []string{"./..."}, false, []string{test.source})
 		if err != nil {
 			t.Fatal(err)
 		}
-		sort.SliceStable(finds, func(i int, j int) bool { return audit.FindingCompare(finds[i], finds[j]) })
-		if fs := testFindings(finds); !subset(test.want, fs) {
+		sort.SliceStable(r.Findings, func(i int, j int) bool { return audit.FindingCompare(r.Findings[i], r.Findings[j]) })
+		if fs := testFindings(r.Findings); !subset(test.want, fs) {
 			t.Errorf("want %v subset of findings; got %v", test.want, fs)
 		}
 	}
 }
+
+func vulnsToString(vulns []*osv.Entry) string {
+	var s string
+	for _, v := range vulns {
+		s += fmt.Sprintf("\t%v\n", v)
+	}
+	return s
+}
+
+func TestFilterVulsn(t *testing.T) {
+	vulns := []*osv.Entry{
+		{Package: osv.Package{Name: "example.com/a"}, Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Fixed: "1.0.0"}}}},
+		{Package: osv.Package{Name: "example.com/b"}, Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Fixed: "2.0.0"}}}},
+		{Package: osv.Package{Name: "example.com/c"}, Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Fixed: "3.0.0"}}}},
+	}
+	pkgs := map[string]string{
+		"example.com/a": "v0.0.1",
+		"example.com/b": "v1.0.0",
+		"example.com/c": "v9.0.0",
+	}
+
+	filtered := filterVulns(vulns, pkgs)
+
+	expected := []*osv.Entry{
+		{Package: osv.Package{Name: "example.com/a"}, Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Fixed: "1.0.0"}}}},
+		{Package: osv.Package{Name: "example.com/b"}, Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Fixed: "2.0.0"}}}},
+	}
+	if !reflect.DeepEqual(filtered, expected) {
+		t.Errorf("filterVulns returned unexpected results: got\n%swant\n%s", vulnsToString(filtered), vulnsToString(expected))
+	}
+}
+
+func TestUnreachable(t *testing.T) {
+	r := &results{
+		Vulns: []*osv.Entry{
+			{ID: "0", Package: osv.Package{Name: "example.com/a"}, Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Fixed: "1.0.0"}}}},
+			{ID: "1", Package: osv.Package{Name: "example.com/b"}, Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Fixed: "2.0.0"}}}},
+		},
+		Findings: []audit.Finding{
+			{Vulns: []osv.Entry{{ID: "0"}}},
+		},
+	}
+
+	expected := []*osv.Entry{
+		{ID: "1", Package: osv.Package{Name: "example.com/b"}, Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Fixed: "2.0.0"}}}},
+	}
+	unreachable := r.unreachable()
+	if !reflect.DeepEqual(unreachable, expected) {
+		t.Errorf("unreachable returned unexpected results: got\n%swant\n%s", vulnsToString(unreachable), vulnsToString(expected))
+	}
+}