vulndb: move from package structured vulnerabilities to module
Adapts govulncheck to work with a database structured around per-module
vulnerabilities, rather than per-package vulnerabilities.
This requires a significant refactor of various aspects of the main
package and the internal/audit packages which, while large, I think
makes the overall program flow somewhat simpler to understand. Some
changes to tests are also required, although similarly I believe they
end up with easier to understand/modify tests.
This also paves the way for more comprehensive details around which
vulnerabilities are unreachable.
Change-Id: I3dd402db344849db6f1a118feee65734daf924cf
Reviewed-on: https://go-review.googlesource.com/c/exp/+/339191
Trust: Roland Shoemaker <roland@golang.org>
Run-TryBot: Roland Shoemaker <roland@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Zvonimir Pavlinovic <zpavlinovic@google.com>
diff --git a/vulndb/govulncheck/main.go b/vulndb/govulncheck/main.go
index 7c9d185..9248d6a 100644
--- a/vulndb/govulncheck/main.go
+++ b/vulndb/govulncheck/main.go
@@ -30,6 +30,7 @@
"golang.org/x/exp/vulndb/internal/binscan"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/go/ssa/ssautil"
+ "golang.org/x/vulndb/client"
"golang.org/x/vulndb/osv"
)
@@ -68,9 +69,9 @@
`
type results struct {
- ImportedPackages []string
- Vulns []*osv.Entry
- Findings []audit.Finding
+ Modules []*packages.Module
+ Vulns []*osv.Entry
+ Findings []audit.Finding
}
func (r *results) unreachable() []*osv.Entry {
@@ -92,7 +93,6 @@
// presentTo pretty-prints results to out.
func (r *results) presentTo(out io.Writer) {
- sort.Strings(r.ImportedPackages)
sort.Slice(r.Vulns, func(i, j int) bool { return r.Vulns[i].ID < r.Vulns[j].ID })
sort.SliceStable(r.Findings, func(i int, j int) bool { return audit.FindingCompare(r.Findings[i], r.Findings[j]) })
if !*jsonFlag {
@@ -148,29 +148,30 @@
r.presentTo(os.Stdout)
}
-// allPkgPaths computes a list of all packages, in
-// the form of their paths, reachable from pkgs.
-func allPkgPaths(pkgs []*packages.Package) []string {
- paths := make(map[string]bool)
+func extractModules(pkgs []*packages.Package) []*packages.Module {
+ modMap := map[*packages.Module]bool{}
+ seen := map[*packages.Package]bool{}
+ var extract func(*packages.Package, map[*packages.Module]bool)
+ extract = func(pkg *packages.Package, modMap map[*packages.Module]bool) {
+ if pkg == nil || seen[pkg] {
+ return
+ }
+ if pkg.Module != nil {
+ modMap[pkg.Module] = true
+ }
+ seen[pkg] = true
+ for _, imp := range pkg.Imports {
+ extract(imp, modMap)
+ }
+ }
for _, pkg := range pkgs {
- pkgPaths(pkg, paths)
+ extract(pkg, modMap)
}
-
- var ps []string
- for p := range paths {
- ps = append(ps, p)
+ modules := []*packages.Module{}
+ for mod := range modMap {
+ modules = append(modules, mod)
}
- return ps
-}
-
-func pkgPaths(pkg *packages.Package, paths map[string]bool) {
- if _, ok := paths[pkg.PkgPath]; ok {
- return
- }
- paths[pkg.PkgPath] = true
- for _, imp := range pkg.Imports {
- pkgPaths(imp, paths)
- }
+ return modules
}
func isFile(path string) bool {
@@ -181,43 +182,25 @@
return !s.IsDir()
}
-func filterVulns(vulns []*osv.Entry, packageVersions map[string]string) []*osv.Entry {
- filtered := []*osv.Entry{}
- for _, v := range vulns {
- version, ok := packageVersions[v.Package.Name]
- if !ok || !v.Affects.AffectsSemver(version) {
- continue
- }
- filtered = append(filtered, v)
- }
- return filtered
-}
-
func run(cfg *packages.Config, patterns []string, importsOnly bool, dbs []string) (*results, error) {
r := &results{}
if len(patterns) == 1 && isFile(patterns[0]) {
- packages, symbols, err := binscan.ExtractPackagesAndSymbols(patterns[0])
+ modules, symbols, err := binscan.ExtractPackagesAndSymbols(patterns[0])
if err != nil {
return nil, err
}
- paths := make([]string, 0, len(packages))
- for pkg := range packages {
- paths = append(paths, pkg)
+ dbClient, err := client.NewClient(dbs, client.Options{})
+ if err != nil {
+ return nil, fmt.Errorf("failed to create database client: %s", err)
}
- r.ImportedPackages = paths
-
- vulns, err := audit.LoadVulnerabilities(dbs, paths)
+ vulns, err := audit.FetchVulnerabilities(dbClient, modules)
if err != nil {
return nil, fmt.Errorf("failed to load vulnerability dbs: %v", err)
}
- vulns = filterVulns(vulns, packages)
- if len(vulns) == 0 {
- return r, nil
- }
- r.Vulns = vulns
+ vulns = vulns.Filter(runtime.GOOS, runtime.GOARCH)
- r.Findings = audit.VulnerablePackageSymbols(symbols, audit.Env{OS: runtime.GOOS, Arch: runtime.GOARCH, PkgVersions: packages, Vulns: vulns})
+ r.Findings = audit.VulnerablePackageSymbols(symbols, vulns)
return r, nil
}
@@ -236,27 +219,23 @@
log.Printf("\t%d loaded packages\n", len(pkgs))
}
- // Load package versions.
- pkgVersions := audit.PackageVersions(pkgs)
-
// Load database.
if *verboseFlag {
log.Println("loading database...")
}
- importedPackages := allPkgPaths(pkgs)
- r.ImportedPackages = importedPackages
- vulns, err := audit.LoadVulnerabilities(dbs, importedPackages)
+ r.Modules = extractModules(pkgs)
+ dbClient, err := client.NewClient(dbs, client.Options{})
if err != nil {
- return nil, fmt.Errorf("failed to load vulnerability dbs: %v", err)
+ return nil, fmt.Errorf("failed to create database client: %s", err)
}
- vulns = filterVulns(vulns, pkgVersions)
- if len(vulns) == 0 {
- return r, nil
+ modVulns, err := audit.FetchVulnerabilities(dbClient, r.Modules)
+ if err != nil {
+ return nil, fmt.Errorf("failed to fetch vulnerabilities: %v", err)
}
- r.Vulns = vulns
+ modVulns = modVulns.Filter(runtime.GOOS, runtime.GOARCH)
if *verboseFlag {
- log.Printf("\t%d known vulnerabilities.\n", len(vulns))
+ log.Printf("\t%d known vulnerabilities.\n", modVulns.Num())
}
// Load SSA.
@@ -274,11 +253,10 @@
log.Println("detecting vulnerabilities...")
}
var findings []audit.Finding
- env := audit.Env{OS: runtime.GOOS, Arch: runtime.GOARCH, PkgVersions: pkgVersions, Vulns: vulns}
if importsOnly {
- r.Findings = audit.VulnerableImports(ssaPkgs, env)
+ r.Findings = audit.VulnerableImports(ssaPkgs, modVulns)
} else {
- r.Findings = audit.VulnerableSymbols(ssaPkgs, env)
+ r.Findings = audit.VulnerableSymbols(ssaPkgs, modVulns)
}
if *verboseFlag {
log.Printf("\t%d detected findings.\n", len(findings))
diff --git a/vulndb/govulncheck/main_test.go b/vulndb/govulncheck/main_test.go
index ad09378..e762cc6 100644
--- a/vulndb/govulncheck/main_test.go
+++ b/vulndb/govulncheck/main_test.go
@@ -47,9 +47,9 @@
}`
var vulns = map[string]string{
- "github.com/go-yaml/yaml.json": goYamlVuln,
- "golang.org/x/crypto/ssh.json": cryptoSSHVuln,
- "k8s.io/apiextensions-apiserver/pkg/apiserver.json": k8sAPIServerVuln,
+ "github.com/go-yaml/yaml.json": goYamlVuln,
+ "golang.org/x/crypto.json": cryptoSSHVuln,
+ "k8s.io/apiextensions-apiserver.json": k8sAPIServerVuln,
}
// addToLocalDb adds vuln for package p to local db at path db.
@@ -163,6 +163,10 @@
t.Logf("failed to get %s: %s", hashiVaultOkta+"@v1.6.3", out)
t.Fatal(err)
}
+ // if out, err := execCmd(e.Config.Dir, env, "go", "mod", "tidy"); err != nil {
+ // t.Logf("failed to mod tidy: %s", out)
+ // t.Fatal(err)
+ // }
// run goaudit.
cfg := &packages.Config{
@@ -189,7 +193,7 @@
}{
// test local db without yaml, which should result in no findings.
{source: "file://" + dbPath, want: nil,
- toAdd: []string{"golang.org/x/crypto/ssh.json", "k8s.io/apiextensions-apiserver/pkg/apiserver.json"}},
+ toAdd: []string{"golang.org/x/crypto.json", "k8s.io/apiextensions-apiserver.json"}},
// add yaml to the local db, which should produce 2 findings.
{source: "file://" + dbPath, toAdd: []string{"github.com/go-yaml/yaml.json"},
want: []finding{
@@ -197,8 +201,8 @@
{"github.com/go-yaml/yaml.yaml_parser_fetch_more_tokens", 12}},
},
// repeat the similar experiment with a server db.
- {source: "http://localhost:8080", toAdd: []string{"k8s.io/apiextensions-apiserver/pkg/apiserver.json"}, want: nil},
- {source: "http://localhost:8080", toAdd: []string{"golang.org/x/crypto/ssh.json", "github.com/go-yaml/yaml.json"},
+ {source: "http://localhost:8080", toAdd: []string{"k8s.io/apiextensions-apiserver.json"}, want: nil},
+ {source: "http://localhost:8080", toAdd: []string{"golang.org/x/crypto.json", "github.com/go-yaml/yaml.json"},
want: []finding{
{"github.com/go-yaml/yaml.decoder.unmarshal", 6},
{"github.com/go-yaml/yaml.yaml_parser_fetch_more_tokens", 12}},
@@ -383,29 +387,6 @@
return s
}
-func TestFilterVulsn(t *testing.T) {
- vulns := []*osv.Entry{
- {Package: osv.Package{Name: "example.com/a"}, Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Fixed: "1.0.0"}}}},
- {Package: osv.Package{Name: "example.com/b"}, Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Fixed: "2.0.0"}}}},
- {Package: osv.Package{Name: "example.com/c"}, Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Fixed: "3.0.0"}}}},
- }
- pkgs := map[string]string{
- "example.com/a": "v0.0.1",
- "example.com/b": "v1.0.0",
- "example.com/c": "v9.0.0",
- }
-
- filtered := filterVulns(vulns, pkgs)
-
- expected := []*osv.Entry{
- {Package: osv.Package{Name: "example.com/a"}, Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Fixed: "1.0.0"}}}},
- {Package: osv.Package{Name: "example.com/b"}, Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Fixed: "2.0.0"}}}},
- }
- if !reflect.DeepEqual(filtered, expected) {
- t.Errorf("filterVulns returned unexpected results: got\n%swant\n%s", vulnsToString(filtered), vulnsToString(expected))
- }
-}
-
func TestUnreachable(t *testing.T) {
r := &results{
Vulns: []*osv.Entry{
diff --git a/vulndb/internal/audit/detect.go b/vulndb/internal/audit/detect.go
index 1e93412..3861792 100644
--- a/vulndb/internal/audit/detect.go
+++ b/vulndb/internal/audit/detect.go
@@ -9,7 +9,9 @@
"fmt"
"go/token"
"io"
+ "strings"
+ "golang.org/x/tools/go/packages"
"golang.org/x/vulndb/osv"
)
@@ -45,17 +47,6 @@
Position *token.Position `json:",omitempty"`
}
-// Env encapsulates information for querying if an imported symbol/package is vulnerable:
-// - platform info
-// - package versions
-// - vulnerability db
-type Env struct {
- OS string
- Arch string
- PkgVersions map[string]string
- Vulns []*osv.Entry
-}
-
// Write method for findings showing the trace and the associated vulnerabilities.
func (f Finding) Write(w io.Writer) {
var pos string
@@ -108,124 +99,104 @@
return []byte(name), nil
}
-func matchingVulns(os, arch, version string, vulns []*osv.Entry) []*osv.Entry {
- var matches []*osv.Entry
- for _, vuln := range vulns {
- if matchesPlatformAndVersion(os, arch, version, vuln) {
- matches = append(matches, vuln)
+type modVulns struct {
+ mod *packages.Module
+ vulns []*osv.Entry
+}
+
+type ModuleVulnerabilities []modVulns
+
+func matchesPlatform(os, arch string, e osv.GoSpecific) bool {
+ matchesOS := len(e.GOOS) == 0
+ matchesArch := len(e.GOARCH) == 0
+ for _, o := range e.GOOS {
+ if os == o {
+ matchesOS = true
+ break
}
}
- return matches
-}
-
-// matchesPlatformAndVersion checks if `os`, `arch`, and `version` match the vulnerability `vuln`.
-func matchesPlatformAndVersion(os, arch, version string, vuln *osv.Entry) bool {
- return matchesPlatform(os, vuln.EcosystemSpecific.GOOS) && matchesPlatform(arch, vuln.EcosystemSpecific.GOARCH) && vuln.Affects.AffectsSemver(version)
-}
-
-// matchesPlatform checks if `platform`, typically os or system architecture,
-// matches `platforms`. Empty `platforms` is also a match.
-func matchesPlatform(platform string, platforms []string) bool {
- if len(platforms) == 0 {
- return true
- }
-
- for _, p := range platforms {
- if platform == p {
- return true
+ for _, a := range e.GOARCH {
+ if arch == a {
+ matchesArch = true
+ break
}
}
- return false
+ return matchesOS && matchesArch
}
-// pkgVulnerabilities map for fast lookup on vulnerable packages.
-// Maps package paths to their vulnerabilities.
-type pkgVulnerabilities map[string][]*osv.Entry
-
-// createPkgVulns creates a fast package-vulnerability look-up map for `vulns`.
-func createPkgVulns(vulns []*osv.Entry) pkgVulnerabilities {
- pkgVulns := make(pkgVulnerabilities)
- for _, vuln := range vulns {
- pkgVulns[vuln.Package.Name] = append(pkgVulns[vuln.Package.Name], vuln)
+func (mv ModuleVulnerabilities) Filter(os, arch string) ModuleVulnerabilities {
+ var filteredMod ModuleVulnerabilities
+ for _, mod := range mv {
+ var filteredVulns []*osv.Entry
+ for _, v := range mod.vulns {
+ if matchesPlatform(os, arch, v.EcosystemSpecific) {
+ filteredVulns = append(filteredVulns, v)
+ }
+ }
+ filteredMod = append(filteredMod, modVulns{
+ mod: mod.mod,
+ vulns: filteredVulns,
+ })
}
- return pkgVulns
+ return filteredMod
}
-// vulnerabilities returns a list of vulnerabilities that deem `pkgPath` vulnerable at `version` as well
-// as `arch` architecture and `os` operating system. Assumes version strings in `pkgVulns` are well-formed;
-// otherwise, the correctness of the results is not guaranteed.
-func (pkgVulns pkgVulnerabilities) vulnerabilities(pkgPath, version, arch, os string) []*osv.Entry {
- vulns, ok := pkgVulns[pkgPath]
- if !ok {
- return nil
+func (mv ModuleVulnerabilities) Num() int {
+ var num int
+ for _, m := range mv {
+ num += len(m.vulns)
}
- return matchingVulns(os, arch, version, vulns)
+ return num
}
-func queryPkgVulns(pkgPath string, env Env, pkgVulns pkgVulnerabilities) []*osv.Entry {
- version, ok := env.PkgVersions[pkgPath]
- if !ok {
- return nil
- }
- return pkgVulns.vulnerabilities(pkgPath, version, env.Arch, env.OS)
-}
-
-// symVulnerabilities map for fast lookup on vulnerable symbols.
-// Maps package paths to symbols to their vulnerabilities.
-type symVulnerabilities map[string]map[string][]*osv.Entry
-
-// Represents any symbol. Used to model vulnerabilities in
-// symVulnerabilties that define every symbol as vulnerable.
-const symWildCard = "*"
-
-// createSymVulns creates a fast symbol-vulnerability look-up map for `vulns`.
-func createSymVulns(vulns []*osv.Entry) symVulnerabilities {
- symVulns := make(symVulnerabilities)
- for _, vuln := range vulns {
- if len(vuln.EcosystemSpecific.Symbols) == 0 {
- // If vuln.Symbols is empty, every symbol is vulnerable.
- symVulns.add(symWildCard, vuln)
- } else {
- for _, sym := range vuln.EcosystemSpecific.Symbols {
- symVulns.add(sym, vuln)
+// VulnsForPackage returns the vulnerabilities for the module which is the most
+// specific prefixof importPath, or nil if there is no matching module with
+// vulnerabilities.
+func (mv ModuleVulnerabilities) VulnsForPackage(importPath string) []*osv.Entry {
+ var mostSpecificMod *modVulns
+ for _, mod := range mv {
+ if strings.HasPrefix(importPath, mod.mod.Path) {
+ if mostSpecificMod == nil || len(mostSpecificMod.mod.Path) < len(mod.mod.Path) {
+ mostSpecificMod = &mod
}
}
}
- return symVulns
-}
-func (symVulns symVulnerabilities) add(symbol string, v *osv.Entry) {
- syms := symVulns[v.Package.Name]
- if syms == nil {
- syms = make(map[string][]*osv.Entry)
- symVulns[v.Package.Name] = syms
- }
- syms[symbol] = append(syms[symbol], v)
-}
-
-// vulnerabilities returns a list of vulnerabilities that deem `symbol` from package `pkgPath` vulnerable at
-// `version`, architecture `arch`, and operating system `os`. Assumes version strings in `symVulns` are well-formed;
-// otherwise, the correctness of the results is not guaranteed.
-func (symVulns symVulnerabilities) vulnerabilities(symbol, pkgPath, version, arch, os string) []*osv.Entry {
- pkgVulns, ok := symVulns[pkgPath]
- if !ok {
+ if mostSpecificMod == nil {
return nil
}
- var vulns []*osv.Entry
- vulns = append(vulns, pkgVulns[symbol]...)
- vulns = append(vulns, pkgVulns[symWildCard]...)
- if len(vulns) == 0 {
- return nil
+ if mostSpecificMod.mod.Replace != nil {
+ importPath = fmt.Sprintf("%s%s", mostSpecificMod.mod.Replace.Path, strings.TrimPrefix(importPath, mostSpecificMod.mod.Path))
}
-
- return matchingVulns(os, arch, version, vulns)
+ vulns := mostSpecificMod.vulns
+ packageVulns := []*osv.Entry{}
+ for _, v := range vulns {
+ if v.Package.Name == importPath {
+ packageVulns = append(packageVulns, v)
+ }
+ }
+ return packageVulns
}
-func querySymbolVulns(symbol, pkgPath string, symVulns symVulnerabilities, env Env) []*osv.Entry {
- version, ok := env.PkgVersions[pkgPath]
- if !ok {
+func (mv ModuleVulnerabilities) VulnsForSymbol(importPath, symbol string) []*osv.Entry {
+ vulns := mv.VulnsForPackage(importPath)
+ if vulns == nil {
return nil
}
- return symVulns.vulnerabilities(symbol, pkgPath, version, env.Arch, env.OS)
+
+ symbolVulns := []*osv.Entry{}
+ for _, v := range vulns {
+ if len(v.EcosystemSpecific.Symbols) == 0 {
+ symbolVulns = append(symbolVulns, v)
+ continue
+ }
+ for _, s := range v.EcosystemSpecific.Symbols {
+ if s == symbol {
+ symbolVulns = append(symbolVulns, v)
+ break
+ }
+ }
+ }
+ return symbolVulns
}
diff --git a/vulndb/internal/audit/detect_binary.go b/vulndb/internal/audit/detect_binary.go
index b5dd324..b9bf49d 100644
--- a/vulndb/internal/audit/detect_binary.go
+++ b/vulndb/internal/audit/detect_binary.go
@@ -12,13 +12,11 @@
// in packageSymbols, given the vulnerability and platform info captured in env.
//
// Returned Findings only have Symbol, Type, and Vulns fields set.
-func VulnerablePackageSymbols(packageSymbols map[string][]string, env Env) []Finding {
- symVulns := createSymVulns(env.Vulns)
-
+func VulnerablePackageSymbols(packageSymbols map[string][]string, modVulns ModuleVulnerabilities) []Finding {
var findings []Finding
for pkg, symbols := range packageSymbols {
for _, symbol := range symbols {
- if vulns := querySymbolVulns(symbol, pkg, symVulns, env); len(vulns) > 0 {
+ if vulns := modVulns.VulnsForSymbol(pkg, symbol); len(vulns) > 0 {
findings = append(findings,
Finding{
Symbol: fmt.Sprintf("%s.%s", pkg, symbol),
diff --git a/vulndb/internal/audit/detect_callgraph.go b/vulndb/internal/audit/detect_callgraph.go
index 5adb8f5..34e434d 100644
--- a/vulndb/internal/audit/detect_callgraph.go
+++ b/vulndb/internal/audit/detect_callgraph.go
@@ -34,7 +34,7 @@
// as traces of transitively using a vulnerable symbol V.
//
// Panics if packages in pkgs do not belong to the same program.
-func VulnerableSymbols(pkgs []*ssa.Package, env Env) []Finding {
+func VulnerableSymbols(pkgs []*ssa.Package, modVulns ModuleVulnerabilities) []Finding {
prog := pkgsProgram(pkgs)
if prog == nil {
panic("packages in pkgs must belong to a single common program")
@@ -47,7 +47,6 @@
queue.PushBack(&callChain{f: entry})
}
- symVulns := createSymVulns(env.Vulns)
var findings []Finding
seen := make(map[*ssa.Function]bool)
for queue.Len() > 0 {
@@ -60,7 +59,7 @@
}
seen[v.f] = true
- finds, calls := funcVulnsAndCalls(v, symVulns, env, callGraph)
+ finds, calls := funcVulnsAndCalls(v, modVulns, callGraph)
findings = append(findings, finds...)
for _, call := range calls {
queue.PushBack(call)
@@ -178,13 +177,13 @@
// funcVulnsAndCalls returns a list of symbol findings for function at the top
// of chain and next calls to analyze.
-func funcVulnsAndCalls(chain *callChain, symVulns symVulnerabilities, env Env, callGraph *callgraph.Graph) ([]Finding, []*callChain) {
+func funcVulnsAndCalls(chain *callChain, modVulns ModuleVulnerabilities, callGraph *callgraph.Graph) ([]Finding, []*callChain) {
var findings []Finding
var calls []*callChain
for _, b := range chain.f.Blocks {
for _, instr := range b.Instrs {
// First collect all findings for globals except callees in function call statements.
- findings = append(findings, globalFindings(globalUses(instr), chain, symVulns, env)...)
+ findings = append(findings, globalFindings(globalUses(instr), chain, modVulns)...)
// Callees are handled separately to produce call findings rather than global findings.
site, ok := instr.(ssa.CallInstruction)
@@ -197,7 +196,7 @@
c := &callChain{call: site, f: callee, parent: chain}
calls = append(calls, c)
- if f := callFinding(c, symVulns, env); f != nil {
+ if f := callFinding(c, modVulns); f != nil {
findings = append(findings, *f)
}
}
@@ -209,15 +208,15 @@
// globalFindings returns findings for vulnerable globals among globalUses.
// Assumes each use in globalUses is a use of a global variable. Can generate
// duplicates when globalUses contains duplicates.
-func globalFindings(globalUses []*ssa.Value, chain *callChain, symVulns symVulnerabilities, env Env) []Finding {
- if underRelatedVuln(chain, symVulns, env) {
+func globalFindings(globalUses []*ssa.Value, chain *callChain, modVulns ModuleVulnerabilities) []Finding {
+ if underRelatedVuln(chain, modVulns) {
return nil
}
var findings []Finding
for _, o := range globalUses {
g := (*o).(*ssa.Global)
- vulns := querySymbolVulns(g.Name(), g.Package().Pkg.Path(), symVulns, env)
+ vulns := modVulns.VulnsForSymbol(g.Package().Pkg.Path(), g.Name())
if len(vulns) > 0 {
findings = append(findings,
Finding{
@@ -235,8 +234,8 @@
// callFinding returns vulnerability finding for the call made at the top of the chain.
// If there is no vulnerability or no call information, then nil is returned.
// TODO(zpavlinovic): remove ssa info from higher-order calls.
-func callFinding(chain *callChain, symVulns symVulnerabilities, env Env) *Finding {
- if underRelatedVuln(chain, symVulns, env) {
+func callFinding(chain *callChain, modVulns ModuleVulnerabilities) *Finding {
+ if underRelatedVuln(chain, modVulns) {
return nil
}
@@ -246,7 +245,7 @@
return nil
}
- vulns := querySymbolVulns(dbFuncName(callee), callee.Package().Pkg.Path(), symVulns, env)
+ vulns := modVulns.VulnsForSymbol(callee.Package().Pkg.Path(), dbFuncName(callee))
if len(vulns) > 0 {
c := chain
if !unresolved(call) {
@@ -277,7 +276,7 @@
//
// Note that for P1:A -> P2:B -> P3:D -> P2:C the function returns false. This
// is because C is called from D that comes from a different package.
-func underRelatedVuln(chain *callChain, symVulns symVulnerabilities, env Env) bool {
+func underRelatedVuln(chain *callChain, modVulns ModuleVulnerabilities) bool {
pkg := pkgPath(chain.f)
c := chain
@@ -288,7 +287,7 @@
break
}
// TODO: can we optimize using the information on findings already reported?
- if len(querySymbolVulns(dbFuncName(c.f), c.f.Pkg.Pkg.Path(), symVulns, env)) > 0 {
+ if len(modVulns.VulnsForSymbol(c.f.Pkg.Pkg.Path(), dbFuncName(c.f))) > 0 {
return true
}
}
diff --git a/vulndb/internal/audit/detect_callgraph_test.go b/vulndb/internal/audit/detect_callgraph_test.go
index a2723f7..615e1fd 100644
--- a/vulndb/internal/audit/detect_callgraph_test.go
+++ b/vulndb/internal/audit/detect_callgraph_test.go
@@ -14,8 +14,8 @@
)
func TestSymbolVulnDetectionVTA(t *testing.T) {
- pkgs, env := testProgAndEnv(t)
- got := projectFindings(VulnerableSymbols(pkgs, env))
+ pkgs, modVulns := testContext(t)
+ got := projectFindings(VulnerableSymbols(pkgs, modVulns))
// There should be four call chains reported with VTA-VTA version, in the following order:
// T:T1() -> vuln.VG [use of global at line 4]
diff --git a/vulndb/internal/audit/detect_imports.go b/vulndb/internal/audit/detect_imports.go
index 40a1b1f..f10de30 100644
--- a/vulndb/internal/audit/detect_imports.go
+++ b/vulndb/internal/audit/detect_imports.go
@@ -24,9 +24,7 @@
// or
// D -> B -> V
// as traces of importing a vulnerable package V.
-func VulnerableImports(pkgs []*ssa.Package, env Env) []Finding {
- pkgVulns := createPkgVulns(env.Vulns)
-
+func VulnerableImports(pkgs []*ssa.Package, modVulns ModuleVulnerabilities) []Finding {
var findings []Finding
seen := make(map[string]bool)
queue := list.New()
@@ -50,7 +48,7 @@
seen[pkg.Path()] = true
for _, imp := range pkg.Imports() {
- vulns := queryPkgVulns(imp.Path(), env, pkgVulns)
+ vulns := modVulns.VulnsForPackage(imp.Path())
if len(vulns) > 0 {
findings = append(findings,
Finding{
diff --git a/vulndb/internal/audit/detect_imports_test.go b/vulndb/internal/audit/detect_imports_test.go
index 0e36097..e36f73a 100644
--- a/vulndb/internal/audit/detect_imports_test.go
+++ b/vulndb/internal/audit/detect_imports_test.go
@@ -13,8 +13,8 @@
)
func TestImportedPackageVulnDetection(t *testing.T) {
- pkgs, env := testProgAndEnv(t)
- got := projectFindings(VulnerableImports(pkgs, env))
+ pkgs, modVulns := testContext(t)
+ got := projectFindings(VulnerableImports(pkgs, modVulns))
// There should be two chains reported in the following order:
// T -> vuln
diff --git a/vulndb/internal/audit/detect_test.go b/vulndb/internal/audit/detect_test.go
index a90c7e3..b9a5ac6 100644
--- a/vulndb/internal/audit/detect_test.go
+++ b/vulndb/internal/audit/detect_test.go
@@ -5,118 +5,187 @@
package audit
import (
+ "fmt"
+ "reflect"
"testing"
+ "golang.org/x/tools/go/packages"
"golang.org/x/vulndb/osv"
)
-var testVulnerabilities = []*osv.Entry{
- {
- Package: osv.Package{
- Name: "xyz.org/vuln",
- },
- Affects: osv.Affects{
- Ranges: []osv.AffectsRange{
- {
- Type: osv.TypeSemver,
- Introduced: "v1.0.1",
- Fixed: "v3.2.6",
- },
- },
- },
- EcosystemSpecific: osv.GoSpecific{
- Symbols: []string{"foo", "bar"},
- GOOS: []string{"amd64"},
- GOARCH: []string{"linux"},
- },
- },
- {
- Package: osv.Package{
- Name: "xyz.org/vuln",
- },
- Affects: osv.Affects{
- Ranges: []osv.AffectsRange{
- {
- Type: osv.TypeSemver,
- Fixed: "v4.0.0",
- },
- },
- },
- EcosystemSpecific: osv.GoSpecific{
- Symbols: []string{"foo"},
- },
- },
- {
- Package: osv.Package{
- Name: "abc.org/morevuln",
- },
- },
+func moduleVulnerabilitiesToString(mv ModuleVulnerabilities) string {
+ var s string
+ for _, m := range mv {
+ s += fmt.Sprintf("mod: %v\n", m.mod)
+ for _, v := range m.vulns {
+ s += fmt.Sprintf("\t%v\n", v)
+ }
+ }
+ return s
}
-func TestPackageVulnCreationAndChecking(t *testing.T) {
- pkgVulns := createPkgVulns(testVulnerabilities)
- if len(pkgVulns) != 2 {
- t.Errorf("want 2 package paths; got %d", len(pkgVulns))
+func TestFilterVulns(t *testing.T) {
+ mv := ModuleVulnerabilities{
+ {
+ mod: &packages.Module{
+ Path: "example.mod/a",
+ Version: "v1.0.0",
+ },
+ vulns: []*osv.Entry{
+ {ID: "a"},
+ {ID: "b", EcosystemSpecific: osv.GoSpecific{GOOS: []string{"windows", "linux"}}},
+ {ID: "c", EcosystemSpecific: osv.GoSpecific{GOARCH: []string{"arm64", "amd64"}}},
+ {ID: "d", EcosystemSpecific: osv.GoSpecific{GOOS: []string{"windows"}}},
+ },
+ },
+ {
+ mod: &packages.Module{
+ Path: "example.mod/b",
+ Version: "v1.0.0",
+ },
+ vulns: []*osv.Entry{
+ {ID: "e", EcosystemSpecific: osv.GoSpecific{GOARCH: []string{"arm64"}}},
+ {ID: "f", EcosystemSpecific: osv.GoSpecific{GOOS: []string{"linux"}}},
+ {ID: "g", EcosystemSpecific: osv.GoSpecific{GOARCH: []string{"amd64"}}},
+ {ID: "h", EcosystemSpecific: osv.GoSpecific{GOOS: []string{"windows"}, GOARCH: []string{"amd64"}}},
+ },
+ },
}
- for _, test := range []struct {
- path string
- version string
- os string
- arch string
- noVulns int
- }{
- // xyz.org/vuln has foo and bar vulns for linux, and just foo for windows.
- {"xyz.org/vuln", "v1.0.1", "amd64", "linux", 2},
- {"xyz.org/vuln", "v1.0.1", "amd64", "windows", 1},
- {"xyz.org/vuln", "v2.4.5", "amd64", "linux", 2},
- {"xyz.org/vuln", "v3.2.7", "amd64", "linux", 1},
- // foo for linux must be at version before v4.0.0.
- {"xyz.org/vuln", "v5.4.5", "amd64", "linux", 0},
- // abc.org/morevuln has vulnerabilities for any symbol, platform, and version
- {"abc.org/morevuln", "v11.0.1", "amd64", "linux", 1},
- {"abc.org/morevuln", "v300.0.1", "i386", "windows", 1},
- } {
- if vulns := pkgVulns.vulnerabilities(test.path, test.version, test.arch, test.os); len(vulns) != test.noVulns {
- t.Errorf("want %d vulnerabilities for %s (v:%s, o:%s, a:%s); got %d",
- test.noVulns, test.path, test.version, test.os, test.path, len(vulns))
- }
+ filtered := mv.Filter("linux", "amd64")
+
+ expected := ModuleVulnerabilities{
+ {
+ mod: &packages.Module{
+ Path: "example.mod/a",
+ Version: "v1.0.0",
+ },
+ vulns: []*osv.Entry{
+ {ID: "a"},
+ {ID: "b", EcosystemSpecific: osv.GoSpecific{GOOS: []string{"windows", "linux"}}},
+ {ID: "c", EcosystemSpecific: osv.GoSpecific{GOARCH: []string{"arm64", "amd64"}}},
+ },
+ },
+ {
+ mod: &packages.Module{
+ Path: "example.mod/b",
+ Version: "v1.0.0",
+ },
+ vulns: []*osv.Entry{
+ {ID: "f", EcosystemSpecific: osv.GoSpecific{GOOS: []string{"linux"}}},
+ {ID: "g", EcosystemSpecific: osv.GoSpecific{GOARCH: []string{"amd64"}}},
+ },
+ },
+ }
+ if !reflect.DeepEqual(filtered, expected) {
+ t.Fatalf("Filter returned unexpected results, got:\n%s\nwant:\n%s", moduleVulnerabilitiesToString(filtered), moduleVulnerabilitiesToString(expected))
}
}
-func TestSymbolVulnCreationAndChecking(t *testing.T) {
- symVulns := createSymVulns(testVulnerabilities)
- if len(symVulns) != 2 {
- t.Errorf("want 2 package paths; got %d", len(symVulns))
+func vulnsToString(vulns []*osv.Entry) string {
+ var s string
+ for _, v := range vulns {
+ s += fmt.Sprintf("\t%v\n", v)
+ }
+ return s
+}
+
+func TestVulnsForPackage(t *testing.T) {
+ mv := ModuleVulnerabilities{
+ {
+ mod: &packages.Module{
+ Path: "example.mod/a",
+ Version: "v1.0.0",
+ },
+ vulns: []*osv.Entry{
+ {ID: "a", Package: osv.Package{Name: "example.mod/a/b/c"}},
+ },
+ },
+ {
+ mod: &packages.Module{
+ Path: "example.mod/a/b",
+ Version: "v1.0.0",
+ },
+ vulns: []*osv.Entry{
+ {ID: "b", Package: osv.Package{Name: "example.mod/a/b/c"}},
+ },
+ },
}
- for _, test := range []struct {
- symbol string
- path string
- version string
- os string
- arch string
- numVulns int
- }{
- // foo appears twice as a vulnerable symbol for "xyz.org/vuln" and bar once.
- {"foo", "xyz.org/vuln", "v1.0.1", "amd64", "linux", 2},
- {"bar", "xyz.org/vuln", "v1.0.1", "amd64", "linux", 1},
- // foo and bar detected vulns should go down by one for windows platform as well as i386 architecture.
- {"foo", "xyz.org/vuln", "v1.0.1", "amd64", "windows", 1},
- {"bar", "xyz.org/vuln", "v1.0.1", "i386", "linux", 0},
- // There should be no findings for foo and bar at module version v5.0.0.
- {"foo", "xyz.org/vuln", "v5.0.0", "amd64", "linux", 0},
- {"bar", "xyz.org/vuln", "v5.0.0", "amd64", "linux", 0},
- // symbol is not a vulnerable symbol for xyz.org/vuln and bogus package is not in the database.
- {"symbol", "xyz.org/vuln", "v1.0.1", "amd64", "linux", 0},
- {"foo", "bogus", "v1.0.1", "amd64", "linux", 0},
- // abc.org/morevuln has vulnerabilities for any symbol, platform, and version
- {"symbol", "abc.org/morevuln", "v2.0.1", "amd64", "linux", 1},
- {"lobmys", "abc.org/morevuln", "v300.0.1", "i386", "windows", 1},
- } {
- if vulns := symVulns.vulnerabilities(test.symbol, test.path, test.version, test.arch, test.os); len(vulns) != test.numVulns {
- t.Errorf("want %d vulnerabilities for %s (p:%s v:%s, o:%s, a:%s); got %d",
- test.numVulns, test.symbol, test.path, test.version, test.os, test.arch, len(vulns))
- }
+ filtered := mv.VulnsForPackage("example.mod/a/b/c")
+ expected := []*osv.Entry{
+ {ID: "b", Package: osv.Package{Name: "example.mod/a/b/c"}},
+ }
+
+ if !reflect.DeepEqual(filtered, expected) {
+ t.Fatalf("VulnsForPackage returned unexpected results, got:\n%s\nwant:\n%s", vulnsToString(filtered), vulnsToString(expected))
+ }
+}
+
+func TestVulnsForPackageReplaced(t *testing.T) {
+ mv := ModuleVulnerabilities{
+ {
+ mod: &packages.Module{
+ Path: "example.mod/a",
+ Version: "v1.0.0",
+ },
+ vulns: []*osv.Entry{
+ {ID: "a", Package: osv.Package{Name: "example.mod/a/b/c"}},
+ },
+ },
+ {
+ mod: &packages.Module{
+ Path: "example.mod/a/b",
+ Replace: &packages.Module{
+ Path: "example.mod/b",
+ },
+ Version: "v1.0.0",
+ },
+ vulns: []*osv.Entry{
+ {ID: "c", Package: osv.Package{Name: "example.mod/b/c"}},
+ },
+ },
+ }
+
+ filtered := mv.VulnsForPackage("example.mod/a/b/c")
+ expected := []*osv.Entry{
+ {ID: "c", Package: osv.Package{Name: "example.mod/b/c"}},
+ }
+
+ if !reflect.DeepEqual(filtered, expected) {
+ t.Fatalf("VulnsForPackage returned unexpected results, got:\n%s\nwant:\n%s", vulnsToString(filtered), vulnsToString(expected))
+ }
+}
+
+func TestVulnsForSymbol(t *testing.T) {
+ mv := ModuleVulnerabilities{
+ {
+ mod: &packages.Module{
+ Path: "example.mod/a",
+ Version: "v1.0.0",
+ },
+ vulns: []*osv.Entry{
+ {ID: "a", Package: osv.Package{Name: "example.mod/a/b/c"}},
+ },
+ },
+ {
+ mod: &packages.Module{
+ Path: "example.mod/a/b",
+ Version: "v1.0.0",
+ },
+ vulns: []*osv.Entry{
+ {ID: "b", Package: osv.Package{Name: "example.mod/a/b/c"}, EcosystemSpecific: osv.GoSpecific{Symbols: []string{"a"}}},
+ {ID: "c", Package: osv.Package{Name: "example.mod/a/b/c"}, EcosystemSpecific: osv.GoSpecific{Symbols: []string{"b"}}},
+ },
+ },
+ }
+
+ filtered := mv.VulnsForSymbol("example.mod/a/b/c", "a")
+ expected := []*osv.Entry{
+ {ID: "b", Package: osv.Package{Name: "example.mod/a/b/c"}, EcosystemSpecific: osv.GoSpecific{Symbols: []string{"a"}}},
+ }
+
+ if !reflect.DeepEqual(filtered, expected) {
+ t.Fatalf("VulnsForPackage returned unexpected results, got:\n%s\nwant:\n%s", vulnsToString(filtered), vulnsToString(expected))
}
}
diff --git a/vulndb/internal/audit/helpers_test.go b/vulndb/internal/audit/helpers_test.go
index f00cd7e..a93086d 100644
--- a/vulndb/internal/audit/helpers_test.go
+++ b/vulndb/internal/audit/helpers_test.go
@@ -7,7 +7,6 @@
import (
"go/token"
"io/ioutil"
- "os"
"path"
"path/filepath"
"strings"
@@ -40,9 +39,7 @@
//
// The following vulnerability should not be reported as it is redundant:
// T:T1() -> A:A1() -> B:B1() -> vuln.VulnData.Vuln()
-//
-// The produced environment is based on testdata/dbs vulnerability databases.
-func testProgAndEnv(t *testing.T) ([]*ssa.Package, Env) {
+func testContext(t *testing.T) ([]*ssa.Package, ModuleVulnerabilities) {
e := packagestest.Export(t, packagestest.Modules, []packagestest.Module{
{
Name: "golang.org/vulntest",
@@ -63,7 +60,7 @@
})
defer e.Cleanup()
- _, ssaPkgs, pkgs, err := loadAndBuildPackages(e, "/vulntest/T/T.go")
+ _, ssaPkgs, _, err := loadAndBuildPackages(e, "/vulntest/T/T.go")
if err != nil {
t.Fatal(err)
}
@@ -71,14 +68,25 @@
t.Errorf("want 1 top level SSA package; got %d", len(ssaPkgs))
}
- vulnsToLoad := []string{"thirdparty.org/vulnerabilities", "bogus.org/module"}
- dbSources := []string{fileSource(t, "testdata/dbs/bogus.db.org"), fileSource(t, "testdata/dbs/golang.deepgo.org")}
- vulns, err := LoadVulnerabilities(dbSources, vulnsToLoad)
- if err != nil {
- t.Fatal(err)
+ modVulns := ModuleVulnerabilities{
+ {
+ mod: &packages.Module{Path: "thirdparty.org/vulnerabilities", Version: "v1.0.1"},
+ vulns: []*osv.Entry{
+ {
+ Package: osv.Package{Name: "thirdparty.org/vulnerabilities/vuln"},
+ Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Introduced: "1.0.0", Fixed: "1.0.4"}, {Type: osv.TypeSemver, Introduced: "1.1.2"}}},
+ EcosystemSpecific: osv.GoSpecific{Symbols: []string{"VulnData.Vuln", "VulnData.VulnOnPtr"}},
+ },
+ {
+ Package: osv.Package{Name: "thirdparty.org/vulnerabilities/vuln"},
+ Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Introduced: "1.0.1", Fixed: "1.0.2"}}},
+ EcosystemSpecific: osv.GoSpecific{Symbols: []string{"VG"}},
+ },
+ },
+ },
}
- return ssaPkgs, Env{OS: "linux", Arch: "amd64", Vulns: vulns, PkgVersions: PackageVersions(pkgs)}
+ return ssaPkgs, modVulns
}
func loadAndBuildPackages(e *packagestest.Exported, file string) (*ssa.Program, []*ssa.Package, []*packages.Package, error) {
@@ -146,16 +154,6 @@
return nfs
}
-// fileSource creates a file URI for a database path `db`. If `db` is
-// relative, the source is made absolute w.r.t. the current directory.
-func fileSource(t *testing.T, db string) string {
- cd, err := os.Getwd()
- if err != nil {
- t.Fatal(err)
- }
- return "file://" + path.Join(cd, db)
-}
-
func readFile(t *testing.T, path string) string {
content, err := ioutil.ReadFile(path)
if err != nil {
diff --git a/vulndb/internal/audit/testdata/dbs/bogus.db.org/bogus.org/module.json b/vulndb/internal/audit/testdata/dbs/bogus.db.org/bogus.org/module.json
deleted file mode 100644
index 222d03e..0000000
--- a/vulndb/internal/audit/testdata/dbs/bogus.db.org/bogus.org/module.json
+++ /dev/null
@@ -1,21 +0,0 @@
-[
- {
- "package": {
- "name": "bogus.org/module/vuln"
- },
- "affects": {
- "ranges": [
- {
- "type": "SEMVER",
- "fixed": "v2.0.0"
- }
- ]
- },
- "ecosystem_specific": {
- "symbols": [
- "Bogus"
- ],
- "url": "bogus.org/bogus/README.doc"
- }
- }
-]
diff --git a/vulndb/internal/audit/testdata/dbs/bogus.db.org/index.json b/vulndb/internal/audit/testdata/dbs/bogus.db.org/index.json
deleted file mode 100644
index b7fa1b0..0000000
--- a/vulndb/internal/audit/testdata/dbs/bogus.db.org/index.json
+++ /dev/null
@@ -1,3 +0,0 @@
-{
- "bogus.org/module/vuln": "2020-03-06T09:21:06.31369157-07:00"
-}
diff --git a/vulndb/internal/audit/testdata/dbs/golang.deepgo.org/index.json b/vulndb/internal/audit/testdata/dbs/golang.deepgo.org/index.json
deleted file mode 100644
index 9129812..0000000
--- a/vulndb/internal/audit/testdata/dbs/golang.deepgo.org/index.json
+++ /dev/null
@@ -1,3 +0,0 @@
-{
- "thirdparty.org/vulnerabilities/vuln": "2020-04-05T10:21:50.21362171-07:00"
-}
diff --git a/vulndb/internal/audit/testdata/dbs/golang.deepgo.org/thirdparty.org/README b/vulndb/internal/audit/testdata/dbs/golang.deepgo.org/thirdparty.org/README
deleted file mode 100644
index 37db9a9..0000000
--- a/vulndb/internal/audit/testdata/dbs/golang.deepgo.org/thirdparty.org/README
+++ /dev/null
@@ -1,4 +0,0 @@
-Contains json files modeling vulnerability info for the module
-`thirdparty.org/vulnerabilities`.
-
-Also used for testing the robustness of loading a local vulnerability database.
diff --git a/vulndb/internal/audit/testdata/dbs/golang.deepgo.org/thirdparty.org/vulnerabilities.json b/vulndb/internal/audit/testdata/dbs/golang.deepgo.org/thirdparty.org/vulnerabilities.json
deleted file mode 100644
index cfc4d70..0000000
--- a/vulndb/internal/audit/testdata/dbs/golang.deepgo.org/thirdparty.org/vulnerabilities.json
+++ /dev/null
@@ -1,70 +0,0 @@
-[
- {
- "package": {
- "name": "thirdparty.org/vulnerabilities/vuln"
- },
- "affects": {
- "ranges": [
- {
- "type": "SEMVER",
- "introduced": "v1.0.0",
- "fixed": "v1.0.4"
- },
- {
- "type": "SEMVER",
- "introduced": "v1.1.2"
- }
- ]
- },
- "ecosystem_specific": {
- "symbols": [
- "VulnData.Vuln",
- "VulnData.VulnOnPtr"
- ],
- "url": "thirdparty.org/vulnerabilities/README.doc"
- }
- },
- {
- "package": {
- "name": "thirdparty.org/vulnerabilities/vuln"
- },
- "affects": {
- "ranges": [
- {
- "type": "SEMVER",
- "introduced": "v1.2.0",
- "fixed": "v1.3.2"
- }
- ]
- },
- "ecosystem_specific": {
- "goarch": [
- "amd64"
- ],
- "goos": [
- "linux"
- ],
- "url": "thirdparty.org/vulnerabilities/README_amd64.doc"
- }
- },
- {
- "package": {
- "name": "thirdparty.org/vulnerabilities/vuln"
- },
- "affects": {
- "ranges": [
- {
- "type": "SEMVER",
- "introduced": "v1.0.1",
- "fixed": "v1.0.2"
- }
- ]
- },
- "ecosystem_specific": {
- "symbols": [
- "VG"
- ],
- "url": "thirdparty.org/vulnerabilities/README_global.doc"
- }
- }
-]
diff --git a/vulndb/internal/audit/version.go b/vulndb/internal/audit/version.go
deleted file mode 100644
index 47a38ab..0000000
--- a/vulndb/internal/audit/version.go
+++ /dev/null
@@ -1,62 +0,0 @@
-// Copyright 2021 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package audit
-
-import (
- "golang.org/x/tools/go/packages"
-)
-
-// Returns module version of a package pkg. If the version is "" and the module is
-// replaced by another module with the same path, replaced module version is returned.
-// TODO(zpavlinovic): check if this is complete/correct.
-func version(pkg *packages.Package) string {
- module := pkg.Module
- if module == nil {
- return ""
- }
-
- if module.Version != "" {
- return module.Version
- }
-
- if module.Replace == nil || module.Replace.Path != module.Path {
- return ""
- }
- return module.Replace.Version
-}
-
-// populateVersionInfo recursively populates pkgVersions for the input package pkg and its transitive dependencies.
-func populatePkgVersions(pkg *packages.Package, pkgVersions map[string]string, seen map[string]bool) {
- if _, ok := seen[pkg.PkgPath]; ok {
- return
- }
- seen[pkg.PkgPath] = true
-
- version := version(pkg)
- if version != "" {
- pkgVersions[pkg.PkgPath] = version
- }
-
- for _, imp := range pkg.Imports {
- populatePkgVersions(imp, pkgVersions, seen)
- }
-}
-
-// PackageVersions computes a map from a path of every package in pkgs and
-// its transitive dependencies to their module version. If module or its
-// version are not present, the corresponding package is not in the map.
-//
-// Does not check for well-formedness of version strings. If such strings
-// exist, the produced map can lead to confusing results down the line.
-// (Well-formedness of version strings should be checked by external tools,
-// such as using golang.org/x/tools/go/packages.Load to construct pkgs.)
-func PackageVersions(pkgs []*packages.Package) map[string]string {
- pkgVersions := make(map[string]string)
- seen := make(map[string]bool)
- for _, pkg := range pkgs {
- populatePkgVersions(pkg, pkgVersions, seen)
- }
- return pkgVersions
-}
diff --git a/vulndb/internal/audit/version_test.go b/vulndb/internal/audit/version_test.go
deleted file mode 100644
index fb55cc5..0000000
--- a/vulndb/internal/audit/version_test.go
+++ /dev/null
@@ -1,79 +0,0 @@
-// Copyright 2021 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package audit
-
-import (
- "testing"
-
- "golang.org/x/tools/go/packages/packagestest"
-)
-
-func TestPackageVersionInfo(t *testing.T) {
- // Export package testdata with a program depending on a vulnerability package
- // vuln with version "v1.0.1".
- e := packagestest.Export(t, packagestest.Modules, []packagestest.Module{
- {
- Name: "golang.org/vulntest",
- Files: map[string]interface{}{
- "testdata/testdata.go": `
- package testdata
-
- import (
- "thirdparty.org/vulnerabilities/vuln"
- )
-
- func Lib1() {
- vuln.Vuln()
- }
- `,
- },
- },
- {
- Name: "thirdparty.org/vulnerabilities@v1.0.1",
- Files: map[string]interface{}{
- "vuln/vuln.go": `
- package vuln
-
- import (
- "abc.org/xyz/foo"
- )
-
- func Vuln() { foo.Foo() }
- `,
- },
- },
- {
- Name: "abc.org/xyz@v0.0.0-20201002170205-7f63de1d35b0",
- Files: map[string]interface{}{
- "foo/foo.go": `
- package foo
-
- func Foo() { }
- `,
- },
- },
- })
- defer e.Cleanup()
-
- _, _, pkgs, err := loadAndBuildPackages(e, "/vulntest/testdata/testdata.go")
- if err != nil {
- t.Fatal(err)
- }
-
- v := PackageVersions(pkgs)
- for _, test := range []struct {
- path string
- version string
- in bool
- }{
- {"command-line-arguments", "", false},
- {"thirdparty.org/vulnerabilities/vuln", "v1.0.1", true},
- {"abc.org/xyz/foo", "v0.0.0-20201002170205-7f63de1d35b0", true},
- } {
- if version, ok := v[test.path]; ok != test.in || version != test.version {
- t.Errorf("want package %s at version %s in=%t package-version map; got %s and %t", test.path, test.version, test.in, version, ok)
- }
- }
-}
diff --git a/vulndb/internal/audit/vulnerability.go b/vulndb/internal/audit/vulnerability.go
index 4ba5937..142bb57 100644
--- a/vulndb/internal/audit/vulnerability.go
+++ b/vulndb/internal/audit/vulnerability.go
@@ -5,18 +5,43 @@
package audit
import (
+ "golang.org/x/tools/go/packages"
"golang.org/x/vulndb/osv"
-
- "golang.org/x/vulndb/client"
)
-// LoadVulnerabilities fetches vulnerabilities for pkgs in dbs. Currently,
-// no caching is enabled.
-// TODO: add cache support once it is amenable to side-effect free testing.
-func LoadVulnerabilities(dbs []string, pkgs []string) ([]*osv.Entry, error) {
- dbClient, err := client.NewClient(dbs, client.Options{})
- if err != nil {
- return nil, err
+type dbClient interface {
+ Get([]string) ([]*osv.Entry, error)
+}
+
+// FetchVulnerabilities fetches vulnerabilities that affect the supplied modules.
+func FetchVulnerabilities(client dbClient, modules []*packages.Module) (ModuleVulnerabilities, error) {
+ mv := ModuleVulnerabilities{}
+ for _, mod := range modules {
+ modPath := mod.Path
+ modVersion := mod.Version
+ if mod.Replace != nil {
+ modPath = mod.Replace.Path
+ modVersion = mod.Replace.Version
+ }
+ vulns, err := client.Get([]string{modPath})
+ if err != nil {
+ return nil, err
+ }
+ // TODO(rolandshoemaker): we may want to consider moving this functionality into
+ // ModuleVulnerabilities.Filter, consolidating the filtering logic in one place.
+ var filteredVulns []*osv.Entry
+ for _, v := range vulns {
+ if v.Affects.AffectsSemver(modVersion) {
+ filteredVulns = append(filteredVulns, v)
+ }
+ }
+ if len(filteredVulns) == 0 {
+ continue
+ }
+ mv = append(mv, modVulns{
+ mod: mod,
+ vulns: filteredVulns,
+ })
}
- return dbClient.Get(pkgs)
+ return mv, nil
}
diff --git a/vulndb/internal/audit/vulnerability_test.go b/vulndb/internal/audit/vulnerability_test.go
index 9b7f2de..79a1189 100644
--- a/vulndb/internal/audit/vulnerability_test.go
+++ b/vulndb/internal/audit/vulnerability_test.go
@@ -5,55 +5,57 @@
package audit
import (
- "os"
- "path"
"reflect"
"testing"
+ "golang.org/x/tools/go/packages"
"golang.org/x/vulndb/osv"
)
-// Testing utility function that simplifies vulns by projecting each vulnerability
-// to Path, and Symbol fields only.
-func vulnProject(vulns []*osv.Entry) map[string][]osv.Entry {
- projVulns := make(map[string][]osv.Entry)
- for _, vuln := range vulns {
- projVulns[vuln.Package.Name] = append(projVulns[vuln.Package.Name],
- osv.Entry{Package: osv.Package{Name: vuln.Package.Name}, EcosystemSpecific: osv.GoSpecific{Symbols: vuln.EcosystemSpecific.Symbols}})
- }
- return projVulns
+type mockClient struct {
+ ret map[string][]*osv.Entry
}
-func TestLoadVulnerabilities(t *testing.T) {
- cd, err := os.Getwd()
+func (mc *mockClient) Get(a []string) ([]*osv.Entry, error) {
+ return mc.ret[a[0]], nil
+}
+
+func TestFetchVulnerabilities(t *testing.T) {
+ mc := &mockClient{
+ ret: map[string][]*osv.Entry{
+ "example.mod/a": {
+ {ID: "a", Package: osv.Package{Name: "example.mod/a"}, Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Fixed: "v2.0.0"}}}},
+ {ID: "b", Package: osv.Package{Name: "example.mod/a"}, Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Fixed: "v1.0.0"}}}},
+ },
+ "example.mod/b": {{ID: "c", Package: osv.Package{Name: "example.mod/b"}, Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Fixed: "v1.0.0"}}}}},
+ "example.mod/d": {{ID: "c", Package: osv.Package{Name: "example.mod/d"}, Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Fixed: "v2.0.0"}}}}},
+ },
+ }
+
+ mv, err := FetchVulnerabilities(mc, []*packages.Module{
+ {Path: "example.mod/a", Version: "v1.0.0"},
+ {Path: "example.mod/b", Version: "v1.0.0"},
+ {Path: "example.mod/c", Replace: &packages.Module{Path: "example.mod/d", Version: "v1.0.0"}, Version: "v2.0.0"},
+ })
if err != nil {
- t.Fatal(err)
+ t.Fatalf("FetchVulnerabilities failed: %s", err)
}
- vulns, err := LoadVulnerabilities([]string{"file://" + path.Join(cd, "testdata/dbs/bogus.db.org"), "file://" + path.Join(cd, "testdata/dbs/golang.deepgo.org")},
- []string{"thirdparty.org/vulnerabilities", "bogus.org/module"})
- if err != nil {
- t.Fatal(err)
- }
-
- testVulnDb := make(map[string][]osv.Entry)
- testVulnDb["thirdparty.org/vulnerabilities/vuln"] = []osv.Entry{
- {Package: osv.Package{Name: "thirdparty.org/vulnerabilities/vuln"},
- EcosystemSpecific: osv.GoSpecific{Symbols: []string{"VulnData.Vuln", "VulnData.VulnOnPtr"}},
+ expected := ModuleVulnerabilities{
+ {
+ mod: &packages.Module{Path: "example.mod/a", Version: "v1.0.0"},
+ vulns: []*osv.Entry{
+ {ID: "a", Package: osv.Package{Name: "example.mod/a"}, Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Fixed: "v2.0.0"}}}},
+ },
},
- {Package: osv.Package{Name: "thirdparty.org/vulnerabilities/vuln"}},
- {Package: osv.Package{Name: "thirdparty.org/vulnerabilities/vuln"},
- EcosystemSpecific: osv.GoSpecific{Symbols: []string{"VG"}},
+ {
+ mod: &packages.Module{Path: "example.mod/c", Replace: &packages.Module{Path: "example.mod/d", Version: "v1.0.0"}, Version: "v2.0.0"},
+ vulns: []*osv.Entry{
+ {ID: "c", Package: osv.Package{Name: "example.mod/d"}, Affects: osv.Affects{Ranges: []osv.AffectsRange{{Type: osv.TypeSemver, Fixed: "v2.0.0"}}}},
+ },
},
}
- testVulnDb["bogus.org/module/vuln"] = []osv.Entry{
- {Package: osv.Package{Name: "bogus.org/module/vuln"},
- EcosystemSpecific: osv.GoSpecific{Symbols: []string{"Bogus"}},
- },
- }
-
- projVulnDb := vulnProject(vulns)
- if !reflect.DeepEqual(testVulnDb, projVulnDb) {
- t.Errorf("want %v vulnerability database; got (simplified) %v", testVulnDb, projVulnDb)
+ if !reflect.DeepEqual(mv, expected) {
+ t.Fatalf("FetchVulnerabilities returned unexpected results, got:\n%s\nwant:\n%s", moduleVulnerabilitiesToString(mv), moduleVulnerabilitiesToString(expected))
}
}
diff --git a/vulndb/internal/binscan/scan.go b/vulndb/internal/binscan/scan.go
index eb57c23..da4f11e 100644
--- a/vulndb/internal/binscan/scan.go
+++ b/vulndb/internal/binscan/scan.go
@@ -18,6 +18,8 @@
"net/url"
"runtime/debug"
"strings"
+
+ "golang.org/x/tools/go/packages"
)
// buildInfoMagic, findVers, and readString are copied from
@@ -167,33 +169,31 @@
return info, true
}
+func debugModulesToPackagesModules(debugModules []*debug.Module) []*packages.Module {
+ packagesModules := make([]*packages.Module, len(debugModules))
+ for i, mod := range debugModules {
+ packagesModules[i] = &packages.Module{
+ Path: mod.Path,
+ Version: mod.Version,
+ }
+ if mod.Replace != nil {
+ packagesModules[i].Replace = &packages.Module{
+ Path: mod.Replace.Path,
+ Version: mod.Replace.Version,
+ }
+ }
+ }
+ return packagesModules
+}
+
// ExtractPackagesAndSymbols extracts the symbols, packages, and their associated module versions
// from a Go binary. Stripped binaries are not supported.
-func ExtractPackagesAndSymbols(binPath string) (map[string]string, map[string][]string, error) {
+func ExtractPackagesAndSymbols(binPath string) ([]*packages.Module, map[string][]string, error) {
x, err := openExe(binPath)
if err != nil {
return nil, nil, err
}
- mod := findVers(x)
-
- bi, ok := readBuildInfo(mod)
- if !ok {
- return nil, nil, err
- }
-
- deps := map[string]string{}
- for _, dep := range bi.Deps {
- if dep == nil {
- continue
- }
- if dep.Replace != nil {
- deps[dep.Replace.Path] = dep.Replace.Version
- continue
- }
- deps[dep.Path] = dep.Version
- }
-
pclntab, textOffset := x.PCLNTab()
lineTab := gosym.NewLineTable(pclntab, textOffset)
if lineTab == nil {
@@ -229,14 +229,10 @@
packageSymbols[pkgName] = append(packageSymbols[pkgName], symName)
}
- versionedPackages := map[string]string{}
- // TODO: this is rather inefficient, but probably fine for most programs
- for pkg := range packageSymbols {
- for mod, version := range deps {
- if strings.HasPrefix(pkg, mod) {
- versionedPackages[pkg] = version
- }
- }
+ bi, ok := readBuildInfo(findVers(x))
+ if !ok {
+ return nil, nil, err
}
- return versionedPackages, packageSymbols, nil
+
+ return debugModulesToPackagesModules(bi.Deps), packageSymbols, nil
}