vulncheck: replace packages.Package with vulncheck.Package
packages.Package can consume a lot of memory for some clients,
so we use its trimmed version vulncheck.Packages instead.
Cherry-picked: https://go-review.googlesource.com/c/exp/+/369135
Change-Id: I530174219824879ccc39ca0f3abb2c9d6c71a3d0
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/395045
Trust: Julie Qiu <julie@golang.org>
Run-TryBot: Julie Qiu <julie@golang.org>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/vulncheck/binary.go b/vulncheck/binary.go
index 65ea431..66721c5 100644
--- a/vulncheck/binary.go
+++ b/vulncheck/binary.go
@@ -7,17 +7,18 @@
import (
"io"
+ "golang.org/x/tools/go/packages"
"golang.org/x/vuln/vulncheck/internal/binscan"
)
// Binary detects presence of vulnerable symbols in exe. The
// imports, require, and call graph are all unavailable (nil).
func Binary(exe io.ReaderAt, cfg *Config) (*Result, error) {
- modules, packageSymbols, err := binscan.ExtractPackagesAndSymbols(exe)
+ mods, packageSymbols, err := binscan.ExtractPackagesAndSymbols(exe)
if err != nil {
return nil, err
}
- modVulns, err := fetchVulnerabilities(cfg.Client, modules)
+ modVulns, err := fetchVulnerabilities(cfg.Client, convertModules(mods))
if err != nil {
return nil, err
}
@@ -45,3 +46,22 @@
}
return result, nil
}
+
+func convertModules(mods []*packages.Module) []*Module {
+ vmods := make([]*Module, len(mods))
+ // TODO(github.com/golang/go/issues/50030): should we share unique
+ // modules? Not needed nowas module info is not returned by Binary.
+ for i, mod := range mods {
+ vmods[i] = &Module{
+ Path: mod.Path,
+ Version: mod.Version,
+ }
+ if mod.Replace != nil {
+ vmods[i].Replace = &Module{
+ Path: mod.Replace.Path,
+ Version: mod.Replace.Version,
+ }
+ }
+ }
+ return vmods
+}
diff --git a/vulncheck/fetch.go b/vulncheck/fetch.go
index 5cef708..9cbb214 100644
--- a/vulncheck/fetch.go
+++ b/vulncheck/fetch.go
@@ -11,12 +11,11 @@
"path/filepath"
"strings"
- "golang.org/x/tools/go/packages"
"golang.org/x/vulndb/client"
)
// modKey creates a unique string identifier for mod.
-func modKey(mod *packages.Module) string {
+func modKey(mod *Module) string {
if mod == nil {
return ""
}
@@ -25,12 +24,12 @@
// extractModules collects modules in `pkgs` up to uniqueness of
// module path and version.
-func extractModules(pkgs []*packages.Package) []*packages.Module {
- modMap := map[string]*packages.Module{}
+func extractModules(pkgs []*Package) []*Module {
+ modMap := map[string]*Module{}
- seen := map[*packages.Package]bool{}
- var extract func(*packages.Package, map[string]*packages.Module)
- extract = func(pkg *packages.Package, modMap map[string]*packages.Module) {
+ seen := map[*Package]bool{}
+ var extract func(*Package, map[string]*Module)
+ extract = func(pkg *Package, modMap map[string]*Module) {
if pkg == nil || seen[pkg] {
return
}
@@ -50,7 +49,7 @@
extract(pkg, modMap)
}
- modules := []*packages.Module{}
+ modules := []*Module{}
for _, mod := range modMap {
modules = append(modules, mod)
}
@@ -58,7 +57,7 @@
}
// fetchVulnerabilities fetches vulnerabilities that affect the supplied modules.
-func fetchVulnerabilities(client client.Client, modules []*packages.Module) (moduleVulnerabilities, error) {
+func fetchVulnerabilities(client client.Client, modules []*Module) (moduleVulnerabilities, error) {
mv := moduleVulnerabilities{}
for _, mod := range modules {
modPath := mod.Path
@@ -91,7 +90,7 @@
// loading local vulnerabilities in testing.
var fetchingInTesting bool = false
-func isLocal(mod *packages.Module) bool {
+func isLocal(mod *Module) bool {
if fetchingInTesting {
return false
}
@@ -101,7 +100,6 @@
}
return !strings.HasPrefix(modDir, modCacheDirectory())
}
-
func modCacheDirectory() string {
var modCacheDir string
// TODO: define modCacheDir using something similar to cmd/go/internal/cfg.GOMODCACHE?
diff --git a/vulncheck/fetch_test.go b/vulncheck/fetch_test.go
index a60edee..252282f 100644
--- a/vulncheck/fetch_test.go
+++ b/vulncheck/fetch_test.go
@@ -8,7 +8,6 @@
"reflect"
"testing"
- "golang.org/x/tools/go/packages"
"golang.org/x/vulndb/osv"
)
@@ -22,11 +21,11 @@
},
}
- mv, err := fetchVulnerabilities(mc, []*packages.Module{
+ mv, err := fetchVulnerabilities(mc, []*Module{
{Path: "example.mod/a", Dir: modCacheDirectory(), Version: "v1.0.0"},
{Path: "example.mod/b", Dir: modCacheDirectory(), Version: "v1.0.4"},
- {Path: "example.mod/c", Replace: &packages.Module{Path: "example.mod/d", Dir: modCacheDirectory(), Version: "v1.0.0"}, Version: "v2.0.0"},
- {Path: "example.mod/e", Replace: &packages.Module{Path: "../local/example.mod/d", Dir: modCacheDirectory(), Version: "v1.0.1"}, Version: "v2.1.0"},
+ {Path: "example.mod/c", Replace: &Module{Path: "example.mod/d", Dir: modCacheDirectory(), Version: "v1.0.0"}, Version: "v2.0.0"},
+ {Path: "example.mod/e", Replace: &Module{Path: "../local/example.mod/d", Dir: modCacheDirectory(), Version: "v1.0.1"}, Version: "v2.1.0"},
})
if err != nil {
t.Fatalf("FetchVulnerabilities failed: %s", err)
@@ -34,19 +33,19 @@
expected := moduleVulnerabilities{
{
- mod: &packages.Module{Path: "example.mod/a", Dir: modCacheDirectory(), Version: "v1.0.0"},
+ mod: &Module{Path: "example.mod/a", Dir: modCacheDirectory(), Version: "v1.0.0"},
vulns: []*osv.Entry{
{ID: "a", Affected: []osv.Affected{{Package: osv.Package{Name: "example.mod/a"}, Ranges: osv.Affects{{Type: osv.TypeSemver, Events: []osv.RangeEvent{{Fixed: "2.0.0"}}}}}}},
},
},
{
- mod: &packages.Module{Path: "example.mod/b", Dir: modCacheDirectory(), Version: "v1.0.4"},
+ mod: &Module{Path: "example.mod/b", Dir: modCacheDirectory(), Version: "v1.0.4"},
vulns: []*osv.Entry{
{ID: "b", Affected: []osv.Affected{{Package: osv.Package{Name: "example.mod/b"}, Ranges: osv.Affects{{Type: osv.TypeSemver, Events: []osv.RangeEvent{{Fixed: "1.1.1"}}}}}}},
},
},
{
- mod: &packages.Module{Path: "example.mod/c", Replace: &packages.Module{Path: "example.mod/d", Dir: modCacheDirectory(), Version: "v1.0.0"}, Version: "v2.0.0"},
+ mod: &Module{Path: "example.mod/c", Replace: &Module{Path: "example.mod/d", Dir: modCacheDirectory(), Version: "v1.0.0"}, Version: "v2.0.0"},
vulns: []*osv.Entry{
{ID: "c", Affected: []osv.Affected{{Package: osv.Package{Name: "example.mod/d"}, Ranges: osv.Affects{{Type: osv.TypeSemver, Events: []osv.RangeEvent{{Fixed: "2.0.0"}}}}}}},
},
diff --git a/vulncheck/source.go b/vulncheck/source.go
index 4ec495c..cac1ac3 100644
--- a/vulncheck/source.go
+++ b/vulncheck/source.go
@@ -6,9 +6,7 @@
import (
"golang.org/x/tools/go/callgraph"
- "golang.org/x/tools/go/packages"
"golang.org/x/tools/go/ssa"
- "golang.org/x/tools/go/ssa/ssautil"
"golang.org/x/vulndb/osv"
)
@@ -19,7 +17,7 @@
// package that has some known vulnerabilities
// - call graph leading to the use of a known vulnerable function
// or method
-func Source(pkgs []*packages.Package, cfg *Config) (*Result, error) {
+func Source(pkgs []*Package, cfg *Config) (*Result, error) {
modVulns, err := fetchVulnerabilities(cfg.Client, extractModules(pkgs))
if err != nil {
return nil, err
@@ -37,8 +35,7 @@
return result, nil
}
- prog, ssaPkgs := ssautil.AllPackages(pkgs, 0)
- prog.Build()
+ prog, ssaPkgs := buildSSA(pkgs)
entries := entryPoints(ssaPkgs)
cg := callGraph(prog, entries)
vulnCallGraphSlice(entries, modVulns, cg, result)
@@ -57,12 +54,12 @@
// vulnPkgModSlice computes the slice of pkgs imports and requires graph
// leading to imports/requires of vulnerable packages/modules in modVulns
// and stores the computed slices to result.
-func vulnPkgModSlice(pkgs []*packages.Package, modVulns moduleVulnerabilities, result *Result) {
+func vulnPkgModSlice(pkgs []*Package, modVulns moduleVulnerabilities, result *Result) {
// analyzedPkgs contains information on packages analyzed thus far.
// If a package is mapped to nil, this means it has been visited
// but it does not lead to a vulnerable imports. Otherwise, a
// visited package is mapped to Imports package node.
- analyzedPkgs := make(map[*packages.Package]*PkgNode)
+ analyzedPkgs := make(map[*Package]*PkgNode)
for _, pkg := range pkgs {
// Top level packages that lead to vulnerable imports are
// stored as result.Imports graph entry points.
@@ -80,7 +77,7 @@
// a package with known vulnerabilities. If that is the case, populates result.Imports
// graph with this reachability information and returns the result.Imports package
// node for pkg. Otherwise, returns nil.
-func vulnImportSlice(pkg *packages.Package, modVulns moduleVulnerabilities, result *Result, analyzed map[*packages.Package]*PkgNode) *PkgNode {
+func vulnImportSlice(pkg *Package, modVulns moduleVulnerabilities, result *Result, analyzed map[*Package]*PkgNode) *PkgNode {
if pn, ok := analyzed[pkg]; ok {
return pn
}
diff --git a/vulncheck/source_test.go b/vulncheck/source_test.go
index df0e272..2cf7cb1 100644
--- a/vulncheck/source_test.go
+++ b/vulncheck/source_test.go
@@ -109,7 +109,7 @@
Client: testClient,
ImportsOnly: true,
}
- result, err := Source(pkgs, cfg)
+ result, err := Source(Convert(pkgs), cfg)
if err != nil {
t.Fatal(err)
}
@@ -345,7 +345,7 @@
cfg := &Config{
Client: testClient,
}
- result, err := Source(pkgs, cfg)
+ result, err := Source(Convert(pkgs), cfg)
if err != nil {
t.Fatal(err)
}
diff --git a/vulncheck/utils.go b/vulncheck/utils.go
index eea2078..d120f56 100644
--- a/vulncheck/utils.go
+++ b/vulncheck/utils.go
@@ -19,6 +19,41 @@
"golang.org/x/tools/go/ssa"
)
+// buildSSA creates an ssa representation for pkgs. Returns
+// the ssa program encapsulating the packages and top level
+// ssa packages corresponding to pkgs.
+func buildSSA(pkgs []*Package) (*ssa.Program, []*ssa.Package) {
+ prog := ssa.NewProgram(token.NewFileSet(), ssa.BuilderMode(0))
+
+ imports := make(map[*Package]*ssa.Package)
+ var createImports func([]*Package)
+ createImports = func(pkgs []*Package) {
+ for _, p := range pkgs {
+ if _, ok := imports[p]; !ok {
+ i := prog.CreatePackage(p.Pkg, p.Syntax, p.TypesInfo, true)
+ imports[p] = i
+ createImports(p.Imports)
+ }
+ }
+ }
+
+ for _, tp := range pkgs {
+ createImports(tp.Imports)
+ }
+
+ var ssaPkgs []*ssa.Package
+ for _, tp := range pkgs {
+ if sp, ok := imports[tp]; ok {
+ ssaPkgs = append(ssaPkgs, sp)
+ } else {
+ sp := prog.CreatePackage(tp.Pkg, tp.Syntax, tp.TypesInfo, false)
+ ssaPkgs = append(ssaPkgs, sp)
+ }
+ }
+ prog.Build()
+ return prog, ssaPkgs
+}
+
// callGraph builds a call graph of prog based on VTA analysis.
func callGraph(prog *ssa.Program, entries []*ssa.Function) *callgraph.Graph {
entrySlice := make(map[*ssa.Function]bool)
diff --git a/vulncheck/vulncheck.go b/vulncheck/vulncheck.go
index 38b6c2c..86bd89b 100644
--- a/vulncheck/vulncheck.go
+++ b/vulncheck/vulncheck.go
@@ -236,7 +236,7 @@
ImportedBy []int
// pkg is used for connecting package node to module and call graph nodes.
- pkg *packages.Package
+ pkg *Package
}
// moduleVulnerabilities is an internal structure for
@@ -246,7 +246,7 @@
// modVulns groups vulnerabilities per module.
type modVulns struct {
- mod *packages.Module
+ mod *Module
vulns []*osv.Entry
}
diff --git a/vulncheck/vulncheck_test.go b/vulncheck/vulncheck_test.go
index d68a1ee..076c458 100644
--- a/vulncheck/vulncheck_test.go
+++ b/vulncheck/vulncheck_test.go
@@ -9,7 +9,6 @@
"reflect"
"testing"
- "golang.org/x/tools/go/packages"
"golang.org/x/tools/go/packages/packagestest"
"golang.org/x/vulndb/osv"
)
@@ -17,7 +16,7 @@
func TestFilterVulns(t *testing.T) {
mv := moduleVulnerabilities{
{
- mod: &packages.Module{
+ mod: &Module{
Path: "example.mod/a",
Version: "v1.0.0",
},
@@ -32,7 +31,7 @@
},
},
{
- mod: &packages.Module{
+ mod: &Module{
Path: "example.mod/b",
Version: "v1.0.0",
},
@@ -44,7 +43,7 @@
},
},
{
- mod: &packages.Module{
+ mod: &Module{
Path: "example.mod/c",
},
vulns: []*osv.Entry{
@@ -54,7 +53,7 @@
},
},
{
- mod: &packages.Module{
+ mod: &Module{
Path: "example.mod/d",
Version: "v1.2.0",
},
@@ -69,7 +68,7 @@
expected := moduleVulnerabilities{
{
- mod: &packages.Module{
+ mod: &Module{
Path: "example.mod/a",
Version: "v1.0.0",
},
@@ -79,7 +78,7 @@
},
},
{
- mod: &packages.Module{
+ mod: &Module{
Path: "example.mod/b",
Version: "v1.0.0",
},
@@ -89,12 +88,12 @@
},
},
{
- mod: &packages.Module{
+ mod: &Module{
Path: "example.mod/c",
},
},
{
- mod: &packages.Module{
+ mod: &Module{
Path: "example.mod/d",
Version: "v1.2.0",
},
@@ -113,7 +112,7 @@
func TestVulnsForPackage(t *testing.T) {
mv := moduleVulnerabilities{
{
- mod: &packages.Module{
+ mod: &Module{
Path: "example.mod/a",
Version: "v1.0.0",
},
@@ -122,7 +121,7 @@
},
},
{
- mod: &packages.Module{
+ mod: &Module{
Path: "example.mod/a/b",
Version: "v1.0.0",
},
@@ -131,7 +130,7 @@
},
},
{
- mod: &packages.Module{
+ mod: &Module{
Path: "example.mod/d",
Version: "v0.0.1",
},
@@ -154,7 +153,7 @@
func TestVulnsForPackageReplaced(t *testing.T) {
mv := moduleVulnerabilities{
{
- mod: &packages.Module{
+ mod: &Module{
Path: "example.mod/a",
Version: "v1.0.0",
},
@@ -163,9 +162,9 @@
},
},
{
- mod: &packages.Module{
+ mod: &Module{
Path: "example.mod/a/b",
- Replace: &packages.Module{
+ Replace: &Module{
Path: "example.mod/b",
},
Version: "v1.0.0",
@@ -189,7 +188,7 @@
func TestVulnsForSymbol(t *testing.T) {
mv := moduleVulnerabilities{
{
- mod: &packages.Module{
+ mod: &Module{
Path: "example.mod/a",
Version: "v1.0.0",
},
@@ -198,7 +197,7 @@
},
},
{
- mod: &packages.Module{
+ mod: &Module{
Path: "example.mod/a/b",
Version: "v1.0.0",
},