gopls: dedup upgrade code actions for vulncheck

Make sure we aren't sending multiple code actions
that do the same thing. This also adds a upgrade
to latest code action.

Change-Id: Ic9cecd0a9410648673d4afe63da5a940960a4afc
Reviewed-on: https://go-review.googlesource.com/c/tools/+/436776
Reviewed-by: Hyang-Ah Hana Kim <hyangah@gmail.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
gopls-CI: kokoro <noreply+kokoro@google.com>
Run-TryBot: Suzy Mueller <suzmue@golang.org>
diff --git a/gopls/internal/lsp/code_action.go b/gopls/internal/lsp/code_action.go
index 1ad0cb4..d19cafc 100644
--- a/gopls/internal/lsp/code_action.go
+++ b/gopls/internal/lsp/code_action.go
@@ -81,16 +81,30 @@
 			if err != nil {
 				return nil, err
 			}
-			vdiags, err := mod.ModVulnerabilityDiagnostics(ctx, snapshot, fh)
-			if err != nil {
-				return nil, err
-			}
-			// TODO(suzmue): Consider deduping upgrades from ModUpgradeDiagnostics and ModVulnerabilityDiagnostics.
-			quickFixes, err := codeActionsMatchingDiagnostics(ctx, snapshot, diagnostics, append(append(diags, udiags...), vdiags...))
+			quickFixes, err := codeActionsMatchingDiagnostics(ctx, snapshot, diagnostics, append(diags, udiags...))
 			if err != nil {
 				return nil, err
 			}
 			codeActions = append(codeActions, quickFixes...)
+
+			vdiags, err := mod.ModVulnerabilityDiagnostics(ctx, snapshot, fh)
+			if err != nil {
+				return nil, err
+			}
+			// Group vulnerabilities by location and then limit which code actions we return
+			// for each location.
+			m := make(map[protocol.Range][]*source.Diagnostic)
+			for _, v := range vdiags {
+				m[v.Range] = append(m[v.Range], v)
+			}
+			for _, sdiags := range m {
+				quickFixes, err = codeActionsMatchingDiagnostics(ctx, snapshot, diagnostics, sdiags)
+				if err != nil {
+					return nil, err
+				}
+				quickFixes = mod.SelectUpgradeCodeActions(quickFixes)
+				codeActions = append(codeActions, quickFixes...)
+			}
 		}
 	case source.Go:
 		// Don't suggest fixes for generated files, since they are generally
diff --git a/gopls/internal/lsp/mod/diagnostics.go b/gopls/internal/lsp/mod/diagnostics.go
index 546c84c..d794424 100644
--- a/gopls/internal/lsp/mod/diagnostics.go
+++ b/gopls/internal/lsp/mod/diagnostics.go
@@ -10,8 +10,10 @@
 	"bytes"
 	"context"
 	"fmt"
+	"sort"
 	"strings"
 
+	"golang.org/x/mod/modfile"
 	"golang.org/x/mod/semver"
 	"golang.org/x/tools/gopls/internal/govulncheck"
 	"golang.org/x/tools/gopls/internal/lsp/command"
@@ -139,7 +141,7 @@
 			return nil, err
 		}
 		// Upgrade to the exact version we offer the user, not the most recent.
-		title := fmt.Sprintf("Upgrade to %v", ver)
+		title := fmt.Sprintf("%s%v", upgradeCodeActionPrefix, ver)
 		cmd, err := command.NewUpgradeDependencyCommand(title, command.DependencyArgs{
 			URI:        protocol.URIFromSpanURI(fh.URI()),
 			AddRequire: false,
@@ -176,6 +178,8 @@
 	}
 }
 
+const upgradeCodeActionPrefix = "Upgrade to "
+
 // ModVulnerabilityDiagnostics adds diagnostics for vulnerabilities in individual modules
 // if the vulnerability is recorded in the view.
 func ModVulnerabilityDiagnostics(ctx context.Context, snapshot source.Snapshot, fh source.FileHandle) (vulnDiagnostics []*source.Diagnostic, err error) {
@@ -212,20 +216,24 @@
 				continue
 			}
 			// Upgrade to the exact version we offer the user, not the most recent.
