internal/lsp: add find all references

This change implements the find all references feature by finding all of
the uses and definitions of the identifier within the current package.

Testing for references is done using "refs" in the testdata files and
marking the references in the package.

Change-Id: Ieb44b68608e940df5f65c3052eb9ec974f6fae6c
Reviewed-on: https://go-review.googlesource.com/c/tools/+/181122
Run-TryBot: Suzy Mueller <suzmue@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
diff --git a/internal/lsp/cmd/cmd_test.go b/internal/lsp/cmd/cmd_test.go
index 8863145..c1a882f 100644
--- a/internal/lsp/cmd/cmd_test.go
+++ b/internal/lsp/cmd/cmd_test.go
@@ -49,6 +49,11 @@
 func (r *runner) Highlight(t *testing.T, data tests.Highlights) {
 	//TODO: add command line highlight tests when it works
 }
+
+func (r *runner) Reference(t *testing.T, data tests.References) {
+	//TODO: add command line references tests when it works
+}
+
 func (r *runner) Symbol(t *testing.T, data tests.Symbols) {
 	//TODO: add command line symbol tests when it works
 }
diff --git a/internal/lsp/general.go b/internal/lsp/general.go
index 0964de7..92e28a5 100644
--- a/internal/lsp/general.go
+++ b/internal/lsp/general.go
@@ -69,6 +69,7 @@
 			HoverProvider:              true,
 			DocumentHighlightProvider:  true,
 			DocumentLinkProvider:       &protocol.DocumentLinkOptions{},
+			ReferencesProvider:         true,
 			SignatureHelpProvider: &protocol.SignatureHelpOptions{
 				TriggerCharacters: []string{"(", ","},
 			},
diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go
index 2548b46..a7d12b0 100644
--- a/internal/lsp/lsp_test.go
+++ b/internal/lsp/lsp_test.go
@@ -460,6 +460,48 @@
 	}
 }
 
