internal/lsp/source: show references to interface methods

This change adds ability to show references to interface methods.
Instead of just showing direct references to a type, we also show
references to the type through an interface now.

Change-Id: I9d313b3b77c75adb9971dc56ee86caa697d03c90
Reviewed-on: https://go-review.googlesource.com/c/tools/+/259998
Trust: Danish Dua <danishdua@google.com>
Run-TryBot: Danish Dua <danishdua@google.com>
gopls-CI: kokoro <noreply+kokoro@google.com>
Reviewed-by: Heschi Kreinick <heschi@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
diff --git a/internal/lsp/source/references.go b/internal/lsp/source/references.go
index cd8a34c..32db1d0 100644
--- a/internal/lsp/source/references.go
+++ b/internal/lsp/source/references.go
@@ -14,7 +14,7 @@
 	"golang.org/x/tools/internal/event"
 	"golang.org/x/tools/internal/lsp/protocol"
 	"golang.org/x/tools/internal/span"
-	"golang.org/x/xerrors"
+	errors "golang.org/x/xerrors"
 )
 
 // ReferenceInfo holds information about reference to an identifier in Go source.
@@ -35,16 +35,18 @@
 
 	qualifiedObjs, err := qualifiedObjsAtProtocolPos(ctx, s, f, pp)
 	// Don't return references for builtin types.
-	if xerrors.Is(err, errBuiltin) {
+	if errors.Is(err, errBuiltin) {
 		return nil, nil
 	}
 	if err != nil {
 		return nil, err
 	}
-	refs, err := references(ctx, s, qualifiedObjs, includeDeclaration)
+
+	refs, err := references(ctx, s, qualifiedObjs, includeDeclaration, true)
 	if err != nil {
 		return nil, err
 	}
+
 	toSort := refs
 	if includeDeclaration {
 		toSort = refs[1:]
@@ -60,32 +62,33 @@
 }
 
 // references is a helper function to avoid recomputing qualifiedObjsAtProtocolPos.
-func references(ctx context.Context, snapshot Snapshot, qos []qualifiedObject, includeDeclaration bool) ([]*ReferenceInfo, error) {
+func references(ctx context.Context, snapshot Snapshot, qos []qualifiedObject, includeDeclaration, includeInterfaceRefs bool) ([]*ReferenceInfo, error) {
 	var (
 		references []*ReferenceInfo
 		seen       = make(map[token.Position]bool)
 	)
 
+	filename := snapshot.FileSet().Position(qos[0].obj.Pos()).Filename
+	pgf, err := qos[0].pkg.File(span.URIFromPath(filename))
+	if err != nil {
+		return nil, err
+	}
+	declIdent, err := findIdentifier(ctx, snapshot, qos[0].pkg, pgf.File, qos[0].obj.Pos())
+	if err != nil {
+		return nil, err
+	}
 	// Make sure declaration is the first item in the response.
 	if includeDeclaration {
-		filename := snapshot.FileSet().Position(qos[0].obj.Pos()).Filename
-		pgf, err := qos[0].pkg.File(span.URIFromPath(filename))
-		if err != nil {
-			return nil, err
-		}
-		ident, err := findIdentifier(ctx, snapshot, qos[0].pkg, pgf.File, qos[0].obj.Pos())
-		if err != nil {
-			return nil, err
-		}
 		references = append(references, &ReferenceInfo{
-			MappedRange:   ident.MappedRange,
+			MappedRange:   declIdent.MappedRange,
 			Name:          qos[0].obj.Name(),
-			ident:         ident.ident,
+			ident:         declIdent.ident,
 			obj:           qos[0].obj,
-			pkg:           ident.pkg,
+			pkg:           declIdent.pkg,
 			isDeclaration: true,
 		})
 	}
+
 	for _, qo := range qos {
 		var searchPkgs []Package
 
@@ -123,5 +126,44 @@
 			}
 		}
 	}
+
+	if includeInterfaceRefs {
+		declRange, err := declIdent.Range()
+		if err != nil {
+			return nil, err
+		}
+		fh, err := snapshot.GetFile(ctx, declIdent.URI())
+		if err != nil {
+			return nil, err
+		}
+		interfaceRefs, err := interfaceReferences(ctx, snapshot, fh, declRange.Start)
+		if err != nil {
+			return nil, err
+		}
+		references = append(references, interfaceRefs...)
+	}
+
 	return references, nil
 }
