internal/lsp: respect References.IncludeDeclaration setting

Previously, (*IdentifierInfo).References was returning the declaration
of the identifier among the reference results. This change alters the
behavior of this function to only ever return non-declaration
references. Declarations can be accessed through the
IdentifierInfo.Declaration field.

Fixes golang/go#36007

Change-Id: I91d82b7e6d0d51a2468d3df67f666834d2905250
Reviewed-on: https://go-review.googlesource.com/c/tools/+/210238
Run-TryBot: Rebecca Stambler <rstambler@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Heschi Kreinick <heschi@google.com>
diff --git a/internal/lsp/cmd/test/references.go b/internal/lsp/cmd/test/references.go
index 3039f53..d48a0a9 100644
--- a/internal/lsp/cmd/test/references.go
+++ b/internal/lsp/cmd/test/references.go
@@ -27,7 +27,7 @@
 	uri := spn.URI()
 	filename := uri.Filename()
 	target := filename + fmt.Sprintf(":%v:%v", spn.Start().Line(), spn.Start().Column())
-	got, _ := r.NormalizeGoplsCmd(t, "references", target)
+	got, _ := r.NormalizeGoplsCmd(t, "references", "-d", target)
 	if expect != got {
 		t.Errorf("references failed for %s expected:\n%s\ngot:\n%s", target, expect, got)
 	}
diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go
index b2414ad..c53a1b8 100644
--- a/internal/lsp/lsp_test.go
+++ b/internal/lsp/lsp_test.go
@@ -541,7 +541,6 @@
 	if err != nil {
 		t.Fatalf("failed for %v: %v", src, err)
 	}
-
 	want := make(map[protocol.Location]bool)
 	for _, pos := range itemList {
 		m, err := r.data.Mapper(pos.URI())
@@ -559,6 +558,9 @@
 			TextDocument: protocol.TextDocumentIdentifier{URI: loc.URI},
 			Position:     loc.Range.Start,
 		},
+		Context: protocol.ReferenceContext{
+			IncludeDeclaration: true,
+		},
 	}
 	got, err := r.server.References(r.ctx, params)
 	if err != nil {
diff --git a/internal/lsp/references.go b/internal/lsp/references.go
index 9d831b5..2dc37f2 100644
--- a/internal/lsp/references.go
+++ b/internal/lsp/references.go
@@ -56,27 +56,18 @@
 			Range: refRange,
 		})
 	}
-	// The declaration of this identifier may not be in the
-	// scope that we search for references, so make sure
-	// it is added to the beginning of the list if IncludeDeclaration
-	// was specified.
+	// Only add the identifier's declaration if the client requests it.
 	if params.Context.IncludeDeclaration {
-		decSpan, err := ident.Declaration.Span()
+		rng, err := ident.Declaration.Range()
 		if err != nil {
 			return nil, err
 		}
-		if !seen[decSpan] {
-			rng, err := ident.Declaration.Range()
-			if err != nil {
-				return nil, err
-			}
-			locations = append([]protocol.Location{
-				{
-					URI:   protocol.NewURI(ident.Declaration.URI()),
-					Range: rng,
-				},
-			}, locations...)
-		}
+		locations = append([]protocol.Location{
+			{
+				URI:   protocol.NewURI(ident.Declaration.URI()),
+				Range: rng,
+			},
+		}, locations...)
 	}
 	return locations, nil
 }
diff --git a/internal/lsp/source/identifier.go b/internal/lsp/source/identifier.go
index a0720c0..a942f5e 100644
--- a/internal/lsp/source/identifier.go
+++ b/internal/lsp/source/identifier.go
@@ -45,6 +45,17 @@
 	wasImplicit bool
 }
 