+func (r *runner) Reference(t *testing.T, data tests.References) {
+	for src, itemList := range data {
+		sm, err := r.mapper(src.URI())
+		if err != nil {
+			t.Fatal(err)
+		}
+		loc, err := sm.Location(src)
+		if err != nil {
+			t.Fatalf("failed for %v: %v", src, err)
+		}
+
+		want := make(map[protocol.Location]bool)
+		for _, pos := range itemList {
+			loc, err := sm.Location(pos)
+			if err != nil {
+				t.Fatalf("failed for %v: %v", src, err)
+			}
+			want[loc] = true
+		}
+
+		params := &protocol.ReferenceParams{
+			TextDocumentPositionParams: protocol.TextDocumentPositionParams{
+				TextDocument: protocol.TextDocumentIdentifier{URI: loc.URI},
+				Position:     loc.Range.Start,
+			},
+		}
+		got, err := r.server.References(context.Background(), params)
+		if err != nil {
+			t.Fatalf("failed for %v: %v", src, err)
+		}
+
+		if len(got) != len(itemList) {
+			t.Errorf("references failed: different lengths got %v want %v", len(got), len(itemList))
+		}
+		for _, loc := range got {
+			if !want[loc] {
+				t.Errorf("references failed: incorrect references got %v want %v", got, want)
+			}
+		}
+	}
+}
+
 func (r *runner) Symbol(t *testing.T, data tests.Symbols) {
 	for uri, expectedSymbols := range data {
 		params := &protocol.DocumentSymbolParams{
diff --git a/internal/lsp/references.go b/internal/lsp/references.go
new file mode 100644
index 0000000..2786244
--- /dev/null
+++ b/internal/lsp/references.go
@@ -0,0 +1,56 @@
+package lsp
+
+import (
+	"context"
+
+	"golang.org/x/tools/internal/lsp/protocol"
+	"golang.org/x/tools/internal/lsp/source"
+	"golang.org/x/tools/internal/span"
+)
+
+func (s *Server) references(ctx context.Context, params *protocol.ReferenceParams) ([]protocol.Location, error) {
+	uri := span.NewURI(params.TextDocument.URI)
+	view := s.session.ViewOf(uri)
+	f, m, err := getGoFile(ctx, view, uri)
+	if err != nil {
+		return nil, err
+	}
+	spn, err := m.PointSpan(params.Position)
+	if err != nil {
+		return nil, err
+	}
+	rng, err := spn.Range(m.Converter)
+	if err != nil {
+		return nil, err
+	}
+
+	// Find all references to the identifier at the position.
+	ident, err := source.Identifier(ctx, view, f, rng.Start)
+	if err != nil {
+		return nil, err
+	}
+	references, err := ident.References(ctx)
+	if err != nil {
+		return nil, err
+	}
+
+	// Get the location of each reference to return as the result.
+	locations := make([]protocol.Location, 0, len(references))
+	for _, ref := range references {
+		refSpan, err := ref.Range.Span()
+		if err != nil {
+			return nil, err
+		}
+		_, refM, err := getSourceFile(ctx, view, refSpan.URI())
+		if err != nil {
+			return nil, err
+		}
+		loc, err := refM.Location(refSpan)
+		if err != nil {
+			return nil, err
+		}
+
+		locations = append(locations, loc)
+	}
+	return locations, nil
+}
diff --git a/internal/lsp/server.go b/internal/lsp/server.go
index 13bcaef..121cc13 100644
--- a/internal/lsp/server.go
+++ b/internal/lsp/server.go
@@ -186,8 +186,8 @@
 	return nil, notImplemented("Implementation")
 }
 
-func (s *Server) References(context.Context, *protocol.ReferenceParams) ([]protocol.Location, error) {
-	return nil, notImplemented("References")
+func (s *Server) References(ctx context.Context, params *protocol.ReferenceParams) ([]protocol.Location, error) {
+	return s.references(ctx, params)
 }
 
 func (s *Server) DocumentHighlight(ctx context.Context, params *protocol.TextDocumentPositionParams) ([]protocol.DocumentHighlight, error) {
diff --git a/internal/lsp/source/references.go b/internal/lsp/source/references.go
new file mode 100644
index 0000000..8a3ae73
--- /dev/null
+++ b/internal/lsp/source/references.go
@@ -0,0 +1,56 @@
+package source
+
+import (
+	"context"
+	"fmt"
+	"go/ast"
+
+	"golang.org/x/tools/internal/span"
+)
+
+// ReferenceInfo holds information about reference to an identifier in Go source.
+type ReferenceInfo struct {
+	Name  string
+	Range span.Range
+	ident *ast.Ident
+}
+
+// References returns a list of references for a given identifier within a package.
+func (i *IdentifierInfo) References(ctx context.Context) ([]*ReferenceInfo, error) {
+	pkg := i.File.GetPackage(ctx)
+	if pkg == nil || pkg.IsIllTyped() {
+		return nil, fmt.Errorf("package for %s is ill typed", i.File.URI())
+	}
+	pkgInfo := pkg.GetTypesInfo()
+	if pkgInfo == nil {
+		return nil, fmt.Errorf("package %s has no types info", pkg.PkgPath())
+	}
+
+	// If the object declaration is nil, assume it is an import spec and do not look for references.
+	declObj := i.decl.obj
+	if declObj == nil {
+		return []*ReferenceInfo{}, nil
+	}
+
+	var references []*ReferenceInfo
+	for ident, obj := range pkgInfo.Defs {
+		if obj == declObj {
+			references = append(references, &ReferenceInfo{
+				Name:  ident.Name,
+				Range: span.NewRange(i.File.FileSet(), ident.Pos(), ident.End()),
+				ident: ident,
+			})
+		}
+	}
+	for ident, obj := range pkgInfo.Uses {
+		if obj == declObj {
+			references = append(references, &ReferenceInfo{
+				Name:  ident.Name,
+				Range: span.NewRange(i.File.FileSet(), ident.Pos(), ident.End()),
+				ident: ident,
+			})
+		}
+	}
+
+	return references, nil
+}
diff --git a/internal/lsp/source/source_test.go b/internal/lsp/source/source_test.go
index 8135e1d..be9c0f6 100644
--- a/internal/lsp/source/source_test.go
+++ b/internal/lsp/source/source_test.go
@@ -417,6 +417,46 @@
 	}
 }
 
+func (r *runner) Reference(t *testing.T, data tests.References) {
+	ctx := context.Background()
+	for src, itemList := range data {
+		f, err := r.view.GetFile(ctx, src.URI())
+		if err != nil {
+			t.Fatalf("failed for %v: %v", src, err)
+		}
+
+		tok := f.GetToken(ctx)
+		pos := tok.Pos(src.Start().Offset())
+		ident, err := source.Identifier(ctx, r.view, f.(source.GoFile), pos)
+		if err != nil {
+			t.Fatalf("failed for %v: %v", src, err)
+		}
+
+		want := make(map[span.Span]bool)
+		for _, pos := range itemList {
+			want[pos] = true
+		}
+
+		got, err := ident.References(ctx)
+		if err != nil {
+			t.Fatalf("failed for %v: %v", src, err)
+		}
+
+		if len(got) != len(itemList) {
+			t.Errorf("references failed: different lengths got %v want %v", len(got), len(itemList))
+		}
+		for _, refInfo := range got {
+			refSpan, err := refInfo.Range.Span()
+			if err != nil {
+				t.Errorf("failed for %v item %v: %v", src, refInfo.Name, err)
+			}
+			if !want[refSpan] {
+				t.Errorf("references failed: incorrect references got %v want locations %v", got, want)
+			}
+		}
+	}
+}
+
 func (r *runner) Symbol(t *testing.T, data tests.Symbols) {
 	ctx := context.Background()
 	for uri, expectedSymbols := range data {
diff --git a/internal/lsp/testdata/foo/foo.go b/internal/lsp/testdata/foo/foo.go
index 0e33467..094623e 100644
--- a/internal/lsp/testdata/foo/foo.go
+++ b/internal/lsp/testdata/foo/foo.go
@@ -13,8 +13,8 @@
 }
 
 func _() {
-	var sFoo StructFoo           //@complete("t", StructFoo)
-	if x := sFoo; x.Value == 1 { //@complete("V", Value),typdef("sFoo", StructFoo)
+	var sFoo StructFoo           //@mark(sFoo1, "sFoo"),complete("t", StructFoo)
+	if x := sFoo; x.Value == 1 { //@mark(sFoo2, "sFoo"),complete("V", Value),typdef("sFoo", StructFoo),refs("sFo", sFoo1, sFoo2)
 		return
 	}
 }
@@ -22,7 +22,7 @@
 func _() {
 	shadowed := 123
 	{
-		shadowed := "hi" //@item(shadowed, "shadowed", "string", "var")
+		shadowed := "hi" //@item(shadowed, "shadowed", "string", "var"),refs("shadowed", shadowed)
 		sha              //@complete("a", shadowed)
 	}
 }
diff --git a/internal/lsp/tests/tests.go b/internal/lsp/tests/tests.go
index a8dc349..81f1a10 100644
--- a/internal/lsp/tests/tests.go
+++ b/internal/lsp/tests/tests.go
@@ -34,6 +34,7 @@
 	ExpectedDefinitionsCount       = 35
 	ExpectedTypeDefinitionsCount   = 2
 	ExpectedHighlightsCount        = 2
+	ExpectedReferencesCount        = 2
 	ExpectedSymbolsCount           = 1
 	ExpectedSignaturesCount        = 20
 	ExpectedLinksCount             = 2
@@ -56,6 +57,7 @@
 type Imports []span.Span
 type Definitions map[span.Span]Definition
 type Highlights map[string][]span.Span
+type References map[span.Span][]span.Span
 type Symbols map[span.URI][]source.Symbol
 type SymbolsChildren map[string][]source.Symbol
 type Signatures map[span.Span]source.SignatureInformation
@@ -72,6 +74,7 @@
 	Imports            Imports
 	Definitions        Definitions
 	Highlights         Highlights
+	References         References
 	Symbols            Symbols
 	symbolsChildren    SymbolsChildren
 	Signatures         Signatures
@@ -90,6 +93,7 @@
 	Import(*testing.T, Imports)
 	Definition(*testing.T, Definitions)
 	Highlight(*testing.T, Highlights)
+	Reference(*testing.T, References)
 	Symbol(*testing.T, Symbols)
 	SignatureHelp(*testing.T, Signatures)
 	Link(*testing.T, Links)
@@ -130,6 +134,7 @@
 		CompletionSnippets: make(CompletionSnippets),
 		Definitions:        make(Definitions),
 		Highlights:         make(Highlights),
+		References:         make(References),
 		Symbols:            make(Symbols),
 		symbolsChildren:    make(SymbolsChildren),
 		Signatures:         make(Signatures),
@@ -209,6 +214,7 @@
 		"typdef":    data.collectTypeDefinitions,
 		"hover":     data.collectHoverDefinitions,
 		"highlight": data.collectHighlights,
+		"refs":      data.collectReferences,
 		"symbol":    data.collectSymbols,
 		"signature": data.collectSignatures,
 		"snippet":   data.collectCompletionSnippets,
@@ -289,6 +295,14 @@
 		tests.Highlight(t, data.Highlights)
 	})
 
+	t.Run("References", func(t *testing.T) {
+		t.Helper()
+		if len(data.References) != ExpectedReferencesCount {
+			t.Errorf("got %v references expected %v", len(data.References), ExpectedReferencesCount)
+		}
+		tests.Reference(t, data.References)
+	})
+
 	t.Run("Symbols", func(t *testing.T) {
 		t.Helper()
 		if len(data.Symbols) != ExpectedSymbolsCount {
@@ -456,6 +470,10 @@
 	data.Highlights[name] = append(data.Highlights[name], rng)
 }
 
+func (data *Data) collectReferences(src span.Span, expected []span.Span) {
+	data.References[src] = expected
+}
+
 func (data *Data) collectSymbols(name string, spn span.Span, kind string, parentName string) {
 	sym := source.Symbol{
 		Name:          name,