internal/lsp: associate code action diagnostics with suggested fixes

Instead of relying on the diagnostics cached on the package, use the
diagnostics sent by the code action when computing suggested fixes.

Change-Id: I77f7fd468b34b824c6c5000a51edbe0f8cc6f637
Reviewed-on: https://go-review.googlesource.com/c/tools/+/197097
Run-TryBot: Rebecca Stambler <rstambler@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Cottrell <iancottrell@google.com>
diff --git a/internal/lsp/cache/pkg.go b/internal/lsp/cache/pkg.go
index 12b8590..5cf1660 100644
--- a/internal/lsp/cache/pkg.go
+++ b/internal/lsp/cache/pkg.go
@@ -13,6 +13,7 @@
 
 	"golang.org/x/tools/go/analysis"
 	"golang.org/x/tools/go/packages"
+	"golang.org/x/tools/internal/lsp/protocol"
 	"golang.org/x/tools/internal/lsp/source"
 	"golang.org/x/tools/internal/span"
 	errors "golang.org/x/xerrors"
@@ -200,15 +201,25 @@
 	pkg.diagnostics[a] = diags
 }
 
-func (pkg *pkg) GetDiagnostics() []source.Diagnostic {
-	pkg.diagMu.Lock()
-	defer pkg.diagMu.Unlock()
+func (p *pkg) FindDiagnostic(pdiag protocol.Diagnostic) (*source.Diagnostic, error) {
+	p.diagMu.Lock()
+	defer p.diagMu.Unlock()
 
-	var diags []source.Diagnostic
-	for _, d := range pkg.diagnostics {
-		diags = append(diags, d...)
+	for a, diagnostics := range p.diagnostics {
+		if a.Name != pdiag.Source {
+			continue
+		}
+		for _, d := range diagnostics {
+			if d.Message != pdiag.Message {
+				continue
+			}
+			if protocol.CompareRange(d.Range, pdiag.Range) != 0 {
+				continue
+			}
+			return &d, nil
+		}
 	}
-	return diags
+	return nil, errors.Errorf("no matching diagnostic for %v", pdiag)
 }
 
 func (p *pkg) FindFile(ctx context.Context, uri span.URI) (source.ParseGoHandle, source.Package, error) {
diff --git a/internal/lsp/code_action.go b/internal/lsp/code_action.go
index c1889a5..0e6523c 100644
--- a/internal/lsp/code_action.go
+++ b/internal/lsp/code_action.go
@@ -75,21 +75,21 @@
 		if err != nil {
 			return nil, err
 		}
-		if wanted[protocol.QuickFix] {
+		if diagnostics := params.Context.Diagnostics; wanted[protocol.QuickFix] && len(diagnostics) > 0 {
 			// First, add the quick fixes reported by go/analysis.
-			qf, err := quickFixes(ctx, view, gof)
+			qf, err := quickFixes(ctx, view, gof, diagnostics)
 			if err != nil {
 				log.Error(ctx, "quick fixes failed", err, telemetry.File.Of(uri))
 			}
 			codeActions = append(codeActions, qf...)
 
 			// If we also have diagnostics for missing imports, we can associate them with quick fixes.
-			if findImportErrors(params.Context.Diagnostics) {
+			if findImportErrors(diagnostics) {
 				// Separate this into a set of codeActions per diagnostic, where
 				// each action is the addition, removal, or renaming of one import.
 				for _, importFix := range editsPerFix {
 					// Get the diagnostics this fix would affect.
-					if fixDiagnostics := importDiagnostics(importFix.Fix, params.Context.Diagnostics); len(fixDiagnostics) > 0 {
+					if fixDiagnostics := importDiagnostics(importFix.Fix, diagnostics); len(fixDiagnostics) > 0 {
 						codeActions = append(codeActions, protocol.CodeAction{
 							Title: importFixTitle(importFix.Fix),
 							Kind:  protocol.QuickFix,
@@ -207,36 +207,36 @@
 	return results
 }
 
-func quickFixes(ctx context.Context, view source.View, gof source.GoFile) ([]protocol.CodeAction, error) {
+func quickFixes(ctx context.Context, view source.View, gof source.GoFile, diagnostics []protocol.Diagnostic) ([]protocol.CodeAction, error) {
 	var codeActions []protocol.CodeAction
-
-	// TODO: This is technically racy because the diagnostics provided by the code action
-	// may not be the same as the ones that gopls is aware of.
-	// We need to figure out some way to solve this problem.
 	cphs, err := gof.CheckPackageHandles(ctx)
 	if err != nil {
 		return nil, err
 	}
-	cph := source.NarrowestCheckPackageHandle(cphs)
+	// We get the package that source.Diagnostics would've used. This is hack.
+	// TODO(golang/go#32443): The correct solution will be to cache diagnostics per-file per-snapshot.
+	cph := source.WidestCheckPackageHandle(cphs)
 	pkg, err := cph.Cached(ctx)
 	if err != nil {
 		return nil, err
 	}
-	for _, diag := range pkg.GetDiagnostics() {
-		pdiag, err := toProtocolDiagnostic(ctx, diag)
+	for _, diag := range diagnostics {
+		sdiag, err := pkg.FindDiagnostic(diag)
 		if err != nil {
-			return nil, err
+			continue
 		}
-		for _, fix := range diag.SuggestedFixes {
+		for _, fix := range sdiag.SuggestedFixes {
+			edits := make(map[string][]protocol.TextEdit)
+			for uri, e := range fix.Edits {
+				edits[protocol.NewURI(uri)] = e
+			}
 			codeActions = append(codeActions, protocol.CodeAction{
-				Title: fix.Title,
-				Kind:  protocol.QuickFix, // TODO(matloob): Be more accurate about these?
+				Title:       fix.Title,
+				Kind:        protocol.QuickFix,
+				Diagnostics: []protocol.Diagnostic{diag},
 				Edit: &protocol.WorkspaceEdit{
-					Changes: &map[string][]protocol.TextEdit{
-						protocol.NewURI(diag.URI): fix.Edits,
-					},
+					Changes: &edits,
 				},
-				Diagnostics: []protocol.Diagnostic{pdiag},
 			})
 		}
 	}
diff --git a/internal/lsp/diagnostics.go b/internal/lsp/diagnostics.go
index b625439..7818027 100644
--- a/internal/lsp/diagnostics.go
+++ b/internal/lsp/diagnostics.go
@@ -74,30 +74,22 @@
 }
 
 func (s *Server) publishDiagnostics(ctx context.Context, uri span.URI, diagnostics []source.Diagnostic) error {
-	protocolDiagnostics, err := toProtocolDiagnostics(ctx, diagnostics)
-	if err != nil {
-		return err
-	}
 	s.client.PublishDiagnostics(ctx, &protocol.PublishDiagnosticsParams{
-		Diagnostics: protocolDiagnostics,
+		Diagnostics: toProtocolDiagnostics(ctx, diagnostics),
 		URI:         protocol.NewURI(uri),
 	})
 	return nil
 }
 
-func toProtocolDiagnostics(ctx context.Context, diagnostics []source.Diagnostic) ([]protocol.Diagnostic, error) {
+func toProtocolDiagnostics(ctx context.Context, diagnostics []source.Diagnostic) []protocol.Diagnostic {
 	reports := []protocol.Diagnostic{}
 	for _, diag := range diagnostics {
-		diagnostic, err := toProtocolDiagnostic(ctx, diag)
-		if err != nil {
-			return nil, err
-		}
-		reports = append(reports, diagnostic)
+		reports = append(reports, toProtocolDiagnostic(ctx, diag))
 	}
-	return reports, nil
+	return reports
 }
 
-func toProtocolDiagnostic(ctx context.Context, diag source.Diagnostic) (protocol.Diagnostic, error) {
+func toProtocolDiagnostic(ctx context.Context, diag source.Diagnostic) protocol.Diagnostic {
 	var severity protocol.DiagnosticSeverity
 	switch diag.Severity {
 	case source.SeverityError:
@@ -110,5 +102,6 @@
 		Range:    diag.Range,
 		Severity: severity,
 		Source:   diag.Source,
-	}, nil
+		Tags:     diag.Tags,
+	}
 }
diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go
index c133bac..44a5729 100644
--- a/internal/lsp/lsp_test.go
+++ b/internal/lsp/lsp_test.go
@@ -331,26 +331,28 @@
 	for _, spn := range data {
 		uri := spn.URI()
 		filename := uri.Filename()
-		v := r.server.session.ViewOf(uri)
+		view := r.server.session.ViewOf(uri)
 		fixed := string(r.data.Golden("suggestedfix", filename, func() ([]byte, error) {
 			cmd := exec.Command("suggestedfix", filename) // TODO(matloob): what do we do here?
 			out, _ := cmd.Output()                        // ignore error, sometimes we have intentionally ungofmt-able files
 			return out, nil
 		}))
-		f, err := getGoFile(r.ctx, v, uri)
+		f, err := getGoFile(r.ctx, view, uri)
 		if err != nil {
 			t.Fatal(err)
 		}
-		results, _, err := source.Diagnostics(r.ctx, v, f, nil)
+		diagnostics, _, err := source.Diagnostics(r.ctx, view, f, nil)
 		if err != nil {
 			t.Fatal(err)
 		}
-		_ = results
 		actions, err := r.server.CodeAction(r.ctx, &protocol.CodeActionParams{
 			TextDocument: protocol.TextDocumentIdentifier{
 				URI: protocol.NewURI(uri),
 			},
-			Context: protocol.CodeActionContext{Only: []protocol.CodeActionKind{protocol.QuickFix}},
+			Context: protocol.CodeActionContext{
+				Only:        []protocol.CodeActionKind{protocol.QuickFix},
+				Diagnostics: toProtocolDiagnostics(r.ctx, diagnostics[uri]),
+			},
 		})
 		if err != nil {
 			if fixed != "" {
diff --git a/internal/lsp/source/analysis.go b/internal/lsp/source/analysis.go
index e9c9ecc..2c7961e 100644
--- a/internal/lsp/source/analysis.go
+++ b/internal/lsp/source/analysis.go
@@ -61,17 +61,18 @@
 // package (as different analyzers are applied, either in sequence or
 // parallel), and across packages (as dependencies are analyzed).
 type Action struct {
-	once         sync.Once
-	Analyzer     *analysis.Analyzer
-	Pkg          Package
-	Deps         []*Action
+	once        sync.Once
+	Analyzer    *analysis.Analyzer
+	Pkg         Package
+	Deps        []*Action
+	diagnostics []analysis.Diagnostic
+
 	pass         *analysis.Pass
 	isroot       bool
 	objectFacts  map[objectFactKey]analysis.Fact
 	packageFacts map[packageFactKey]analysis.Fact
 	inputs       map[*analysis.Analyzer]interface{}
 	result       interface{}
-	diagnostics  []analysis.Diagnostic
 	err          error
 	duration     time.Duration
 	view         View
diff --git a/internal/lsp/source/diagnostics.go b/internal/lsp/source/diagnostics.go
index 922fd72..c48a2cc 100644
--- a/internal/lsp/source/diagnostics.go
+++ b/internal/lsp/source/diagnostics.go
@@ -26,15 +26,11 @@
 	Message  string
 	Source   string
 	Severity DiagnosticSeverity
+	Tags     []protocol.DiagnosticTag
 
 	SuggestedFixes []SuggestedFix
 }
 
-type SuggestedFix struct {
-	Title string
-	Edits []protocol.TextEdit
-}
-
 type DiagnosticSeverity int
 
 const (
@@ -238,22 +234,29 @@
 	if err != nil {
 		return Diagnostic{}, err
 	}
-	ca, err := getCodeActions(ctx, view, pkg, diag)
-	if err != nil {
-		return Diagnostic{}, err
-	}
-
 	rng, err := spanToRange(ctx, view, pkg, spn, false)
 	if err != nil {
 		return Diagnostic{}, err
 	}
+	fixes, err := suggestedFixes(ctx, view, pkg, diag)
+	if err != nil {
+		return Diagnostic{}, err
+	}
+	// This is a bit of a hack, but clients > 3.15 will be able to grey out unnecessary code.
+	// If we are deleting code as part of all of our suggested fixes, assume that this is dead code.
+	// TODO(golang/go/#34508): Return these codes from the diagnostics themselves.
+	var tags []protocol.DiagnosticTag
+	if onlyDeletions(fixes) {
+		tags = append(tags, protocol.Unnecessary)
+	}
 	return Diagnostic{
 		URI:            spn.URI(),
 		Range:          rng,
 		Source:         category,
 		Message:        diag.Message,
 		Severity:       SeverityWarning,
-		SuggestedFixes: ca,
+		SuggestedFixes: fixes,
+		Tags:           tags,
 	}, nil
 }
 
diff --git a/internal/lsp/source/suggested_fix.go b/internal/lsp/source/suggested_fix.go
index 99eab5c..dd5f54a 100644
--- a/internal/lsp/source/suggested_fix.go
+++ b/internal/lsp/source/suggested_fix.go
@@ -8,13 +8,19 @@
 	"golang.org/x/tools/internal/span"
 )
 
-func getCodeActions(ctx context.Context, view View, pkg Package, diag analysis.Diagnostic) ([]SuggestedFix, error) {
+type SuggestedFix struct {
+	Title string
+	Edits map[span.URI][]protocol.TextEdit
+}
+
+func suggestedFixes(ctx context.Context, view View, pkg Package, diag analysis.Diagnostic) ([]SuggestedFix, error) {
 	var fixes []SuggestedFix
 	for _, fix := range diag.SuggestedFixes {
-		var edits []protocol.TextEdit
+		edits := make(map[span.URI][]protocol.TextEdit)
 		for _, e := range fix.TextEdits {
 			posn := view.Session().Cache().FileSet().Position(e.Pos)
-			ph, _, err := pkg.FindFile(ctx, span.FileURI(posn.Filename))
+			uri := span.FileURI(posn.Filename)
+			ph, _, err := pkg.FindFile(ctx, uri)
 			if err != nil {
 				return nil, err
 			}
@@ -30,7 +36,7 @@
 			if err != nil {
 				return nil, err
 			}
-			edits = append(edits, protocol.TextEdit{
+			edits[uri] = append(edits[uri], protocol.TextEdit{
 				Range:   rng,
 				NewText: string(e.NewText),
 			})
@@ -42,3 +48,20 @@
 	}
 	return fixes, nil
 }
+
+// onlyDeletions returns true if all of the suggested fixes are deletions.
+func onlyDeletions(fixes []SuggestedFix) bool {
+	for _, fix := range fixes {
+		for _, edits := range fix.Edits {
+			for _, edit := range edits {
+				if edit.NewText != "" {
+					return false
+				}
+				if protocol.ComparePosition(edit.Range.Start, edit.Range.End) == 0 {
+					return false
+				}
+			}
+		}
+	}
+	return true
+}
diff --git a/internal/lsp/source/view.go b/internal/lsp/source/view.go
index dfacffd..aa8a872 100644
--- a/internal/lsp/source/view.go
+++ b/internal/lsp/source/view.go
@@ -297,8 +297,9 @@
 	GetTypesInfo() *types.Info
 	GetTypesSizes() types.Sizes
 	IsIllTyped() bool
-	GetDiagnostics() []Diagnostic
-	SetDiagnostics(a *analysis.Analyzer, diag []Diagnostic)
+
+	SetDiagnostics(*analysis.Analyzer, []Diagnostic)
+	FindDiagnostic(protocol.Diagnostic) (*Diagnostic, error)
 
 	// GetImport returns the CheckPackageHandle for a package imported by this package.
 	GetImport(ctx context.Context, pkgPath string) (Package, error)
diff --git a/internal/lsp/tests/tests.go b/internal/lsp/tests/tests.go
index 87cd807..0f08e19 100644
--- a/internal/lsp/tests/tests.go
+++ b/internal/lsp/tests/tests.go
@@ -295,28 +295,26 @@
 
 	// Collect any data that needs to be used by subsequent tests.
 	if err := data.Exported.Expect(map[string]interface{}{
-		"diag":       data.collectDiagnostics,
-		"item":       data.collectCompletionItems,
-		"complete":   data.collectCompletions(CompletionDefault),
-		"unimported": data.collectCompletions(CompletionUnimported),
-		"deep":       data.collectCompletions(CompletionDeep),
-		"fuzzy":      data.collectCompletions(CompletionFuzzy),
-		"rank":       data.collectCompletions(CompletionRank),
-		"snippet":    data.collectCompletionSnippets,
-		"fold":       data.collectFoldingRanges,
-		"format":     data.collectFormats,
-		"import":     data.collectImports,
-		"godef":      data.collectDefinitions,
-		"typdef":     data.collectTypeDefinitions,
-		"hover":      data.collectHoverDefinitions,
-		"highlight":  data.collectHighlights,
-		"refs":       data.collectReferences,
-		"rename":     data.collectRenames,
-		"prepare":    data.collectPrepareRenames,
-		"symbol":     data.collectSymbols,
-		"signature":  data.collectSignatures,
-
-		// LSP-only features.
+		"diag":         data.collectDiagnostics,
+		"item":         data.collectCompletionItems,
+		"complete":     data.collectCompletions(CompletionDefault),
+		"unimported":   data.collectCompletions(CompletionUnimported),
+		"deep":         data.collectCompletions(CompletionDeep),
+		"fuzzy":        data.collectCompletions(CompletionFuzzy),
+		"rank":         data.collectCompletions(CompletionRank),
+		"snippet":      data.collectCompletionSnippets,
+		"fold":         data.collectFoldingRanges,
+		"format":       data.collectFormats,
+		"import":       data.collectImports,
+		"godef":        data.collectDefinitions,
+		"typdef":       data.collectTypeDefinitions,
+		"hover":        data.collectHoverDefinitions,
+		"highlight":    data.collectHighlights,
+		"refs":         data.collectReferences,
+		"rename":       data.collectRenames,
+		"prepare":      data.collectPrepareRenames,
+		"symbol":       data.collectSymbols,
+		"signature":    data.collectSignatures,
 		"link":         data.collectLinks,
 		"suggestedfix": data.collectSuggestedFixes,
 	}); err != nil {