-			// TODO(suzmue): Add an upgrade for module@latest.
 			// TODO(hakim): Produce fixes only for affecting vulnerabilities (if len(v.Trace) > 0)
 			var fixes []source.SuggestedFix
 			if fixedVersion := v.FixedIn; semver.IsValid(fixedVersion) && semver.Compare(req.Mod.Version, fixedVersion) < 0 {
-				title := fmt.Sprintf("Upgrade to %v", fixedVersion)
-				cmd, err := command.NewUpgradeDependencyCommand(title, command.DependencyArgs{
-					URI:        protocol.URIFromSpanURI(fh.URI()),
-					AddRequire: false,
-					GoCmdArgs:  []string{req.Mod.Path + "@" + fixedVersion},
-				})
+				cmd, err := getUpgradeCodeAction(fh, req, fixedVersion)
 				if err != nil {
 					return nil, err
 				}
-				fixes = append(fixes, source.SuggestedFixFromCommand(cmd, protocol.QuickFix))
+				// Add an upgrade for module@latest.
+				// TODO(suzmue): verify if latest is the same as fixedVersion.
+				latest, err := getUpgradeCodeAction(fh, req, "latest")
+				if err != nil {
+					return nil, err
+				}
+
+				fixes = []source.SuggestedFix{
+					source.SuggestedFixFromCommand(cmd, protocol.QuickFix),
+					source.SuggestedFixFromCommand(latest, protocol.QuickFix),
+				}
 			}
 
 			severity := protocol.SeverityInformation
@@ -277,3 +285,48 @@
 	}
 	return fmt.Sprintf("https://pkg.go.dev/vuln/%s", vuln.ID)
 }