+
+// interfaceReferences returns the references to the interfaces implemeneted by
+// the type or method at the given position.
+func interfaceReferences(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position) ([]*ReferenceInfo, error) {
+	implementations, err := implementations(ctx, s, f, pp)
+	if err != nil {
+		if errors.Is(err, ErrNotAType) {
+			return nil, nil
+		}
+		return nil, err
+	}
+
+	var refs []*ReferenceInfo
+	for _, impl := range implementations {
+		implRefs, err := references(ctx, s, []qualifiedObject{impl}, false, false)
+		if err != nil {
+			return nil, err
+		}
+		refs = append(refs, implRefs...)
+	}
+	return refs, nil
+}
diff --git a/internal/lsp/source/rename.go b/internal/lsp/source/rename.go
index 23e7425..7fdcc70 100644
--- a/internal/lsp/source/rename.go
+++ b/internal/lsp/source/rename.go
@@ -91,7 +91,7 @@
 	if pkg == nil || pkg.IsIllTyped() {
 		return nil, errors.Errorf("package for %s is ill typed", f.URI())
 	}
-	refs, err := references(ctx, s, qos, true)
+	refs, err := references(ctx, s, qos, true, false)
 	if err != nil {
 		return nil, err
 	}
diff --git a/internal/lsp/testdata/references/interfaces/interfaces.go b/internal/lsp/testdata/references/interfaces/interfaces.go
new file mode 100644
index 0000000..6661dcc
--- /dev/null
+++ b/internal/lsp/testdata/references/interfaces/interfaces.go
@@ -0,0 +1,34 @@
+package interfaces
+
+type first interface {
+	common() //@mark(firCommon, "common"),refs("common", firCommon, xCommon, zCommon)
+	firstMethod() //@mark(firMethod, "firstMethod"),refs("firstMethod", firMethod, xfMethod, zfMethod)
+}
+
+type second interface {
+	common() //@mark(secCommon, "common"),refs("common", secCommon, yCommon, zCommon)
+	secondMethod() //@mark(secMethod, "secondMethod"),refs("secondMethod", secMethod, ysMethod, zsMethod)
+}
+
+type s struct {}
+
+func (*s) common() {} //@mark(sCommon, "common"),refs("common", sCommon, xCommon, yCommon, zCommon)
+
+func (*s) firstMethod() {} //@mark(sfMethod, "firstMethod"),refs("firstMethod", sfMethod, xfMethod, zfMethod)
+
+func (*s) secondMethod() {} //@mark(ssMethod, "secondMethod"),refs("secondMethod", ssMethod, ysMethod, zsMethod)
+
+func main() {
+	var x first = &s{}
+	var y second = &s{}
+
+	x.common() //@mark(xCommon, "common"),refs("common", firCommon, xCommon, zCommon)
+	x.firstMethod() //@mark(xfMethod, "firstMethod"),refs("firstMethod", firMethod, xfMethod, zfMethod)
+	y.common() //@mark(yCommon, "common"),refs("common", secCommon, yCommon, zCommon)
+	y.secondMethod() //@mark(ysMethod, "secondMethod"),refs("secondMethod", secMethod, ysMethod, zsMethod)
+
+	var z *s = &s{}
+	z.firstMethod() //@mark(zfMethod, "firstMethod"),refs("firstMethod", sfMethod, xfMethod, zfMethod)
+	z.secondMethod() //@mark(zsMethod, "secondMethod"),refs("secondMethod", ssMethod, ysMethod, zsMethod)
+	z.common() //@mark(zCommon, "common"),refs("common", sCommon, xCommon, yCommon, zCommon)
+}
diff --git a/internal/lsp/testdata/summary.txt.golden b/internal/lsp/testdata/summary.txt.golden
index d302acf..b7795a1 100644
--- a/internal/lsp/testdata/summary.txt.golden
+++ b/internal/lsp/testdata/summary.txt.golden
@@ -17,7 +17,7 @@
 DefinitionsCount = 63
 TypeDefinitionsCount = 2
 HighlightsCount = 69
-ReferencesCount = 11
+ReferencesCount = 25
 RenamesCount = 29
 PrepareRenamesCount = 7
 SymbolsCount = 5