cmd/govulncheck: add GetCallInfo to internal package
Move more functionality to the internal/govulncheck package so
gopls can share it.
Change-Id: Iabb88b6e5af71cf22e54f1264d8b307cc719b9e1
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/406936
Reviewed-by: Zvonimir Pavlinovic <zpavlinovic@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
diff --git a/cmd/govulncheck/html.go b/cmd/govulncheck/html.go
index 3f29f43..090a464 100644
--- a/cmd/govulncheck/html.go
+++ b/cmd/govulncheck/html.go
@@ -13,13 +13,14 @@
"html/template"
"io"
+ "golang.org/x/vuln/cmd/govulncheck/internal/govulncheck"
"golang.org/x/vuln/vulncheck"
)
//go:embed static/*
var staticContent embed.FS
-func html(w io.Writer, r *vulncheck.Result, callStacks map[*vulncheck.Vuln][]vulncheck.CallStack, moduleVersions map[string]string, topPackages map[string]bool, vulnGroups [][]*vulncheck.Vuln) error {
+func html(w io.Writer, r *vulncheck.Result, ci *govulncheck.CallInfo) error {
tmpl, err := template.New("govulncheck.tmpl").Funcs(template.FuncMap{
"funcName": funcName,
}).ParseFS(staticContent, "static/govulncheck.tmpl")
@@ -43,21 +44,21 @@
}
var vulns []*vuln
- for _, vg := range vulnGroups {
+ for _, vg := range ci.VulnGroups {
v0 := vg[0]
vn := &vuln{
ID: v0.OSV.ID,
PkgPath: v0.PkgPath,
- CurrentVersion: moduleVersions[v0.ModPath],
+ CurrentVersion: ci.ModuleVersions[v0.ModPath],
FixedVersion: "v" + latestFixed(v0.OSV.Affected),
Reference: fmt.Sprintf("https://pkg.go.dev/vuln/%s", v0.OSV.ID),
Details: v0.OSV.Details,
}
// Keep first call stack for each vuln.
for _, v := range vg {
- if css := callStacks[v]; len(css) > 0 {
+ if css := ci.CallStacks[v]; len(css) > 0 {
vn.Stacks = append(vn.Stacks, callstack{
- Summary: summarizeCallStack(css[0], topPackages, v.PkgPath),
+ Summary: summarizeCallStack(css[0], ci.TopPackages, v.PkgPath),
Stack: css[0],
})
}
diff --git a/cmd/govulncheck/internal/govulncheck/source.go b/cmd/govulncheck/internal/govulncheck/source.go
index b02fc47..9f91665 100644
--- a/cmd/govulncheck/internal/govulncheck/source.go
+++ b/cmd/govulncheck/internal/govulncheck/source.go
@@ -7,6 +7,7 @@
import (
"context"
"fmt"
+ "sort"
"strings"
"golang.org/x/tools/go/packages"
@@ -70,3 +71,56 @@
r.Vulns = vulns
return r, nil
}
+
+// CallInfo is information about calls to vulnerable functions.
+type CallInfo struct {
+ CallStacks map[*vulncheck.Vuln][]vulncheck.CallStack // all call stacks
+ VulnGroups [][]*vulncheck.Vuln // vulns grouped by ID and package
+ ModuleVersions map[string]string // map from module paths to versions
+ TopPackages map[string]bool // top-level packages
+}
+
+// GetCallInfo computes call stacks and related information from a vulncheck.Result.
+// I also makes a set of top-level packages from pkgs.
+func GetCallInfo(r *vulncheck.Result, pkgs []*vulncheck.Package) *CallInfo {
+ pset := map[string]bool{}
+ for _, p := range pkgs {
+ pset[p.PkgPath] = true
+ }
+ return &CallInfo{
+ CallStacks: vulncheck.CallStacks(r),
+ VulnGroups: groupByIDAndPackage(r.Vulns),
+ ModuleVersions: moduleVersionMap(r.Modules),
+ TopPackages: pset,
+ }
+}
+
+func groupByIDAndPackage(vs []*vulncheck.Vuln) [][]*vulncheck.Vuln {
+ groups := map[[2]string][]*vulncheck.Vuln{}
+ for _, v := range vs {
+ key := [2]string{v.OSV.ID, v.PkgPath}
+ groups[key] = append(groups[key], v)
+ }
+
+ var res [][]*vulncheck.Vuln
+ for _, g := range groups {
+ res = append(res, g)
+ }
+ sort.Slice(res, func(i, j int) bool {
+ return res[i][0].PkgPath < res[j][0].PkgPath
+ })
+ return res
+}
+
+// moduleVersionMap builds a map from module paths to versions.
+func moduleVersionMap(mods []*vulncheck.Module) map[string]string {
+ moduleVersions := map[string]string{}
+ for _, m := range mods {
+ v := m.Version
+ if m.Replace != nil {
+ v = m.Replace.Version
+ }
+ moduleVersions[m.Path] = v
+ }
+ return moduleVersions
+}
diff --git a/cmd/govulncheck/main.go b/cmd/govulncheck/main.go
index a7f8a82..050a142 100644
--- a/cmd/govulncheck/main.go
+++ b/cmd/govulncheck/main.go
@@ -128,20 +128,14 @@
if *jsonFlag {
writeJSON(r)
} else {
- callStacks := vulncheck.CallStacks(r)
- // Create set of top-level packages, used to find representative symbols
- topPackages := map[string]bool{}
- for _, p := range pkgs {
- topPackages[p.PkgPath] = true
- }
- vulnGroups := groupByIDAndPackage(r.Vulns)
- moduleVersions := moduleVersionMap(r.Modules)
+ // set of top-level packages, used to find representative symbols
+ ci := govulncheck.GetCallInfo(r, pkgs)
if *htmlFlag {
- if err := html(os.Stdout, r, callStacks, moduleVersions, topPackages, vulnGroups); err != nil {
+ if err := html(os.Stdout, r, ci); err != nil {
die("writing HTML: %v", err)
}
} else {
- writeText(r, callStacks, moduleVersions, topPackages, vulnGroups)
+ writeText(r, ci)
}
}
exitCode := 0
@@ -152,19 +146,6 @@
os.Exit(exitCode)
}
-// moduleVersionMap builds a map from module paths to versions.
-func moduleVersionMap(mods []*vulncheck.Module) map[string]string {
- moduleVersions := map[string]string{}
- for _, m := range mods {
- v := m.Version
- if m.Replace != nil {
- v = m.Replace.Version
- }
- moduleVersions[m.Path] = v
- }
- return moduleVersions
-}
-
func writeJSON(r *vulncheck.Result) {
b, err := json.MarshalIndent(r, "", "\t")
if err != nil {
@@ -174,23 +155,23 @@
fmt.Println()
}
-func writeText(r *vulncheck.Result, callStacks map[*vulncheck.Vuln][]vulncheck.CallStack, moduleVersions map[string]string, topPackages map[string]bool, vulnGroups [][]*vulncheck.Vuln) {
+func writeText(r *vulncheck.Result, ci *govulncheck.CallInfo) {
const labelWidth = 16
line := func(label, text string) {
fmt.Printf("%-*s%s\n", labelWidth, label, text)
}
- for _, vg := range vulnGroups {
+ for _, vg := range ci.VulnGroups {
// All the vulns in vg have the same PkgPath, ModPath and OSV.
// All have a non-zero CallSink.
v0 := vg[0]
line("package:", v0.PkgPath)
- line("your version:", moduleVersions[v0.ModPath])
+ line("your version:", ci.ModuleVersions[v0.ModPath])
line("fixed version:", "v"+latestFixed(v0.OSV.Affected))
var summaries []string
for _, v := range vg {
- if css := callStacks[v]; len(css) > 0 {
- if sum := summarizeCallStack(css[0], topPackages, v.PkgPath); sum != "" {
+ if css := ci.CallStacks[v]; len(css) > 0 {
+ if sum := summarizeCallStack(css[0], ci.TopPackages, v.PkgPath); sum != "" {
summaries = append(summaries, sum)
}
}
@@ -216,34 +197,6 @@
}
}
-func groupByIDAndPackage(vs []*vulncheck.Vuln) [][]*vulncheck.Vuln {
- groups := map[[2]string][]*vulncheck.Vuln{}
- for _, v := range vs {
- key := [2]string{v.OSV.ID, v.PkgPath}
- groups[key] = append(groups[key], v)
- }
-
- var res [][]*vulncheck.Vuln
- for _, g := range groups {
- res = append(res, g)
- }
- sort.Slice(res, func(i, j int) bool {
- return res[i][0].PkgPath < res[j][0].PkgPath
- })
- return res
-}
-
-func packageModule(p *packages.Package) *packages.Module {
- m := p.Module
- if m == nil {
- return nil
- }
- if r := m.Replace; r != nil {
- return r
- }
- return m
-}
-
func isFile(path string) bool {
s, err := os.Stat(path)
if err != nil {