+func (i *IdentifierInfo) DeclarationReferenceInfo() *ReferenceInfo {
+	return &ReferenceInfo{
+		Name:          i.Declaration.obj.Name(),
+		mappedRange:   i.Declaration.mappedRange,
+		obj:           i.Declaration.obj,
+		ident:         i.ident,
+		pkg:           i.pkg,
+		isDeclaration: true,
+	}
+}
+
 // Identifier returns identifier information for a position
 // in a file, accounting for a potentially incomplete selector.
 func Identifier(ctx context.Context, snapshot Snapshot, f File, pos protocol.Position) (*IdentifierInfo, error) {
diff --git a/internal/lsp/source/references.go b/internal/lsp/source/references.go
index 5539f5f..6163425 100644
--- a/internal/lsp/source/references.go
+++ b/internal/lsp/source/references.go
@@ -32,8 +32,6 @@
 	ctx, done := trace.StartSpan(ctx, "source.References")
 	defer done()
 
-	var references []*ReferenceInfo
-
 	// If the object declaration is nil, assume it is an import spec and do not look for references.
 	if i.Declaration.obj == nil {
 		return nil, errors.Errorf("no references for an import spec")
@@ -42,36 +40,6 @@
 	if info == nil {
 		return nil, errors.Errorf("package %s has no types info", i.pkg.PkgPath())
 	}
-	if i.Declaration.wasImplicit {
-		// The definition is implicit, so we must add it separately.
-		// This occurs when the variable is declared in a type switch statement
-		// or is an implicit package name. Both implicits are local to a file.
-		references = append(references, &ReferenceInfo{
-			Name:          i.Declaration.obj.Name(),
-			mappedRange:   i.Declaration.mappedRange,
-			obj:           i.Declaration.obj,
-			pkg:           i.pkg,
-			isDeclaration: true,
-		})
-	}
-	for ident, obj := range info.Defs {
-		if obj == nil || !sameObj(obj, i.Declaration.obj) {
-			continue
-		}
-		rng, err := posToMappedRange(i.Snapshot.View(), i.pkg, ident.Pos(), ident.End())
-		if err != nil {
-			return nil, err
-		}
-		// Add the declarations at the beginning of the references list.
-		references = append([]*ReferenceInfo{{
-			Name:          ident.Name,
-			ident:         ident,
-			obj:           obj,
-			pkg:           i.pkg,
-			isDeclaration: true,
-			mappedRange:   rng,
-		}}, references...)
-	}
 	var searchpkgs []Package
 	if i.Declaration.obj.Exported() {
 		// Only search all packages if the identifier is exported.
@@ -91,9 +59,11 @@
 	}
 	// Add the package in which the identifier is declared.
 	searchpkgs = append(searchpkgs, i.pkg)
+
+	var references []*ReferenceInfo
 	for _, pkg := range searchpkgs {
 		for ident, obj := range pkg.GetTypesInfo().Uses {
-			if obj == nil || !(sameObj(obj, i.Declaration.obj)) {
+			if !sameObj(obj, i.Declaration.obj) {
 				continue
 			}
 			rng, err := posToMappedRange(i.Snapshot.View(), pkg, ident.Pos(), ident.End())
@@ -117,6 +87,9 @@
 // and their objectpath and package are the same; or if they don't
 // have object paths and they have the same Pos and Name.
 func sameObj(obj, declObj types.Object) bool {
+	if obj == nil || declObj == nil {
+		return false
+	}
 	// TODO(suzmue): support the case where an identifier may have two different
 	// declaration positions.
 	if obj.Pkg() == nil || declObj.Pkg() == nil {
diff --git a/internal/lsp/source/rename.go b/internal/lsp/source/rename.go
index 29a7d8c..986a286 100644
--- a/internal/lsp/source/rename.go
+++ b/internal/lsp/source/rename.go
@@ -120,6 +120,9 @@
 		return nil, err
 	}
 
+	// Make sure to add the declaration of the identifier.
+	refs = append(refs, i.DeclarationReferenceInfo())
+
 	r := renamer{
 		ctx:          ctx,
 		fset:         i.Snapshot.View().Session().Cache().FileSet(),
diff --git a/internal/lsp/source/source_test.go b/internal/lsp/source/source_test.go
index 48dda36..de1408d 100644
--- a/internal/lsp/source/source_test.go
+++ b/internal/lsp/source/source_test.go
@@ -653,17 +653,16 @@
 	if err != nil {
 		t.Fatalf("failed for %v: %v", src, err)
 	}
-
 	want := make(map[span.Span]bool)
 	for _, pos := range itemList {
 		want[pos] = true
 	}
-
 	refs, err := ident.References(ctx)
 	if err != nil {
 		t.Fatalf("failed for %v: %v", src, err)
 	}
-
+	// Add the item's declaration, since References omits it.
+	refs = append([]*source.ReferenceInfo{ident.DeclarationReferenceInfo()}, refs...)
 	got := make(map[span.Span]bool)
 	for _, refInfo := range refs {
 		refSpan, err := refInfo.Span()
@@ -672,11 +671,9 @@
 		}
 		got[refSpan] = true
 	}
-
 	if len(got) != len(want) {
 		t.Errorf("references failed: different lengths got %v want %v", len(got), len(want))
 	}
-
 	for spn := range got {
 		if !want[spn] {
 			t.Errorf("references failed: incorrect references got %v want locations %v", got, want)