vulndb/govulncheck: surface unreachable vulns
Surface vulnerabilities in imported packages which are not reachable in
the call graph. Additionally we pre-filter vulnerabilities which do not
apply to the versions used, skipping unnecessary analysis.
Change-Id: If845a376406cd079a5f96935f419e6af5eabd76c
Reviewed-on: https://go-review.googlesource.com/c/exp/+/335171
Trust: Roland Shoemaker <roland@golang.org>
Trust: Zvonimir Pavlinovic <zpavlinovic@google.com>
Run-TryBot: Roland Shoemaker <roland@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Zvonimir Pavlinovic <zpavlinovic@google.com>
diff --git a/vulndb/govulncheck/main.go b/vulndb/govulncheck/main.go
index 7232c28..7c9d185 100644
--- a/vulndb/govulncheck/main.go
+++ b/vulndb/govulncheck/main.go
@@ -30,6 +30,7 @@
"golang.org/x/exp/vulndb/internal/binscan"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/go/ssa/ssautil"
+ "golang.org/x/vulndb/osv"
)
var (
@@ -66,6 +67,60 @@
databases are merged.
`
+type results struct {
+ ImportedPackages []string
+ Vulns []*osv.Entry
+ Findings []audit.Finding
+}
+
+func (r *results) unreachable() []*osv.Entry {
+ seen := map[string]bool{}
+ for _, f := range r.Findings {
+ for _, v := range f.Vulns {
+ seen[v.ID] = true
+ }
+ }
+ unseen := []*osv.Entry{}
+ for _, v := range r.Vulns {
+ if seen[v.ID] {
+ continue
+ }
+ unseen = append(unseen, v)
+ }
+ return unseen
+}
+
+// presentTo pretty-prints results to out.
+func (r *results) presentTo(out io.Writer) {
+ sort.Strings(r.ImportedPackages)
+ sort.Slice(r.Vulns, func(i, j int) bool { return r.Vulns[i].ID < r.Vulns[j].ID })
+ sort.SliceStable(r.Findings, func(i int, j int) bool { return audit.FindingCompare(r.Findings[i], r.Findings[j]) })
+ if !*jsonFlag {
+ for _, finding := range r.Findings {
+ finding.Write(out)
+ out.Write([]byte{'\n'})
+ }
+ if unreachable := r.unreachable(); len(unreachable) > 0 {
+ fmt.Fprintf(out, "The following %d vulnerabilities don't affect this project:\n", len(unreachable))
+ for _, u := range unreachable {
+ var aliases string
+ if len(u.Aliases) > 0 {
+ aliases = fmt.Sprintf(" (%s)", strings.Join(u.Aliases, ", "))
+ }
+ fmt.Fprintf(out, "- %s%s (package imported, but vulnerable symbol is not reachable)\n", u.ID, aliases)
+ }
+ }
+ return
+ }
+ b, err := json.MarshalIndent(r, "", "\t")
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "govulncheck: %s\n", err)
+ os.Exit(1)
+ }
+ out.Write(b)
+ out.Write([]byte{'\n'})
+}
+
func main() {
flag.Usage = func() { fmt.Fprintln(os.Stderr, usage) }
flag.Parse()
@@ -84,32 +139,13 @@
Mode: packages.LoadAllSyntax | packages.NeedModule,
}
- findings, err := run(cfg, flag.Args(), *importsFlag, dbs)
+ r, err := run(cfg, flag.Args(), *importsFlag, dbs)
if err != nil {
fmt.Fprintf(os.Stderr, "govulncheck: %s\n", err)
os.Exit(1)
}
- sort.SliceStable(findings, func(i int, j int) bool { return audit.FindingCompare(findings[i], findings[j]) })
- presentTo(os.Stdout, findings)
-}
-
-// presentTo pretty-prints findings to out.
-func presentTo(out io.Writer, findings []audit.Finding) {
- if !*jsonFlag {
- for _, finding := range findings {
- finding.Write(out)
- out.Write([]byte{'\n'})
- }
- return
- }
- b, err := json.MarshalIndent(findings, "", "\t")
- if err != nil {
- fmt.Fprintf(os.Stderr, "govulncheck: %s\n", err)
- os.Exit(1)
- }
- out.Write(b)
- out.Write([]byte{'\n'})
+ r.presentTo(os.Stdout)
}
// allPkgPaths computes a list of all packages, in
@@ -145,7 +181,20 @@
return !s.IsDir()
}
-func run(cfg *packages.Config, patterns []string, importsOnly bool, dbs []string) ([]audit.Finding, error) {
+func filterVulns(vulns []*osv.Entry, packageVersions map[string]string) []*osv.Entry {
+ filtered := []*osv.Entry{}
+ for _, v := range vulns {
+ version, ok := packageVersions[v.Package.Name]
+ if !ok || !v.Affects.AffectsSemver(version) {
+ continue
+ }
+ filtered = append(filtered, v)
+ }
+ return filtered
+}
+
+func run(cfg *packages.Config, patterns []string, importsOnly bool, dbs []string) (*results, error) {
+ r := &results{}
if len(patterns) == 1 && isFile(patterns[0]) {
packages, symbols, err := binscan.ExtractPackagesAndSymbols(patterns[0])
if err != nil {
@@ -156,14 +205,20 @@
for pkg := range packages {
paths = append(paths, pkg)
}
+ r.ImportedPackages = paths
vulns, err := audit.LoadVulnerabilities(dbs, paths)
if err != nil {
return nil, fmt.Errorf("failed to load vulnerability dbs: %v", err)
}
- env := audit.Env{OS: runtime.GOOS, Arch: runtime.GOARCH, PkgVersions: packages, Vulns: vulns}
+ vulns = filterVulns(vulns, packages)
+ if len(vulns) == 0 {
+ return r, nil
+ }
+ r.Vulns = vulns
- return audit.VulnerablePackageSymbols(symbols, env), nil
+ r.Findings = audit.VulnerablePackageSymbols(symbols, audit.Env{OS: runtime.GOOS, Arch: runtime.GOARCH, PkgVersions: packages, Vulns: vulns})
+ return r, nil
}
// Load packages.
@@ -181,22 +236,29 @@
log.Printf("\t%d loaded packages\n", len(pkgs))
}
+ // Load package versions.
+ pkgVersions := audit.PackageVersions(pkgs)
+
// Load database.
if *verboseFlag {
log.Println("loading database...")
}
- vulns, err := audit.LoadVulnerabilities(dbs, allPkgPaths(pkgs))
+ importedPackages := allPkgPaths(pkgs)
+ r.ImportedPackages = importedPackages
+ vulns, err := audit.LoadVulnerabilities(dbs, importedPackages)
if err != nil {
return nil, fmt.Errorf("failed to load vulnerability dbs: %v", err)
}
+ vulns = filterVulns(vulns, pkgVersions)
+ if len(vulns) == 0 {
+ return r, nil
+ }
+ r.Vulns = vulns
if *verboseFlag {
log.Printf("\t%d known vulnerabilities.\n", len(vulns))
}
- // Load package versions.
- pkgVersions := audit.PackageVersions(pkgs)
-
// Load SSA.
if *verboseFlag {
log.Println("building ssa...")
@@ -214,12 +276,13 @@
var findings []audit.Finding
env := audit.Env{OS: runtime.GOOS, Arch: runtime.GOARCH, PkgVersions: pkgVersions, Vulns: vulns}
if importsOnly {
- findings = audit.VulnerableImports(ssaPkgs, env)
+ r.Findings = audit.VulnerableImports(ssaPkgs, env)
} else {
- findings = audit.VulnerableSymbols(ssaPkgs, env)
+ r.Findings = audit.VulnerableSymbols(ssaPkgs, env)
}
if *verboseFlag {
log.Printf("\t%d detected findings.\n", len(findings))
}
- return findings, nil
+
+ return r, nil
}
diff --git a/vulndb/govulncheck/main_test.go b/vulndb/govulncheck/main_test.go
index 6042b28..ad09378 100644
--- a/vulndb/govulncheck/main_test.go
+++ b/vulndb/govulncheck/main_test.go
@@ -15,6 +15,7 @@
"os/exec"
"path"
"path/filepath"
+ "reflect"
"runtime"
"sort"
"strings"
@@ -23,6 +24,7 @@
"golang.org/x/exp/vulndb/internal/audit"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/go/packages/packagestest"
+ "golang.org/x/vulndb/osv"
)
// TODO(zpavlinovic): improve integration tests.
@@ -210,12 +212,12 @@
}
}
- finds, err := run(cfg, []string{hashiVaultOkta}, false, []string{test.source})
+ r, err := run(cfg, []string{hashiVaultOkta}, false, []string{test.source})
if err != nil {
t.Fatal(err)
}
- sort.SliceStable(finds, func(i int, j int) bool { return audit.FindingCompare(finds[i], finds[j]) })
- if fs := testFindings(finds); !subset(test.want, fs) {
+ sort.SliceStable(r.Findings, func(i int, j int) bool { return audit.FindingCompare(r.Findings[i], r.Findings[j]) })
+ if fs := testFindings(r.Findings); !subset(test.want, fs) {
t.Errorf("want %v subset of findings; got %v", test.want, fs)
}
}
@@ -362,13 +364,64 @@
}
}
- finds, err := run(cfg, []string{"./..."}, false, []string{test.source})
+ r, err := run(cfg, []string{"./..."}, false, []string{test.source})
if err != nil {
t.Fatal(err)
}
- sort.SliceStable(finds, func(i int, j int) bool { return audit.FindingCompare(finds[i], finds[j]) })
- if fs := testFindings(finds); !subset(test.want, fs) {
+ sort.SliceStable(r.Findings, func(i int, j int) bool { return audit.FindingCompare(r.Findings[i], r.Findings[j]) })
+ if fs := testFindings(r.Findings); !subset(test.want, fs) {
t.Errorf("want %v subset of findings; got %v", test.want, fs)
}
}
}
+
+func vulnsToString(vulns []*osv.Entry) string {
+ var s string
+ for _, v := range vulns {
+ s += fmt.Sprintf("\t%v\n", v)
+ }
+ return s
+}
+
+func TestFilterVulsn(t *testing.T) {
+ vulns := []*osv.Entry{
+ {Package: osv.Package{Name: "example.com/a"}, Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Fixed: "1.0.0"}}}},
+ {Package: osv.Package{Name: "example.com/b"}, Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Fixed: "2.0.0"}}}},
+ {Package: osv.Package{Name: "example.com/c"}, Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Fixed: "3.0.0"}}}},
+ }
+ pkgs := map[string]string{
+ "example.com/a": "v0.0.1",
+ "example.com/b": "v1.0.0",
+ "example.com/c": "v9.0.0",
+ }
+
+ filtered := filterVulns(vulns, pkgs)
+
+ expected := []*osv.Entry{
+ {Package: osv.Package{Name: "example.com/a"}, Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Fixed: "1.0.0"}}}},
+ {Package: osv.Package{Name: "example.com/b"}, Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Fixed: "2.0.0"}}}},
+ }
+ if !reflect.DeepEqual(filtered, expected) {
+ t.Errorf("filterVulns returned unexpected results: got\n%swant\n%s", vulnsToString(filtered), vulnsToString(expected))
+ }
+}
+
+func TestUnreachable(t *testing.T) {
+ r := &results{
+ Vulns: []*osv.Entry{
+ {ID: "0", Package: osv.Package{Name: "example.com/a"}, Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Fixed: "1.0.0"}}}},
+ {ID: "1", Package: osv.Package{Name: "example.com/b"}, Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Fixed: "2.0.0"}}}},
+ },
+ Findings: []audit.Finding{
+ {Vulns: []osv.Entry{{ID: "0"}}},
+ },
+ }
+
+ expected := []*osv.Entry{
+ {ID: "1", Package: osv.Package{Name: "example.com/b"}, Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Fixed: "2.0.0"}}}},
+ }
+ unreachable := r.unreachable()
+ if !reflect.DeepEqual(unreachable, expected) {
+ t.Errorf("unreachable returned unexpected results: got\n%swant\n%s", vulnsToString(unreachable), vulnsToString(expected))
+ }
+}