internal/lsp: use protocol.Range for diagnostics instead of span.Span

This is the first in a series of many changes that will change the API
of the source package to use different types for positions. Using
token.Pos is particularly fragile, since the pos has to refer to the
specific *ast.File from which it was derived.

Change-Id: I70c9b806f7dd45b2e229954ebdcdd86e2cf3bbbb
Reviewed-on: https://go-review.googlesource.com/c/tools/+/190340
Run-TryBot: Rebecca Stambler <rstambler@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Cottrell <iancottrell@google.com>
diff --git a/internal/lsp/cmd/check_test.go b/internal/lsp/cmd/check_test.go
index 9771ca9..7f39ee8 100644
--- a/internal/lsp/cmd/check_test.go
+++ b/internal/lsp/cmd/check_test.go
@@ -6,6 +6,7 @@
 
 import (
 	"fmt"
+	"io/ioutil"
 	"strings"
 	"testing"
 
@@ -37,13 +38,25 @@
 			if len(bits) == 2 {
 				spn := span.Parse(strings.TrimSpace(bits[0]))
 				spn = span.New(spn.URI(), spn.Start(), span.Point{})
-				l = fmt.Sprintf("%s: %s", spn, strings.TrimSpace(bits[1]))
+				data, err := ioutil.ReadFile(fname)
+				if err != nil {
+					t.Fatal(err)
+				}
+				converter := span.NewContentConverter(fname, data)
+				s, err := spn.WithPosition(converter)
+				if err != nil {
+					t.Fatal(err)
+				}
+				l = fmt.Sprintf("%s: %s", s, strings.TrimSpace(bits[1]))
 			}
 			got[l] = struct{}{}
 		}
 		for _, diag := range want {
-			spn := span.New(diag.Span.URI(), diag.Span.Start(), diag.Span.Start())
-			expect := fmt.Sprintf("%v: %v", spn, diag.Message)
+			// TODO: This is a hack, fix this.
+			expect := fmt.Sprintf("%v:%v:%v: %v", diag.URI.Filename(), diag.Range.Start.Line+1, diag.Range.Start.Character+1, diag.Message)
+			if diag.Range.Start.Character == 0 {
+				expect = fmt.Sprintf("%v:%v: %v", diag.URI.Filename(), diag.Range.Start.Line+1, diag.Message)
+			}
 			_, found := got[expect]
 			if !found {
 				t.Errorf("missing diagnostic %q", expect)
diff --git a/internal/lsp/code_action.go b/internal/lsp/code_action.go
index 4613d24..fe8eef6 100644
--- a/internal/lsp/code_action.go
+++ b/internal/lsp/code_action.go
@@ -224,12 +224,12 @@
 		return nil, err
 	}
 	for _, diag := range pkg.GetDiagnostics() {
-		pdiag, err := toProtocolDiagnostic(ctx, view, diag)
+		pdiag, err := toProtocolDiagnostic(ctx, diag)
 		if err != nil {
 			return nil, err
 		}
 		for _, ca := range diag.SuggestedFixes {
-			_, m, err := getGoFile(ctx, view, diag.URI())
+			_, m, err := getGoFile(ctx, view, diag.URI)
 			if err != nil {
 				return nil, err
 			}
@@ -242,7 +242,7 @@
 				Kind:  protocol.QuickFix, // TODO(matloob): Be more accurate about these?
 				Edit: &protocol.WorkspaceEdit{
 					Changes: &map[string][]protocol.TextEdit{
-						string(diag.URI()): edits,
+						protocol.NewURI(diag.URI): edits,
 					},
 				},
 				Diagnostics: []protocol.Diagnostic{pdiag},
diff --git a/internal/lsp/diagnostics.go b/internal/lsp/diagnostics.go
index 79f3774..c3a897c 100644
--- a/internal/lsp/diagnostics.go
+++ b/internal/lsp/diagnostics.go
@@ -60,7 +60,7 @@
 }
 
 func (s *Server) publishDiagnostics(ctx context.Context, view source.View, uri span.URI, diagnostics []source.Diagnostic) error {
-	protocolDiagnostics, err := toProtocolDiagnostics(ctx, view, diagnostics)
+	protocolDiagnostics, err := toProtocolDiagnostics(ctx, diagnostics)
 	if err != nil {
 		return err
 	}
@@ -71,10 +71,10 @@
 	return nil
 }
 
-func toProtocolDiagnostics(ctx context.Context, v source.View, diagnostics []source.Diagnostic) ([]protocol.Diagnostic, error) {
+func toProtocolDiagnostics(ctx context.Context, diagnostics []source.Diagnostic) ([]protocol.Diagnostic, error) {
 	reports := []protocol.Diagnostic{}
 	for _, diag := range diagnostics {
-		diagnostic, err := toProtocolDiagnostic(ctx, v, diag)
+		diagnostic, err := toProtocolDiagnostic(ctx, diag)
 		if err != nil {
 			return nil, err
 		}
@@ -83,11 +83,7 @@
 	return reports, nil
 }
 
-func toProtocolDiagnostic(ctx context.Context, v source.View, diag source.Diagnostic) (protocol.Diagnostic, error) {
-	_, m, err := getSourceFile(ctx, v, diag.Span.URI())
-	if err != nil {
-		return protocol.Diagnostic{}, err
-	}
+func toProtocolDiagnostic(ctx context.Context, diag source.Diagnostic) (protocol.Diagnostic, error) {
 	var severity protocol.DiagnosticSeverity
 	switch diag.Severity {
 	case source.SeverityError:
@@ -95,13 +91,9 @@
 	case source.SeverityWarning:
 		severity = protocol.SeverityWarning
 	}
-	rng, err := m.Range(diag.Span)
-	if err != nil {
-		return protocol.Diagnostic{}, err
-	}
 	return protocol.Diagnostic{
 		Message:  strings.TrimSpace(diag.Message), // go list returns errors prefixed by newline
-		Range:    rng,
+		Range:    diag.Range,
 		Severity: severity,
 		Source:   diag.Source,
 	}, nil
diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go
index 800c5e9..56db483 100644
--- a/internal/lsp/lsp_test.go
+++ b/internal/lsp/lsp_test.go
@@ -91,76 +91,12 @@
 			}
 			continue
 		}
-		if diff := diffDiagnostics(uri, want, got); diff != "" {
+		if diff := tests.DiffDiagnostics(uri, want, got); diff != "" {
 			t.Error(diff)
 		}
 	}
 }
 
-func sortDiagnostics(d []source.Diagnostic) {
-	sort.Slice(d, func(i int, j int) bool {
-		if r := span.Compare(d[i].Span, d[j].Span); r != 0 {
-			return r < 0
-		}
-		return d[i].Message < d[j].Message
-	})
-}
-
-// diffDiagnostics prints the diff between expected and actual diagnostics test
-// results.
-func diffDiagnostics(uri span.URI, want, got []source.Diagnostic) string {
-	sortDiagnostics(want)
-	sortDiagnostics(got)
-	if len(got) != len(want) {
-		return summarizeDiagnostics(-1, want, got, "different lengths got %v want %v", len(got), len(want))
-	}
-	for i, w := range want {
-		g := got[i]
-		if w.Message != g.Message {
-			return summarizeDiagnostics(i, want, got, "incorrect Message got %v want %v", g.Message, w.Message)
-		}
-		if span.ComparePoint(w.Start(), g.Start()) != 0 {
-			return summarizeDiagnostics(i, want, got, "incorrect Start got %v want %v", g.Start(), w.Start())
-		}
-		// Special case for diagnostics on parse errors.
-		if strings.Contains(string(uri), "noparse") {
-			if span.ComparePoint(g.Start(), g.End()) != 0 || span.ComparePoint(w.Start(), g.End()) != 0 {
-				return summarizeDiagnostics(i, want, got, "incorrect End got %v want %v", g.End(), w.Start())
-			}
-		} else if !g.IsPoint() { // Accept any 'want' range if the diagnostic returns a zero-length range.
-			if span.ComparePoint(w.End(), g.End()) != 0 {
-				return summarizeDiagnostics(i, want, got, "incorrect End got %v want %v", g.End(), w.End())
-			}
-		}
-		if w.Severity != g.Severity {
-			return summarizeDiagnostics(i, want, got, "incorrect Severity got %v want %v", g.Severity, w.Severity)
-		}
-		if w.Source != g.Source {
-			return summarizeDiagnostics(i, want, got, "incorrect Source got %v want %v", g.Source, w.Source)
-		}
-	}
-	return ""
-}
-
-func summarizeDiagnostics(i int, want []source.Diagnostic, got []source.Diagnostic, reason string, args ...interface{}) string {
-	msg := &bytes.Buffer{}
-	fmt.Fprint(msg, "diagnostics failed")
-	if i >= 0 {
-		fmt.Fprintf(msg, " at %d", i)
-	}
-	fmt.Fprint(msg, " because of ")
-	fmt.Fprintf(msg, reason, args...)
-	fmt.Fprint(msg, ":\nexpected:\n")
-	for _, d := range want {
-		fmt.Fprintf(msg, "  %v: %s\n", d.Span, d.Message)
-	}
-	fmt.Fprintf(msg, "got:\n")
-	for _, d := range got {
-		fmt.Fprintf(msg, "  %v: %s\n", d.Span, d.Message)
-	}
-	return msg.String()
-}
-
 func (r *runner) Completion(t *testing.T, data tests.Completions, snippets tests.CompletionSnippets, items tests.CompletionItems) {
 	defer func() {
 		r.server.useDeepCompletions = false
diff --git a/internal/lsp/protocol/span.go b/internal/lsp/protocol/span.go
index d0e4e84..feb25b0 100644
--- a/internal/lsp/protocol/span.go
+++ b/internal/lsp/protocol/span.go
@@ -7,6 +7,7 @@
 package protocol
 
 import (
+	"fmt"
 	"go/token"
 
 	"golang.org/x/tools/internal/span"
@@ -108,3 +109,34 @@
 	lineStart := span.NewPoint(line, 1, offset)
 	return span.FromUTF16Column(lineStart, int(p.Character)+1, m.Content)
 }
+
+func IsPoint(r Range) bool {
+	return r.Start.Line == r.End.Line && r.Start.Character == r.End.Character
+}
+
+func CompareRange(a, b Range) int {
+	if r := ComparePosition(a.Start, b.Start); r != 0 {
+		return r
+	}
+	return ComparePosition(a.End, b.End)
+}
+
+func ComparePosition(a, b Position) int {
+	if a.Line < b.Line {
+		return -1
+	}
+	if a.Line > b.Line {
+		return 1
+	}
+	if a.Character < b.Character {
+		return -1
+	}
+	if a.Character > b.Character {
+		return 1
+	}
+	return 0
+}
+
+func (r Range) Format(f fmt.State, _ rune) {
+	fmt.Fprintf(f, "%v:%v-%v:%v", r.Start.Line, r.Start.Character, r.End.Line, r.End.Character)
+}
diff --git a/internal/lsp/source/diagnostics.go b/internal/lsp/source/diagnostics.go
index 92f04f1..442ae78 100644
--- a/internal/lsp/source/diagnostics.go
+++ b/internal/lsp/source/diagnostics.go
@@ -8,6 +8,7 @@
 	"bytes"
 	"context"
 	"fmt"
+	"go/ast"
 	"strings"
 
 	"golang.org/x/tools/go/analysis"
@@ -34,14 +35,17 @@
 	"golang.org/x/tools/go/analysis/passes/unsafeptr"
 	"golang.org/x/tools/go/analysis/passes/unusedresult"
 	"golang.org/x/tools/go/packages"
+	"golang.org/x/tools/internal/lsp/protocol"
 	"golang.org/x/tools/internal/lsp/telemetry"
 	"golang.org/x/tools/internal/span"
 	"golang.org/x/tools/internal/telemetry/log"
 	"golang.org/x/tools/internal/telemetry/trace"
+	errors "golang.org/x/xerrors"
 )
 
 type Diagnostic struct {
-	span.Span
+	URI      span.URI
+	Range    protocol.Range
 	Message  string
 	Source   string
 	Severity DiagnosticSeverity
@@ -111,7 +115,7 @@
 }
 
 type diagnosticSet struct {
-	listErrors, parseErrors, typeErrors []Diagnostic
+	listErrors, parseErrors, typeErrors []*Diagnostic
 }
 
 func diagnostics(ctx context.Context, view View, pkg Package, reports map[span.URI][]Diagnostic) bool {
@@ -120,28 +124,32 @@
 
 	diagSets := make(map[span.URI]*diagnosticSet)
 	for _, err := range pkg.GetErrors() {
-		diag := Diagnostic{
-			Span:     packagesErrorSpan(err),
+		spn := packagesErrorSpan(err)
+		diag := &Diagnostic{
+			URI:      spn.URI(),
 			Message:  err.Msg,
 			Source:   "LSP",
 			Severity: SeverityError,
 		}
-		set, ok := diagSets[diag.Span.URI()]
+		set, ok := diagSets[diag.URI]
 		if !ok {
 			set = &diagnosticSet{}
-			diagSets[diag.Span.URI()] = set
+			diagSets[diag.URI] = set
 		}
 		switch err.Kind {
 		case packages.ParseError:
 			set.parseErrors = append(set.parseErrors, diag)
 		case packages.TypeError:
-			if diag.Span.IsPoint() {
-				diag.Span = pointToSpan(ctx, view, diag.Span)
-			}
 			set.typeErrors = append(set.typeErrors, diag)
 		default:
 			set.listErrors = append(set.listErrors, diag)
 		}
+		rng, err := spanToRange(ctx, view, pkg, spn, err.Kind == packages.TypeError)
+		if err != nil {
+			log.Error(ctx, "failed to convert span to range", err)
+			continue
+		}
+		diag.Range = rng
 	}
 	var nonEmptyDiagnostics bool // track if we actually send non-empty diagnostics
 	for uri, set := range diagSets {
@@ -157,21 +165,64 @@
 		}
 		for _, diag := range diags {
 			if _, ok := reports[uri]; ok {
-				reports[uri] = append(reports[uri], diag)
+				reports[uri] = append(reports[uri], *diag)
 			}
 		}
 	}
 	return nonEmptyDiagnostics
 }
 
-func analyses(ctx context.Context, v View, cph CheckPackageHandle, disabledAnalyses map[string]struct{}, reports map[span.URI][]Diagnostic) error {
+// spanToRange converts a span.Span to a protocol.Range,
+// assuming that the span belongs to the package whose diagnostics are being computed.
+func spanToRange(ctx context.Context, view View, pkg Package, spn span.Span, isTypeError bool) (protocol.Range, error) {
+	var (
+		fh   FileHandle
+		file *ast.File
+		err  error
+	)
+	for _, ph := range pkg.GetHandles() {
+		if ph.File().Identity().URI == spn.URI() {
+			fh = ph.File()
+			file, err = ph.Cached(ctx)
+		}
+	}
+	if file == nil {
+		return protocol.Range{}, err
+	}
+	fset := view.Session().Cache().FileSet()
+	tok := fset.File(file.Pos())
+	if tok == nil {
+		return protocol.Range{}, errors.Errorf("no token.File for %s", spn.URI())
+	}
+	data, _, err := fh.Read(ctx)
+	if err != nil {
+		return protocol.Range{}, err
+	}
+	uri := fh.Identity().URI
+	m := protocol.NewColumnMapper(uri, uri.Filename(), fset, tok, data)
+
+	// Try to get a range for the diagnostic.
+	// TODO: Don't just limit ranges to type errors.
+	if spn.IsPoint() && isTypeError {
+		if s, err := spn.WithOffset(m.Converter); err == nil {
+			start := s.Start()
+			offset := start.Offset()
+			if width := bytes.IndexAny(data[offset:], " \n,():;[]"); width > 0 {
+				spn = span.New(spn.URI(), start, span.NewPoint(start.Line(), start.Column()+width, offset+width))
+			}
+		}
+	}
+	return m.Range(spn)
+}
+
+func analyses(ctx context.Context, view View, cph CheckPackageHandle, disabledAnalyses map[string]struct{}, reports map[span.URI][]Diagnostic) error {
 	// Type checking and parsing succeeded. Run analyses.
-	if err := runAnalyses(ctx, v, cph, disabledAnalyses, func(a *analysis.Analyzer, diag analysis.Diagnostic) error {
-		diagnostic, err := toDiagnostic(a, v, diag)
+	if err := runAnalyses(ctx, view, cph, disabledAnalyses, func(a *analysis.Analyzer, diag analysis.Diagnostic) error {
+		diagnostic, err := toDiagnostic(ctx, view, diag, a.Name)
 		if err != nil {
 			return err
 		}
-		addReport(v, reports, diagnostic.Span.URI(), diagnostic)
+		addReport(view, reports, diagnostic.URI, diagnostic)
 		return nil
 	}); err != nil {
 		return err
@@ -179,24 +230,42 @@
 	return nil
 }
 
-func toDiagnostic(a *analysis.Analyzer, v View, diag analysis.Diagnostic) (Diagnostic, error) {
-	r := span.NewRange(v.Session().Cache().FileSet(), diag.Pos, diag.End)
-	s, err := r.Span()
+func toDiagnostic(ctx context.Context, view View, diag analysis.Diagnostic, category string) (Diagnostic, error) {
+	r := span.NewRange(view.Session().Cache().FileSet(), diag.Pos, diag.End)
+	spn, err := r.Span()
 	if err != nil {
 		// The diagnostic has an invalid position, so we don't have a valid span.
 		return Diagnostic{}, err
 	}
-	category := a.Name
 	if diag.Category != "" {
 		category += "." + category
 	}
-	ca, err := getCodeActions(v.Session().Cache().FileSet(), diag)
+	ca, err := getCodeActions(view.Session().Cache().FileSet(), diag)
+	if err != nil {
+		return Diagnostic{}, err
+	}
+	f, err := view.GetFile(ctx, spn.URI())
+	if err != nil {
+		return Diagnostic{}, err
+	}
+	gof, ok := f.(GoFile)
+	if !ok {
+		return Diagnostic{}, errors.Errorf("%s is not a Go file", f.URI())
+	}
+	// If the package has changed since these diagnostics were computed,
+	// this may be incorrect. Should the package be associated with the diagnostic?
+	pkg, err := gof.GetCachedPackage(ctx)
+	if err != nil {
+		return Diagnostic{}, err
+	}
+	rng, err := spanToRange(ctx, view, pkg, spn, false)
 	if err != nil {
 		return Diagnostic{}, err
 	}
 	return Diagnostic{
+		URI:            spn.URI(),
+		Range:          rng,
 		Source:         category,
-		Span:           s,
 		Message:        diag.Message,
 		Severity:       SeverityWarning,
 		SuggestedFixes: ca,
@@ -214,7 +283,9 @@
 	if v.Ignore(uri) {
 		return
 	}
-	reports[uri] = append(reports[uri], diagnostic)
+	if _, ok := reports[uri]; ok {
+		reports[uri] = append(reports[uri], diagnostic)
+	}
 }
 
 func packagesErrorSpan(err packages.Error) span.Span {
@@ -241,49 +312,12 @@
 	return span.Parse(input[:msgIndex])
 }
 
-func pointToSpan(ctx context.Context, view View, spn span.Span) span.Span {
-	f, err := view.GetFile(ctx, spn.URI())
-	ctx = telemetry.File.With(ctx, spn.URI())
-	if err != nil {
-		log.Error(ctx, "could not find file for diagnostic", nil, telemetry.File)
-		return spn
-	}
-	diagFile, ok := f.(GoFile)
-	if !ok {
-		log.Error(ctx, "not a Go file", nil, telemetry.File)
-		return spn
-	}
-	tok, err := diagFile.GetToken(ctx)
-	if err != nil {
-		log.Error(ctx, "could not find token.File for diagnostic", err, telemetry.File)
-		return spn
-	}
-	data, _, err := diagFile.Handle(ctx).Read(ctx)
-	if err != nil {
-		log.Error(ctx, "could not find content for diagnostic", err, telemetry.File)
-		return spn
-	}
-	c := span.NewTokenConverter(diagFile.FileSet(), tok)
-	s, err := spn.WithOffset(c)
-	//we just don't bother producing an error if this failed
-	if err != nil {
-		log.Error(ctx, "invalid span for diagnostic", err, telemetry.File)
-		return spn
-	}
-	start := s.Start()
-	offset := start.Offset()
-	width := bytes.IndexAny(data[offset:], " \n,():;[]")
-	if width <= 0 {
-		return spn
-	}
-	return span.New(spn.URI(), start, span.NewPoint(start.Line(), start.Column()+width, offset+width))
-}
-
 func singleDiagnostic(uri span.URI, format string, a ...interface{}) map[span.URI][]Diagnostic {
 	return map[span.URI][]Diagnostic{
 		uri: []Diagnostic{{
 			Source:   "LSP",
-			Span:     span.New(uri, span.Point{}, span.Point{}),
+			URI:      uri,
+			Range:    protocol.Range{},
 			Message:  fmt.Sprintf(format, a...),
 			Severity: SeverityError,
 		}},
@@ -316,7 +350,7 @@
 	unusedresult.Analyzer,
 }
 
-func runAnalyses(ctx context.Context, v View, cph CheckPackageHandle, disabledAnalyses map[string]struct{}, report func(a *analysis.Analyzer, diag analysis.Diagnostic) error) error {
+func runAnalyses(ctx context.Context, view View, cph CheckPackageHandle, disabledAnalyses map[string]struct{}, report func(a *analysis.Analyzer, diag analysis.Diagnostic) error) error {
 	var analyzers []*analysis.Analyzer
 	for _, a := range Analyzers {
 		if _, ok := disabledAnalyses[a.Name]; ok {
@@ -325,7 +359,7 @@
 		analyzers = append(analyzers, a)
 	}
 
-	roots, err := analyze(ctx, v, []CheckPackageHandle{cph}, analyzers)
+	roots, err := analyze(ctx, view, []CheckPackageHandle{cph}, analyzers)
 	if err != nil {
 		return err
 	}
@@ -342,7 +376,7 @@
 			if err := report(r.Analyzer, diag); err != nil {
 				return err
 			}
-			sdiag, err := toDiagnostic(r.Analyzer, v, diag)
+			sdiag, err := toDiagnostic(ctx, view, diag, r.Analyzer.Name)
 			if err != nil {
 				return err
 			}
diff --git a/internal/lsp/source/source_test.go b/internal/lsp/source/source_test.go
index d32a2d8..a53ed7e 100644
--- a/internal/lsp/source/source_test.go
+++ b/internal/lsp/source/source_test.go
@@ -70,76 +70,12 @@
 			}
 			continue
 		}
-		if diff := diffDiagnostics(uri, want, got); diff != "" {
+		if diff := tests.DiffDiagnostics(uri, want, got); diff != "" {
 			t.Error(diff)
 		}
 	}
 }
 
-func sortDiagnostics(d []source.Diagnostic) {
-	sort.Slice(d, func(i int, j int) bool {
-		if r := span.Compare(d[i].Span, d[j].Span); r != 0 {
-			return r < 0
-		}
-		return d[i].Message < d[j].Message
-	})
-}
-
-// diffDiagnostics prints the diff between expected and actual diagnostics test
-// results.
-func diffDiagnostics(uri span.URI, want, got []source.Diagnostic) string {
-	sortDiagnostics(want)
-	sortDiagnostics(got)
-	if len(got) != len(want) {
-		return summarizeDiagnostics(-1, want, got, "different lengths got %v want %v", len(got), len(want))
-	}
-	for i, w := range want {
-		g := got[i]
-		if w.Message != g.Message {
-			return summarizeDiagnostics(i, want, got, "incorrect Message got %v want %v", g.Message, w.Message)
-		}
-		if span.ComparePoint(w.Start(), g.Start()) != 0 {
-			return summarizeDiagnostics(i, want, got, "incorrect Start got %v want %v", g.Start(), w.Start())
-		}
-		// Special case for diagnostics on parse errors.
-		if strings.Contains(string(uri), "noparse") {
-			if span.ComparePoint(g.Start(), g.End()) != 0 || span.ComparePoint(w.Start(), g.End()) != 0 {
-				return summarizeDiagnostics(i, want, got, "incorrect End got %v want %v", g.End(), w.Start())
-			}
-		} else if !g.IsPoint() { // Accept any 'want' range if the diagnostic returns a zero-length range.
-			if span.ComparePoint(w.End(), g.End()) != 0 {
-				return summarizeDiagnostics(i, want, got, "incorrect End got %v want %v", g.End(), w.End())
-			}
-		}
-		if w.Severity != g.Severity {
-			return summarizeDiagnostics(i, want, got, "incorrect Severity got %v want %v", g.Severity, w.Severity)
-		}
-		if w.Source != g.Source {
-			return summarizeDiagnostics(i, want, got, "incorrect Source got %v want %v", g.Source, w.Source)
-		}
-	}
-	return ""
-}
-
-func summarizeDiagnostics(i int, want []source.Diagnostic, got []source.Diagnostic, reason string, args ...interface{}) string {
-	msg := &bytes.Buffer{}
-	fmt.Fprint(msg, "diagnostics failed")
-	if i >= 0 {
-		fmt.Fprintf(msg, " at %d", i)
-	}
-	fmt.Fprint(msg, " because of ")
-	fmt.Fprintf(msg, reason, args...)
-	fmt.Fprint(msg, ":\nexpected:\n")
-	for _, d := range want {
-		fmt.Fprintf(msg, "  %v\n", d)
-	}
-	fmt.Fprintf(msg, "got:\n")
-	for _, d := range got {
-		fmt.Fprintf(msg, "  %v\n", d)
-	}
-	return msg.String()
-}
-
 func (r *runner) Completion(t *testing.T, data tests.Completions, snippets tests.CompletionSnippets, items tests.CompletionItems) {
 	ctx := r.ctx
 	for src, itemList := range data {
diff --git a/internal/lsp/tests/tests.go b/internal/lsp/tests/tests.go
index 2207604..2233ad9 100644
--- a/internal/lsp/tests/tests.go
+++ b/internal/lsp/tests/tests.go
@@ -5,8 +5,10 @@
 package tests
 
 import (
+	"bytes"
 	"context"
 	"flag"
+	"fmt"
 	"go/ast"
 	"go/token"
 	"io/ioutil"
@@ -18,6 +20,7 @@
 	"golang.org/x/tools/go/expect"
 	"golang.org/x/tools/go/packages"
 	"golang.org/x/tools/go/packages/packagestest"
+	"golang.org/x/tools/internal/lsp/protocol"
 	"golang.org/x/tools/internal/lsp/source"
 	"golang.org/x/tools/internal/span"
 	"golang.org/x/tools/internal/txtar"
@@ -424,8 +427,20 @@
 	if strings.Contains(string(spn.URI()), "analyzer") {
 		severity = source.SeverityWarning
 	}
+	// This is not the correct way to do this,
+	// but it seems excessive to do the full conversion here.
 	want := source.Diagnostic{
-		Span:     spn,
+		URI: spn.URI(),
+		Range: protocol.Range{
+			Start: protocol.Position{
+				Line:      float64(spn.Start().Line()) - 1,
+				Character: float64(spn.Start().Column()) - 1,
+			},
+			End: protocol.Position{
+				Line:      float64(spn.End().Line()) - 1,
+				Character: float64(spn.End().Column()) - 1,
+			},
+		},
 		Severity: severity,
 		Source:   msgSource,
 		Message:  msg,
@@ -433,6 +448,85 @@
 	data.Diagnostics[spn.URI()] = append(data.Diagnostics[spn.URI()], want)
 }
 
+// diffDiagnostics prints the diff between expected and actual diagnostics test
+// results.
+func DiffDiagnostics(uri span.URI, want, got []source.Diagnostic) string {
+	sortDiagnostics(want)
+	sortDiagnostics(got)
+
+	if len(got) != len(want) {
+		return summarizeDiagnostics(-1, want, got, "different lengths got %v want %v", len(got), len(want))
+	}
+	for i, w := range want {
+		g := got[i]
+		if w.Message != g.Message {
+			return summarizeDiagnostics(i, want, got, "incorrect Message got %v want %v", g.Message, w.Message)
+		}
+		if protocol.ComparePosition(w.Range.Start, g.Range.Start) != 0 {
+			return summarizeDiagnostics(i, want, got, "incorrect Start got %v want %v", g.Range.Start, w.Range.Start)
+		}
+		// Special case for diagnostics on parse errors.
+		if strings.Contains(string(uri), "noparse") {
+			if protocol.ComparePosition(g.Range.Start, g.Range.End) != 0 || protocol.ComparePosition(w.Range.Start, g.Range.End) != 0 {
+				return summarizeDiagnostics(i, want, got, "incorrect End got %v want %v", g.Range.End, w.Range.Start)
+			}
+		} else if !protocol.IsPoint(g.Range) { // Accept any 'want' range if the diagnostic returns a zero-length range.
+			if protocol.ComparePosition(w.Range.End, g.Range.End) != 0 {
+				return summarizeDiagnostics(i, want, got, "incorrect End got %v want %v", g.Range.End, w.Range.End)
+			}
+		}
+		if w.Severity != g.Severity {
+			return summarizeDiagnostics(i, want, got, "incorrect Severity got %v want %v", g.Severity, w.Severity)
+		}
+		if w.Source != g.Source {
+			return summarizeDiagnostics(i, want, got, "incorrect Source got %v want %v", g.Source, w.Source)
+		}
+	}
+	return ""
+}
+
+func sortDiagnostics(d []source.Diagnostic) {
+	sort.Slice(d, func(i int, j int) bool {
+		return compareDiagnostic(d[i], d[j]) < 0
+	})
+}
+
+func compareDiagnostic(a, b source.Diagnostic) int {
+	if r := span.CompareURI(a.URI, b.URI); r != 0 {
+		return r
+	}
+	if r := protocol.CompareRange(a.Range, b.Range); r != 0 {
+		return r
+	}
+	if a.Message < b.Message {
+		return -1
+	}
+	if a.Message == b.Message {
+		return 0
+	} else {
+		return 1
+	}
+}
+
+func summarizeDiagnostics(i int, want []source.Diagnostic, got []source.Diagnostic, reason string, args ...interface{}) string {
+	msg := &bytes.Buffer{}
+	fmt.Fprint(msg, "diagnostics failed")
+	if i >= 0 {
+		fmt.Fprintf(msg, " at %d", i)
+	}
+	fmt.Fprint(msg, " because of ")
+	fmt.Fprintf(msg, reason, args...)
+	fmt.Fprint(msg, ":\nexpected:\n")
+	for _, d := range want {
+		fmt.Fprintf(msg, "  %s:%v: %s\n", d.URI, d.Range, d.Message)
+	}
+	fmt.Fprintf(msg, "got:\n")
+	for _, d := range got {
+		fmt.Fprintf(msg, "  %s:%v: %s\n", d.URI, d.Range, d.Message)
+	}
+	return msg.String()
+}
+
 func (data *Data) collectCompletions(src span.Span, expected []token.Pos) {
 	data.Completions[src] = expected
 }