+
+func getUpgradeCodeAction(fh source.FileHandle, req *modfile.Require, version string) (protocol.Command, error) {
+	cmd, err := command.NewUpgradeDependencyCommand(upgradeTitle(version), command.DependencyArgs{
+		URI:        protocol.URIFromSpanURI(fh.URI()),
+		AddRequire: false,
+		GoCmdArgs:  []string{req.Mod.Path + "@" + version},
+	})
+	if err != nil {
+		return protocol.Command{}, err
+	}
+	return cmd, nil
+}
+
+func upgradeTitle(fixedVersion string) string {
+	title := fmt.Sprintf("%s%v", upgradeCodeActionPrefix, fixedVersion)
+	return title
+}
+
+// SelectUpgradeCodeActions takes a list of upgrade code actions for a
+// required module and returns a more selective list of upgrade code actions,
+// where the code actions have been deduped.
+func SelectUpgradeCodeActions(actions []protocol.CodeAction) []protocol.CodeAction {
+	// TODO(suzmue): we can further limit the code actions to only return the most
+	// recent version that will fix all the vulnerabilities.
+
+	set := make(map[string]protocol.CodeAction)
+	for _, action := range actions {
+		set[action.Command.Title] = action
+	}
+	var result []protocol.CodeAction
+	for _, action := range set {
+		result = append(result, action)
+	}
+	// Sort results by version number, latest first.
+	// There should be no duplicates at this point.
+	sort.Slice(result, func(i, j int) bool {
+		vi, vj := getUpgradeVersion(result[i]), getUpgradeVersion(result[j])
+		return vi == "latest" || (vj != "latest" && semver.Compare(vi, vj) > 0)
+	})
+	return result
+}
+
+func getUpgradeVersion(p protocol.CodeAction) string {
+	return strings.TrimPrefix(p.Title, upgradeCodeActionPrefix)
+}
diff --git a/gopls/internal/regtest/misc/vuln_test.go b/gopls/internal/regtest/misc/vuln_test.go
index 1b55a6c..ecebfb5 100644
--- a/gopls/internal/regtest/misc/vuln_test.go
+++ b/gopls/internal/regtest/misc/vuln_test.go
@@ -9,7 +9,6 @@
 
 import (
 	"context"
-	"strings"
 	"testing"
 
 	"golang.org/x/tools/gopls/internal/lsp/command"
@@ -310,7 +309,6 @@
 		},
 	).Run(t, workspace1, func(t *testing.T, env *Env) {
 		env.OpenFile("go.mod")
-		env.ExecuteCodeLensCommand("go.mod", command.Tidy)
 
 		env.ExecuteCodeLensCommand("go.mod", command.RunVulncheckExp)
 		d := &protocol.PublishDiagnosticsParams{}
@@ -318,20 +316,117 @@
 			CompletedWork("govulncheck", 1, true),
 			ShownMessage("Found"),
 			OnceMet(
-				env.DiagnosticAtRegexpWithMessage("go.mod", `golang.org/amod`, "golang.org/amod has a known vulnerability: vuln in amod"),
-				env.DiagnosticAtRegexpWithMessage("go.mod", `golang.org/amod`, "golang.org/amod has a known vulnerability: unaffecting vulnerability"),
-				env.DiagnosticAtRegexpWithMessage("go.mod", `golang.org/bmod`, "golang.org/bmod has a known vulnerability: vuln in bmod\n\nThis is a long description of this vulnerability."),
+				env.DiagnosticAtRegexp("go.mod", `golang.org/amod`),
 				ReadDiagnostics("go.mod", d),
 			),
 		)
 
-		var toFix []protocol.Diagnostic
-		for _, diag := range d.Diagnostics {
-			if strings.Contains(diag.Message, "vuln in ") {
-				toFix = append(toFix, diag)
-			}
+		type diagnostic struct {
+			msg      string
+			severity protocol.DiagnosticSeverity
+			// codeActions is a list titles of code actions that we get with this
+			// diagnostics as the context.
+			codeActions []string
 		}
-		env.ApplyQuickFixes("go.mod", toFix)
+		// wantDiagnostics maps a module path in the require
+		// section of a go.mod to diagnostics that will be returned
+		// when running vulncheck.
+		wantDiagnostics := map[string]struct {
+			// applyAction is the title of the code action to run for this module.
+			// If empty, no code actions will be executed.
+			applyAction string
+			// diagnostics is the list of diagnostics we expect at the require line for
+			// the module path.
+			diagnostics []diagnostic
+			// codeActions is a list titles of code actions that we get with context
+			// diagnostics.
+			codeActions []string
+		}{
+			"golang.org/amod": {
+				applyAction: "Upgrade to v1.0.4",
+				diagnostics: []diagnostic{
+					{
+						msg:      "golang.org/amod has a known vulnerability: vuln in amod",
+						severity: protocol.SeverityWarning,
+						codeActions: []string{
+							"Upgrade to latest",
+							"Upgrade to v1.0.4",
+						},
+					},
+					{
+						msg:      "golang.org/amod has a known vulnerability: unaffecting vulnerability",
+						severity: protocol.SeverityInformation,
+						codeActions: []string{
+							"Upgrade to latest",
+							"Upgrade to v1.0.6",
+						},
+					},
+				},
+				codeActions: []string{
+					"Upgrade to latest",
+					"Upgrade to v1.0.6",
+					"Upgrade to v1.0.4",
+				},
+			},
+			"golang.org/bmod": {
+				diagnostics: []diagnostic{
+					{
+						msg:      "golang.org/bmod has a known vulnerability: vuln in bmod\n\nThis is a long description of this vulnerability.",
+						severity: protocol.SeverityWarning,
+					},
+				},
+			},
+		}
+
+		for mod, want := range wantDiagnostics {
+			pos := env.RegexpSearch("go.mod", mod)
+			var modPathDiagnostics []protocol.Diagnostic
+			for _, w := range want.diagnostics {
+				// Find the diagnostics at pos.
+				var diag *protocol.Diagnostic
+				for _, g := range d.Diagnostics {
+					g := g
+					if g.Range.Start == pos.ToProtocolPosition() && w.msg == g.Message {
+						modPathDiagnostics = append(modPathDiagnostics, g)
+						diag = &g
+						break
+					}
+				}
+				if diag == nil {
+					t.Errorf("no diagnostic at %q matching %q found\n", mod, w.msg)
+					continue
+				}
+				if diag.Severity != w.severity {
+					t.Errorf("incorrect severity for %q, expected %s got %s\n", w.msg, w.severity, diag.Severity)
+				}
+
+				gotActions := env.CodeAction("go.mod", []protocol.Diagnostic{*diag})
+				if !sameCodeActions(gotActions, w.codeActions) {
+					t.Errorf("code actions for %q do not match, expected %v, got %v\n", w.msg, w.codeActions, gotActions)
+					continue
+				}
+			}
+
+			// Check that the actions we get when including all diagnostics at a location return the same result
+			gotActions := env.CodeAction("go.mod", modPathDiagnostics)
+			if !sameCodeActions(gotActions, want.codeActions) {
+				t.Errorf("code actions for %q do not match, expected %v, got %v\n", mod, want.codeActions, gotActions)
+				continue
+			}
+
+			// Apply the code action matching applyAction.
+			if want.applyAction == "" {
+				continue
+			}
+			for _, action := range gotActions {
+				if action.Title == want.applyAction {
+					env.ApplyCodeAction(action)
+					break
+				}
+			}
+
+		}
+
 		env.Await(env.DoneWithChangeWatchedFiles())
 		wantGoMod := `module golang.org/entry
 
@@ -349,3 +444,19 @@
 		}
 	})
 }
+
+func sameCodeActions(gotActions []protocol.CodeAction, want []string) bool {
+	gotTitles := make([]string, len(gotActions))
+	for i, ca := range gotActions {
+		gotTitles[i] = ca.Title
+	}
+	if len(gotTitles) != len(want) {
+		return false
+	}
+	for i := range want {
+		if gotTitles[i] != want[i] {
+			return false
+		}
+	}
+	return true
+}