gopls: dedup upgrade code actions for vulncheck
Make sure we aren't sending multiple code actions
that do the same thing. This also adds a upgrade
to latest code action.
Change-Id: Ic9cecd0a9410648673d4afe63da5a940960a4afc
Reviewed-on: https://go-review.googlesource.com/c/tools/+/436776
Reviewed-by: Hyang-Ah Hana Kim <hyangah@gmail.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
gopls-CI: kokoro <noreply+kokoro@google.com>
Run-TryBot: Suzy Mueller <suzmue@golang.org>
diff --git a/gopls/internal/lsp/code_action.go b/gopls/internal/lsp/code_action.go
index 1ad0cb4..d19cafc 100644
--- a/gopls/internal/lsp/code_action.go
+++ b/gopls/internal/lsp/code_action.go
@@ -81,16 +81,30 @@
if err != nil {
return nil, err
}
- vdiags, err := mod.ModVulnerabilityDiagnostics(ctx, snapshot, fh)
- if err != nil {
- return nil, err
- }
- // TODO(suzmue): Consider deduping upgrades from ModUpgradeDiagnostics and ModVulnerabilityDiagnostics.
- quickFixes, err := codeActionsMatchingDiagnostics(ctx, snapshot, diagnostics, append(append(diags, udiags...), vdiags...))
+ quickFixes, err := codeActionsMatchingDiagnostics(ctx, snapshot, diagnostics, append(diags, udiags...))
if err != nil {
return nil, err
}
codeActions = append(codeActions, quickFixes...)
+
+ vdiags, err := mod.ModVulnerabilityDiagnostics(ctx, snapshot, fh)
+ if err != nil {
+ return nil, err
+ }
+ // Group vulnerabilities by location and then limit which code actions we return
+ // for each location.
+ m := make(map[protocol.Range][]*source.Diagnostic)
+ for _, v := range vdiags {
+ m[v.Range] = append(m[v.Range], v)
+ }
+ for _, sdiags := range m {
+ quickFixes, err = codeActionsMatchingDiagnostics(ctx, snapshot, diagnostics, sdiags)
+ if err != nil {
+ return nil, err
+ }
+ quickFixes = mod.SelectUpgradeCodeActions(quickFixes)
+ codeActions = append(codeActions, quickFixes...)
+ }
}
case source.Go:
// Don't suggest fixes for generated files, since they are generally
diff --git a/gopls/internal/lsp/mod/diagnostics.go b/gopls/internal/lsp/mod/diagnostics.go
index 546c84c..d794424 100644
--- a/gopls/internal/lsp/mod/diagnostics.go
+++ b/gopls/internal/lsp/mod/diagnostics.go
@@ -10,8 +10,10 @@
"bytes"
"context"
"fmt"
+ "sort"
"strings"
+ "golang.org/x/mod/modfile"
"golang.org/x/mod/semver"
"golang.org/x/tools/gopls/internal/govulncheck"
"golang.org/x/tools/gopls/internal/lsp/command"
@@ -139,7 +141,7 @@
return nil, err
}
// Upgrade to the exact version we offer the user, not the most recent.
- title := fmt.Sprintf("Upgrade to %v", ver)
+ title := fmt.Sprintf("%s%v", upgradeCodeActionPrefix, ver)
cmd, err := command.NewUpgradeDependencyCommand(title, command.DependencyArgs{
URI: protocol.URIFromSpanURI(fh.URI()),
AddRequire: false,
@@ -176,6 +178,8 @@
}
}
+const upgradeCodeActionPrefix = "Upgrade to "
+
// ModVulnerabilityDiagnostics adds diagnostics for vulnerabilities in individual modules
// if the vulnerability is recorded in the view.
func ModVulnerabilityDiagnostics(ctx context.Context, snapshot source.Snapshot, fh source.FileHandle) (vulnDiagnostics []*source.Diagnostic, err error) {
@@ -212,20 +216,24 @@
continue
}
// Upgrade to the exact version we offer the user, not the most recent.
- // TODO(suzmue): Add an upgrade for module@latest.
// TODO(hakim): Produce fixes only for affecting vulnerabilities (if len(v.Trace) > 0)
var fixes []source.SuggestedFix
if fixedVersion := v.FixedIn; semver.IsValid(fixedVersion) && semver.Compare(req.Mod.Version, fixedVersion) < 0 {
- title := fmt.Sprintf("Upgrade to %v", fixedVersion)
- cmd, err := command.NewUpgradeDependencyCommand(title, command.DependencyArgs{
- URI: protocol.URIFromSpanURI(fh.URI()),
- AddRequire: false,
- GoCmdArgs: []string{req.Mod.Path + "@" + fixedVersion},
- })
+ cmd, err := getUpgradeCodeAction(fh, req, fixedVersion)
if err != nil {
return nil, err
}
- fixes = append(fixes, source.SuggestedFixFromCommand(cmd, protocol.QuickFix))
+ // Add an upgrade for module@latest.
+ // TODO(suzmue): verify if latest is the same as fixedVersion.
+ latest, err := getUpgradeCodeAction(fh, req, "latest")
+ if err != nil {
+ return nil, err
+ }
+
+ fixes = []source.SuggestedFix{
+ source.SuggestedFixFromCommand(cmd, protocol.QuickFix),
+ source.SuggestedFixFromCommand(latest, protocol.QuickFix),
+ }
}
severity := protocol.SeverityInformation
@@ -277,3 +285,48 @@
}
return fmt.Sprintf("https://pkg.go.dev/vuln/%s", vuln.ID)
}
+
+func getUpgradeCodeAction(fh source.FileHandle, req *modfile.Require, version string) (protocol.Command, error) {
+ cmd, err := command.NewUpgradeDependencyCommand(upgradeTitle(version), command.DependencyArgs{
+ URI: protocol.URIFromSpanURI(fh.URI()),
+ AddRequire: false,
+ GoCmdArgs: []string{req.Mod.Path + "@" + version},
+ })
+ if err != nil {
+ return protocol.Command{}, err
+ }
+ return cmd, nil
+}
+
+func upgradeTitle(fixedVersion string) string {
+ title := fmt.Sprintf("%s%v", upgradeCodeActionPrefix, fixedVersion)
+ return title
+}
+
+// SelectUpgradeCodeActions takes a list of upgrade code actions for a
+// required module and returns a more selective list of upgrade code actions,
+// where the code actions have been deduped.
+func SelectUpgradeCodeActions(actions []protocol.CodeAction) []protocol.CodeAction {
+ // TODO(suzmue): we can further limit the code actions to only return the most
+ // recent version that will fix all the vulnerabilities.
+
+ set := make(map[string]protocol.CodeAction)
+ for _, action := range actions {
+ set[action.Command.Title] = action
+ }
+ var result []protocol.CodeAction
+ for _, action := range set {
+ result = append(result, action)
+ }
+ // Sort results by version number, latest first.
+ // There should be no duplicates at this point.
+ sort.Slice(result, func(i, j int) bool {
+ vi, vj := getUpgradeVersion(result[i]), getUpgradeVersion(result[j])
+ return vi == "latest" || (vj != "latest" && semver.Compare(vi, vj) > 0)
+ })
+ return result
+}
+
+func getUpgradeVersion(p protocol.CodeAction) string {
+ return strings.TrimPrefix(p.Title, upgradeCodeActionPrefix)
+}
diff --git a/gopls/internal/regtest/misc/vuln_test.go b/gopls/internal/regtest/misc/vuln_test.go
index 1b55a6c..ecebfb5 100644
--- a/gopls/internal/regtest/misc/vuln_test.go
+++ b/gopls/internal/regtest/misc/vuln_test.go
@@ -9,7 +9,6 @@
import (
"context"
- "strings"
"testing"
"golang.org/x/tools/gopls/internal/lsp/command"
@@ -310,7 +309,6 @@
},
).Run(t, workspace1, func(t *testing.T, env *Env) {
env.OpenFile("go.mod")
- env.ExecuteCodeLensCommand("go.mod", command.Tidy)
env.ExecuteCodeLensCommand("go.mod", command.RunVulncheckExp)
d := &protocol.PublishDiagnosticsParams{}
@@ -318,20 +316,117 @@
CompletedWork("govulncheck", 1, true),
ShownMessage("Found"),
OnceMet(
- env.DiagnosticAtRegexpWithMessage("go.mod", `golang.org/amod`, "golang.org/amod has a known vulnerability: vuln in amod"),
- env.DiagnosticAtRegexpWithMessage("go.mod", `golang.org/amod`, "golang.org/amod has a known vulnerability: unaffecting vulnerability"),
- env.DiagnosticAtRegexpWithMessage("go.mod", `golang.org/bmod`, "golang.org/bmod has a known vulnerability: vuln in bmod\n\nThis is a long description of this vulnerability."),
+ env.DiagnosticAtRegexp("go.mod", `golang.org/amod`),
ReadDiagnostics("go.mod", d),
),
)
- var toFix []protocol.Diagnostic
- for _, diag := range d.Diagnostics {
- if strings.Contains(diag.Message, "vuln in ") {
- toFix = append(toFix, diag)
- }
+ type diagnostic struct {
+ msg string
+ severity protocol.DiagnosticSeverity
+ // codeActions is a list titles of code actions that we get with this
+ // diagnostics as the context.
+ codeActions []string
}
- env.ApplyQuickFixes("go.mod", toFix)
+ // wantDiagnostics maps a module path in the require
+ // section of a go.mod to diagnostics that will be returned
+ // when running vulncheck.
+ wantDiagnostics := map[string]struct {
+ // applyAction is the title of the code action to run for this module.
+ // If empty, no code actions will be executed.
+ applyAction string
+ // diagnostics is the list of diagnostics we expect at the require line for
+ // the module path.
+ diagnostics []diagnostic
+ // codeActions is a list titles of code actions that we get with context
+ // diagnostics.
+ codeActions []string
+ }{
+ "golang.org/amod": {
+ applyAction: "Upgrade to v1.0.4",
+ diagnostics: []diagnostic{
+ {
+ msg: "golang.org/amod has a known vulnerability: vuln in amod",
+ severity: protocol.SeverityWarning,
+ codeActions: []string{
+ "Upgrade to latest",
+ "Upgrade to v1.0.4",
+ },
+ },
+ {
+ msg: "golang.org/amod has a known vulnerability: unaffecting vulnerability",
+ severity: protocol.SeverityInformation,
+ codeActions: []string{
+ "Upgrade to latest",
+ "Upgrade to v1.0.6",
+ },
+ },
+ },
+ codeActions: []string{
+ "Upgrade to latest",
+ "Upgrade to v1.0.6",
+ "Upgrade to v1.0.4",
+ },
+ },
+ "golang.org/bmod": {
+ diagnostics: []diagnostic{
+ {
+ msg: "golang.org/bmod has a known vulnerability: vuln in bmod\n\nThis is a long description of this vulnerability.",
+ severity: protocol.SeverityWarning,
+ },
+ },
+ },
+ }
+
+ for mod, want := range wantDiagnostics {
+ pos := env.RegexpSearch("go.mod", mod)
+ var modPathDiagnostics []protocol.Diagnostic
+ for _, w := range want.diagnostics {
+ // Find the diagnostics at pos.
+ var diag *protocol.Diagnostic
+ for _, g := range d.Diagnostics {
+ g := g
+ if g.Range.Start == pos.ToProtocolPosition() && w.msg == g.Message {
+ modPathDiagnostics = append(modPathDiagnostics, g)
+ diag = &g
+ break
+ }
+ }
+ if diag == nil {
+ t.Errorf("no diagnostic at %q matching %q found\n", mod, w.msg)
+ continue
+ }
+ if diag.Severity != w.severity {
+ t.Errorf("incorrect severity for %q, expected %s got %s\n", w.msg, w.severity, diag.Severity)
+ }
+
+ gotActions := env.CodeAction("go.mod", []protocol.Diagnostic{*diag})
+ if !sameCodeActions(gotActions, w.codeActions) {
+ t.Errorf("code actions for %q do not match, expected %v, got %v\n", w.msg, w.codeActions, gotActions)
+ continue
+ }
+ }
+
+ // Check that the actions we get when including all diagnostics at a location return the same result
+ gotActions := env.CodeAction("go.mod", modPathDiagnostics)
+ if !sameCodeActions(gotActions, want.codeActions) {
+ t.Errorf("code actions for %q do not match, expected %v, got %v\n", mod, want.codeActions, gotActions)
+ continue
+ }
+
+ // Apply the code action matching applyAction.
+ if want.applyAction == "" {
+ continue
+ }
+ for _, action := range gotActions {
+ if action.Title == want.applyAction {
+ env.ApplyCodeAction(action)
+ break
+ }
+ }
+
+ }
+
env.Await(env.DoneWithChangeWatchedFiles())
wantGoMod := `module golang.org/entry
@@ -349,3 +444,19 @@
}
})
}
+
+func sameCodeActions(gotActions []protocol.CodeAction, want []string) bool {
+ gotTitles := make([]string, len(gotActions))
+ for i, ca := range gotActions {
+ gotTitles[i] = ca.Title
+ }
+ if len(gotTitles) != len(want) {
+ return false
+ }
+ for i := range want {
+ if gotTitles[i] != want[i] {
+ return false
+ }
+ }
+ return true
+}