gopls/internal/vulncheck: use internal/govulncheck
Copy the x/vuln/cmd/govulncheck/internal/govulncheck package
and use it in internal/vulncheck.
Fixes golang/go#52985.
Change-Id: I3fb16b3d486ac462fca36aa53fd46e576041102d
Reviewed-on: https://go-review.googlesource.com/c/tools/+/407114
Reviewed-by: Hyang-Ah Hana Kim <hyangah@gmail.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
gopls-CI: kokoro <noreply+kokoro@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/gopls/internal/govulncheck/source.go b/gopls/internal/govulncheck/source.go
new file mode 100644
index 0000000..752a831
--- /dev/null
+++ b/gopls/internal/govulncheck/source.go
@@ -0,0 +1,129 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.18
+// +build go1.18
+
+package govulncheck
+
+import (
+ "context"
+ "fmt"
+ "sort"
+ "strings"
+
+ "golang.org/x/tools/go/packages"
+ "golang.org/x/vuln/client"
+ "golang.org/x/vuln/vulncheck"
+)
+
+// A PackageError contains errors from loading a set of packages.
+type PackageError struct {
+ Errors []packages.Error
+}
+
+func (e *PackageError) Error() string {
+ var b strings.Builder
+ fmt.Fprintln(&b, "Packages contain errors:")
+ for _, e := range e.Errors {
+ fmt.Println(&b, e)
+ }
+ return b.String()
+}
+
+// LoadPackages loads the packages matching patterns using cfg, after setting
+// the cfg mode flags that vulncheck needs for analysis.
+// If the packages contain errors, a PackageError is returned containing a list of the errors,
+// along with the packages themselves.
+func LoadPackages(cfg *packages.Config, patterns ...string) ([]*vulncheck.Package, error) {
+ cfg.Mode |= packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles |
+ packages.NeedImports | packages.NeedTypes | packages.NeedTypesSizes |
+ packages.NeedSyntax | packages.NeedTypesInfo | packages.NeedDeps |
+ packages.NeedModule
+
+ pkgs, err := packages.Load(cfg, patterns...)
+ vpkgs := vulncheck.Convert(pkgs)
+ if err != nil {
+ return nil, err
+ }
+ var perrs []packages.Error
+ packages.Visit(pkgs, nil, func(p *packages.Package) {
+ perrs = append(perrs, p.Errors...)
+ })
+ if len(perrs) > 0 {
+ err = &PackageError{perrs}
+ }
+ return vpkgs, err
+}
+
+// Source calls vulncheck.Source on the Go source in pkgs. It returns the result
+// with Vulns trimmed to those that are actually called.
+func Source(ctx context.Context, pkgs []*vulncheck.Package, c client.Client) (*vulncheck.Result, error) {
+ r, err := vulncheck.Source(ctx, pkgs, &vulncheck.Config{Client: c})
+ if err != nil {
+ return nil, err
+ }
+ // Keep only the vulns that are called.
+ var vulns []*vulncheck.Vuln
+ for _, v := range r.Vulns {
+ if v.CallSink != 0 {
+ vulns = append(vulns, v)
+ }
+ }
+ r.Vulns = vulns
+ return r, nil
+}
+
+// CallInfo is information about calls to vulnerable functions.
+type CallInfo struct {
+ CallStacks map[*vulncheck.Vuln][]vulncheck.CallStack // all call stacks
+ VulnGroups [][]*vulncheck.Vuln // vulns grouped by ID and package
+ ModuleVersions map[string]string // map from module paths to versions
+ TopPackages map[string]bool // top-level packages
+}
+
+// GetCallInfo computes call stacks and related information from a vulncheck.Result.
+// I also makes a set of top-level packages from pkgs.
+func GetCallInfo(r *vulncheck.Result, pkgs []*vulncheck.Package) *CallInfo {
+ pset := map[string]bool{}
+ for _, p := range pkgs {
+ pset[p.PkgPath] = true
+ }
+ return &CallInfo{
+ CallStacks: vulncheck.CallStacks(r),
+ VulnGroups: groupByIDAndPackage(r.Vulns),
+ ModuleVersions: moduleVersionMap(r.Modules),
+ TopPackages: pset,
+ }
+}
+
+func groupByIDAndPackage(vs []*vulncheck.Vuln) [][]*vulncheck.Vuln {
+ groups := map[[2]string][]*vulncheck.Vuln{}
+ for _, v := range vs {
+ key := [2]string{v.OSV.ID, v.PkgPath}
+ groups[key] = append(groups[key], v)
+ }
+
+ var res [][]*vulncheck.Vuln
+ for _, g := range groups {
+ res = append(res, g)
+ }
+ sort.Slice(res, func(i, j int) bool {
+ return res[i][0].PkgPath < res[j][0].PkgPath
+ })
+ return res
+}
+
+// moduleVersionMap builds a map from module paths to versions.
+func moduleVersionMap(mods []*vulncheck.Module) map[string]string {
+ moduleVersions := map[string]string{}
+ for _, m := range mods {
+ v := m.Version
+ if m.Replace != nil {
+ v = m.Replace.Version
+ }
+ moduleVersions[m.Path] = v
+ }
+ return moduleVersions
+}
diff --git a/gopls/internal/govulncheck/util.go b/gopls/internal/govulncheck/util.go
new file mode 100644
index 0000000..baa2d96
--- /dev/null
+++ b/gopls/internal/govulncheck/util.go
@@ -0,0 +1,109 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.18
+// +build go1.18
+
+package govulncheck
+
+import (
+ "fmt"
+ "strings"
+
+ "golang.org/x/mod/semver"
+ "golang.org/x/vuln/osv"
+ "golang.org/x/vuln/vulncheck"
+)
+
+// LatestFixed returns the latest fixed version in the list of affected ranges,
+// or the empty string if there are no fixed versions.
+func LatestFixed(as []osv.Affected) string {
+ v := ""
+ for _, a := range as {
+ for _, r := range a.Ranges {
+ if r.Type == osv.TypeSemver {
+ for _, e := range r.Events {
+ if e.Fixed != "" && (v == "" || semver.Compare(e.Fixed, v) > 0) {
+ v = e.Fixed
+ }
+ }
+ }
+ }
+ }
+ return v
+}
+
+// SummarizeCallStack returns a short description of the call stack.
+// It uses one of two forms, depending on what the lowest function F in topPkgs
+// calls:
+// - If it calls a function V from the vulnerable package, then summarizeCallStack
+// returns "F calls V".
+// - If it calls a function G in some other package, which eventually calls V,
+// it returns "F calls G, which eventually calls V".
+//
+// If it can't find any of these functions, summarizeCallStack returns the empty string.
+func SummarizeCallStack(cs vulncheck.CallStack, topPkgs map[string]bool, vulnPkg string) string {
+ // Find the lowest function in the top packages.
+ iTop := lowest(cs, func(e vulncheck.StackEntry) bool {
+ return topPkgs[PkgPath(e.Function)]
+ })
+ if iTop < 0 {
+ return ""
+ }
+ // Find the highest function in the vulnerable package that is below iTop.
+ iVuln := highest(cs[iTop+1:], func(e vulncheck.StackEntry) bool {
+ return PkgPath(e.Function) == vulnPkg
+ })
+ if iVuln < 0 {
+ return ""
+ }
+ iVuln += iTop + 1 // adjust for slice in call to highest.
+ topName := FuncName(cs[iTop].Function)
+ vulnName := FuncName(cs[iVuln].Function)
+ if iVuln == iTop+1 {
+ return fmt.Sprintf("%s calls %s", topName, vulnName)
+ }
+ return fmt.Sprintf("%s calls %s, which eventually calls %s",
+ topName, FuncName(cs[iTop+1].Function), vulnName)
+}
+
+// highest returns the highest (one with the smallest index) entry in the call
+// stack for which f returns true.
+func highest(cs vulncheck.CallStack, f func(e vulncheck.StackEntry) bool) int {
+ for i := 0; i < len(cs); i++ {
+ if f(cs[i]) {
+ return i
+ }
+ }
+ return -1
+}
+
+// lowest returns the lowest (one with the largets index) entry in the call
+// stack for which f returns true.
+func lowest(cs vulncheck.CallStack, f func(e vulncheck.StackEntry) bool) int {
+ for i := len(cs) - 1; i >= 0; i-- {
+ if f(cs[i]) {
+ return i
+ }
+ }
+ return -1
+}
+
+// PkgPath returns the package path from fn.
+func PkgPath(fn *vulncheck.FuncNode) string {
+ if fn.PkgPath != "" {
+ return fn.PkgPath
+ }
+ s := strings.TrimPrefix(fn.RecvType, "*")
+ if i := strings.LastIndexByte(s, '.'); i > 0 {
+ s = s[:i]
+ }
+ return s
+}
+
+// FuncName returns the function name from fn, adjusted
+// to remove pointer annotations.
+func FuncName(fn *vulncheck.FuncNode) string {
+ return strings.TrimPrefix(fn.String(), "*")
+}
diff --git a/gopls/internal/govulncheck/util_test.go b/gopls/internal/govulncheck/util_test.go
new file mode 100644
index 0000000..3288cd8
--- /dev/null
+++ b/gopls/internal/govulncheck/util_test.go
@@ -0,0 +1,85 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.18
+// +build go1.18
+
+package govulncheck
+
+import (
+ "strings"
+ "testing"
+
+ "golang.org/x/vuln/vulncheck"
+)
+
+func TestPkgPath(t *testing.T) {
+ for _, test := range []struct {
+ in vulncheck.FuncNode
+ want string
+ }{
+ {
+ vulncheck.FuncNode{PkgPath: "math", Name: "Floor"},
+ "math",
+ },
+ {
+ vulncheck.FuncNode{RecvType: "a.com/b.T", Name: "M"},
+ "a.com/b",
+ },
+ {
+ vulncheck.FuncNode{RecvType: "*a.com/b.T", Name: "M"},
+ "a.com/b",
+ },
+ } {
+ got := PkgPath(&test.in)
+ if got != test.want {
+ t.Errorf("%+v: got %q, want %q", test.in, got, test.want)
+ }
+ }
+}
+
+func TestSummarizeCallStack(t *testing.T) {
+ topPkgs := map[string]bool{"t1": true, "t2": true}
+ vulnPkg := "v"
+
+ for _, test := range []struct {
+ in, want string
+ }{
+ {"a.F", ""},
+ {"t1.F", ""},
+ {"v.V", ""},
+ {
+ "t1.F v.V",
+ "t1.F calls v.V",
+ },
+ {
+ "t1.F t2.G v.V1 v.v2",
+ "t2.G calls v.V1",
+ },
+ {
+ "t1.F x.Y t2.G a.H b.I c.J v.V",
+ "t2.G calls a.H, which eventually calls v.V",
+ },
+ } {
+ in := stringToCallStack(test.in)
+ got := SummarizeCallStack(in, topPkgs, vulnPkg)
+ if got != test.want {
+ t.Errorf("%s:\ngot %s\nwant %s", test.in, got, test.want)
+ }
+ }
+}
+
+func stringToCallStack(s string) vulncheck.CallStack {
+ var cs vulncheck.CallStack
+ for _, e := range strings.Fields(s) {
+ parts := strings.Split(e, ".")
+ cs = append(cs, vulncheck.StackEntry{
+ Function: &vulncheck.FuncNode{
+ PkgPath: parts[0],
+ Name: parts[1],
+ },
+ })
+ }
+ return cs
+}
diff --git a/gopls/internal/vulncheck/command.go b/gopls/internal/vulncheck/command.go
index 459ecca..a89354f 100644
--- a/gopls/internal/vulncheck/command.go
+++ b/gopls/internal/vulncheck/command.go
@@ -17,7 +17,6 @@
gvc "golang.org/x/tools/gopls/internal/govulncheck"
"golang.org/x/tools/internal/lsp/command"
"golang.org/x/vuln/client"
- "golang.org/x/vuln/vulncheck"
)
func init() {
@@ -73,53 +72,36 @@
packages.NeedTypesSizes | packages.NeedSyntax | packages.NeedTypesInfo | packages.NeedDeps
log.Println("loading packages...")
-
- loadedPkgs, err := packages.Load(cfg, patterns...)
+ loadedPkgs, err := gvc.LoadPackages(cfg, patterns...)
if err != nil {
log.Printf("package load failed: %v", err)
return nil, err
}
log.Printf("loaded %d packages\n", len(loadedPkgs))
- pkgs := vulncheck.Convert(loadedPkgs)
- r, err := vulncheck.Source(ctx, pkgs, &vulncheck.Config{
- Client: c.Client,
- })
+ r, err := gvc.Source(ctx, loadedPkgs, c.Client)
if err != nil {
return nil, err
}
-
- // Skip vulns that are in the import graph but have no calls to them.
- var vulns []*vulncheck.Vuln
- for _, v := range r.Vulns {
- if v.CallSink != 0 {
- vulns = append(vulns, v)
- }
- }
-
- callStacks := vulncheck.CallStacks(r)
- // Create set of top-level packages, used to find representative symbols
- topPackages := map[string]bool{}
- for _, p := range pkgs {
- topPackages[p.PkgPath] = true
- }
- vulnGroups := groupByIDAndPackage(vulns)
- moduleVersions := moduleVersionMap(r.Modules)
-
- return toVulns(callStacks, moduleVersions, topPackages, vulnGroups)
+ callInfo := gvc.GetCallInfo(r, loadedPkgs)
+ return toVulns(callInfo)
// TODO: add import graphs.
}
-func toVulns(callStacks map[*vulncheck.Vuln][]vulncheck.CallStack, moduleVersions map[string]string, topPackages map[string]bool, vulnGroups [][]*vulncheck.Vuln) ([]Vuln, error) {
+func toVulns(ci *gvc.CallInfo) ([]Vuln, error) {
var vulns []Vuln
- for _, vg := range vulnGroups {
+ for _, vg := range ci.VulnGroups {
v0 := vg[0]
+ lf := gvc.LatestFixed(v0.OSV.Affected)
+ if lf != "" && lf[0] != 'v' {
+ lf = "v" + lf
+ }
vuln := Vuln{
ID: v0.OSV.ID,
PkgPath: v0.PkgPath,
- CurrentVersion: moduleVersions[v0.ModPath],
- FixedVersion: latestFixed(v0.OSV.Affected),
+ CurrentVersion: ci.ModuleVersions[v0.ModPath],
+ FixedVersion: lf,
Details: v0.OSV.Details,
Aliases: v0.OSV.Aliases,
@@ -130,9 +112,9 @@
// Keep first call stack for each vuln.
for _, v := range vg {
- if css := callStacks[v]; len(css) > 0 {
+ if css := ci.CallStacks[v]; len(css) > 0 {
vuln.CallStacks = append(vuln.CallStacks, toCallStack(css[0]))
- vuln.CallStackSummaries = append(vuln.CallStackSummaries, summarizeCallStack(css[0], topPackages, v.PkgPath))
+ vuln.CallStackSummaries = append(vuln.CallStackSummaries, gvc.SummarizeCallStack(css[0], ci.TopPackages, v.PkgPath))
}
}
vulns = append(vulns, vuln)
diff --git a/gopls/internal/vulncheck/util.go b/gopls/internal/vulncheck/util.go
index e2a437b..c329461 100644
--- a/gopls/internal/vulncheck/util.go
+++ b/gopls/internal/vulncheck/util.go
@@ -10,134 +10,13 @@
import (
"fmt"
"go/token"
- "sort"
- "strings"
- "golang.org/x/mod/semver"
+ gvc "golang.org/x/tools/gopls/internal/govulncheck"
"golang.org/x/tools/internal/lsp/protocol"
"golang.org/x/vuln/osv"
"golang.org/x/vuln/vulncheck"
)
-// TODO(hyangah): automate copy of golang.org/x/vuln/cmd.
-
-// moduleVersionMap builds a map from module paths to versions.
-func moduleVersionMap(mods []*vulncheck.Module) map[string]string {
- moduleVersions := map[string]string{}
- for _, m := range mods {
- v := m.Version
- if m.Replace != nil {
- v = m.Replace.Version
- }
- moduleVersions[m.Path] = v
- }
- return moduleVersions
-}
-
-func groupByIDAndPackage(vs []*vulncheck.Vuln) [][]*vulncheck.Vuln {
- groups := map[[2]string][]*vulncheck.Vuln{}
- for _, v := range vs {
- key := [2]string{v.OSV.ID, v.PkgPath}
- groups[key] = append(groups[key], v)
- }
-
- var res [][]*vulncheck.Vuln
- for _, g := range groups {
- res = append(res, g)
- }
- sort.Slice(res, func(i, j int) bool {
- return res[i][0].PkgPath < res[j][0].PkgPath
- })
- return res
-}
-
-// latestFixed returns the latest fixed version in the list of affected ranges,
-// or the empty string if there are no fixed versions.
-func latestFixed(as []osv.Affected) string {
- v := ""
- for _, a := range as {
- for _, r := range a.Ranges {
- if r.Type == osv.TypeSemver {
- for _, e := range r.Events {
- if e.Fixed != "" && (v == "" || semver.Compare(e.Fixed, v) > 0) {
- v = e.Fixed
- }
- }
- }
- }
- }
- if v == "" || v[0] == 'v' {
- return v
- }
- return "v" + v
-}
-
-// summarizeCallStack returns a short description of the call stack.
-// It uses one of two forms, depending on what the lowest function F in topPkgs
-// calls:
-// - If it calls a function V from the vulnerable package, then summarizeCallStack
-// returns "F calls V".
-// - If it calls a function G in some other package, which eventually calls V,
-// it returns "F calls G, which eventually calls V".
-//
-// If it can't find any of these functions, summarizeCallStack returns the empty string.
-func summarizeCallStack(cs vulncheck.CallStack, topPkgs map[string]bool, vulnPkg string) string {
- // Find the lowest function in the top packages.
- iTop := lowest(cs, func(e vulncheck.StackEntry) bool {
- return topPkgs[pkgPath(e.Function)]
- })
- if iTop < 0 {
- return ""
- }
- // Find the highest function in the vulnerable package that is below iTop.
- iVuln := highest(cs[iTop+1:], func(e vulncheck.StackEntry) bool {
- return pkgPath(e.Function) == vulnPkg
- })
- if iVuln < 0 {
- return ""
- }
- iVuln += iTop + 1 // adjust for slice in call to highest.
- topName := funcName(cs[iTop].Function)
- vulnName := funcName(cs[iVuln].Function)
- if iVuln == iTop+1 {
- return fmt.Sprintf("%s calls %s", topName, vulnName)
- }
- return fmt.Sprintf("%s calls %s, which eventually calls %s",
- topName, funcName(cs[iTop+1].Function), vulnName)
-}
-
-// highest returns the highest (one with the smallest index) entry in the call
-// stack for which f returns true.
-func highest(cs vulncheck.CallStack, f func(e vulncheck.StackEntry) bool) int {
- for i := 0; i < len(cs); i++ {
- if f(cs[i]) {
- return i
- }
- }
- return -1
-}
-
-// lowest returns the lowest (one with the largets index) entry in the call
-// stack for which f returns true.
-func lowest(cs vulncheck.CallStack, f func(e vulncheck.StackEntry) bool) int {
- for i := len(cs) - 1; i >= 0; i-- {
- if f(cs[i]) {
- return i
- }
- }
- return -1
-}
-func pkgPath(fn *vulncheck.FuncNode) string {
- if fn.PkgPath != "" {
- return fn.PkgPath
- }
- s := strings.TrimPrefix(fn.RecvType, "*")
- if i := strings.LastIndexByte(s, '.'); i > 0 {
- s = s[:i]
- }
- return s
-}
-
func toCallStack(src vulncheck.CallStack) CallStack {
var dest []StackEntry
for _, e := range src {
@@ -149,7 +28,7 @@
func toStackEntry(src vulncheck.StackEntry) StackEntry {
f, call := src.Function, src.Call
pos := f.Pos
- desc := funcName(f)
+ desc := gvc.FuncName(f)
if src.Call != nil {
pos = src.Call.Pos // Exact call site position is helpful.
if !call.Resolved {
@@ -166,10 +45,6 @@
}
}
-func funcName(fn *vulncheck.FuncNode) string {
- return strings.TrimPrefix(fn.String(), "*")
-}
-
// href returns a URL embedded in the entry if any.
// If no suitable URL is found, it returns a default entry in
// pkg.go.dev/vuln.