internal/lsp: check all package variants in find-implementations

We previously only searched for implementations of the object we found
in the "widest" package variant. We instead need to search all
variants because each variant is type checked separately, and
implementations can be located in packages associated with different
variants.

For example, say you have:

-- foo/foo.go --
package foo
type Foo int
type Fooer interface { Foo() Foo }

-- foo/foo_test.go --
package foo
func TestFoo(t *testing.T) {}

-- bar/bar.go --
package bar
import "foo"
type impl struct {}
func (impl) Foo() foo.Foo { return 0 }

When you run find-implementations on the Fooer interface, we
previously would start from the (widest) foo.test's Fooer named
type. Unfortunately bar imports foo, not foo.test, so bar.impl
does not implement foo.test.Fooer. The specific reason is that
bar.impl.Foo returns foo.Foo, whereas foo.test.Fooer.Foo returns
foo.test.Foo, which are distinct *types.Named objects.

Starting our search instead from foo.Fooer resolves this issue.
However, we also need to search from foo.test.Fooer so we match any
implementations in foo_test.go.

Change-Id: I0b0039c98925410751c8f643c8ebd185340e409f
Reviewed-on: https://go-review.googlesource.com/c/tools/+/210459
Run-TryBot: Muir Manders <muir@mnd.rs>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
diff --git a/internal/lsp/definition.go b/internal/lsp/definition.go
index 78aae36..ab2b442 100644
--- a/internal/lsp/definition.go
+++ b/internal/lsp/definition.go
@@ -23,7 +23,7 @@
 	if err != nil {
 		return nil, err
 	}
-	ident, err := source.Identifier(ctx, snapshot, f, params.Position)
+	ident, err := source.Identifier(ctx, snapshot, f, params.Position, source.WidestCheckPackageHandle)
 	if err != nil {
 		return nil, err
 	}
@@ -50,7 +50,7 @@
 	if err != nil {
 		return nil, err
 	}
-	ident, err := source.Identifier(ctx, snapshot, f, params.Position)
+	ident, err := source.Identifier(ctx, snapshot, f, params.Position, source.WidestCheckPackageHandle)
 	if err != nil {
 		return nil, err
 	}
diff --git a/internal/lsp/hover.go b/internal/lsp/hover.go
index 7334d92..b8ab702 100644
--- a/internal/lsp/hover.go
+++ b/internal/lsp/hover.go
@@ -26,7 +26,7 @@
 	if err != nil {
 		return nil, err
 	}
-	ident, err := source.Identifier(ctx, snapshot, f, params.Position)
+	ident, err := source.Identifier(ctx, snapshot, f, params.Position, source.WidestCheckPackageHandle)
 	if err != nil {
 		return nil, nil
 	}
diff --git a/internal/lsp/implementation.go b/internal/lsp/implementation.go
index 9af0f70..010be7d 100644
--- a/internal/lsp/implementation.go
+++ b/internal/lsp/implementation.go
@@ -9,7 +9,9 @@
 
 	"golang.org/x/tools/internal/lsp/protocol"
 	"golang.org/x/tools/internal/lsp/source"
+	"golang.org/x/tools/internal/lsp/telemetry"
 	"golang.org/x/tools/internal/span"
+	"golang.org/x/tools/internal/telemetry/log"
 )
 
 func (s *Server) implementation(ctx context.Context, params *protocol.ImplementationParams) ([]protocol.Location, error) {
@@ -23,9 +25,45 @@
 	if err != nil {
 		return nil, err
 	}
-	ident, err := source.Identifier(ctx, snapshot, f, params.Position)
+
+	phs, err := snapshot.PackageHandles(ctx, snapshot.Handle(ctx, f))
 	if err != nil {
 		return nil, err
 	}
-	return ident.Implementation(ctx)
+
+	var (
+		allLocs []protocol.Location
+		seen    = make(map[protocol.Location]bool)
+	)
+	for _, ph := range phs {
+		ctx := telemetry.Package.With(ctx, ph.ID())
+
+		ident, err := source.Identifier(ctx, snapshot, f, params.Position, source.SpecificPackageHandle(ph.ID()))
+		if err != nil {
+			if err == source.ErrNoIdentFound {
+				return nil, err
+			}
+			log.Error(ctx, "failed to find Identifer", err)
+			continue
+		}
+
+		locs, err := ident.Implementation(ctx)
+		if err != nil {
+			if err == source.ErrNotAMethod {
+				return nil, err
+			}
+			log.Error(ctx, "failed to find Implemenation", err)
+			continue
+		}
+
+		for _, loc := range locs {
+			if seen[loc] {
+				continue
+			}
+			seen[loc] = true
+			allLocs = append(allLocs, loc)
+		}
+	}
+
+	return allLocs, nil
 }
diff --git a/internal/lsp/references.go b/internal/lsp/references.go
index 2dc37f2..ea6754b 100644
--- a/internal/lsp/references.go
+++ b/internal/lsp/references.go
@@ -26,7 +26,7 @@
 		return nil, err
 	}
 	// Find all references to the identifier at the position.
-	ident, err := source.Identifier(ctx, snapshot, f, params.Position)
+	ident, err := source.Identifier(ctx, snapshot, f, params.Position, source.WidestCheckPackageHandle)
 	if err != nil {
 		return nil, err
 	}
diff --git a/internal/lsp/rename.go b/internal/lsp/rename.go
index 1db1abd..18f9c79 100644
--- a/internal/lsp/rename.go
+++ b/internal/lsp/rename.go
@@ -23,7 +23,7 @@
 	if err != nil {
 		return nil, err
 	}
-	ident, err := source.Identifier(ctx, snapshot, f, params.Position)
+	ident, err := source.Identifier(ctx, snapshot, f, params.Position, source.WidestCheckPackageHandle)
 	if err != nil {
 		return nil, err
 	}
@@ -56,7 +56,7 @@
 	if err != nil {
 		return nil, err
 	}
-	ident, err := source.Identifier(ctx, snapshot, f, params.Position)
+	ident, err := source.Identifier(ctx, snapshot, f, params.Position, source.WidestCheckPackageHandle)
 	if err != nil {
 		return nil, nil // ignore errors
 	}
diff --git a/internal/lsp/source/identifier.go b/internal/lsp/source/identifier.go
index a942f5e..26d8183 100644
--- a/internal/lsp/source/identifier.go
+++ b/internal/lsp/source/identifier.go
@@ -58,11 +58,11 @@
 
 // Identifier returns identifier information for a position
 // in a file, accounting for a potentially incomplete selector.
-func Identifier(ctx context.Context, snapshot Snapshot, f File, pos protocol.Position) (*IdentifierInfo, error) {
+func Identifier(ctx context.Context, snapshot Snapshot, f File, pos protocol.Position, selectPackage PackagePolicy) (*IdentifierInfo, error) {
 	ctx, done := trace.StartSpan(ctx, "source.Identifier")
 	defer done()
 
-	pkg, pgh, err := getParsedFile(ctx, snapshot, f, WidestCheckPackageHandle)
+	pkg, pgh, err := getParsedFile(ctx, snapshot, f, selectPackage)
 	if err != nil {
 		return nil, fmt.Errorf("getting file for Identifier: %v", err)
 	}
@@ -81,6 +81,8 @@
 	return findIdentifier(snapshot, pkg, file, rng.Start)
 }
 
+var ErrNoIdentFound = errors.New("no identifier found")
+
 func findIdentifier(snapshot Snapshot, pkg Package, file *ast.File, pos token.Pos) (*IdentifierInfo, error) {
 	if result, err := identifier(snapshot, pkg, file, pos); err != nil || result != nil {
 		return result, err
@@ -90,7 +92,7 @@
 	// requesting a completion), use the path to the preceding node.
 	ident, err := identifier(snapshot, pkg, file, pos-1)
 	if ident == nil && err == nil {
-		err = errors.New("no identifier found")
+		err = ErrNoIdentFound
 	}
 	return ident, err
 }
diff --git a/internal/lsp/source/implementation.go b/internal/lsp/source/implementation.go
index ed8b163..e1578cd 100644
--- a/internal/lsp/source/implementation.go
+++ b/internal/lsp/source/implementation.go
@@ -19,6 +19,7 @@
 	"golang.org/x/tools/internal/lsp/protocol"
 	"golang.org/x/tools/internal/lsp/telemetry"
 	"golang.org/x/tools/internal/telemetry/log"
+	errors "golang.org/x/xerrors"
 )
 
 func (i *IdentifierInfo) Implementation(ctx context.Context) ([]protocol.Location, error) {
@@ -101,6 +102,8 @@
 	return locations, nil
 }
 
+var ErrNotAMethod = errors.New("this function is not a method")
+
 func (i *IdentifierInfo) implementations(ctx context.Context) (implementsResult, error) {
 	var T types.Type
 	var method *types.Func
@@ -112,7 +115,7 @@
 		}
 		recv := obj.Type().(*types.Signature).Recv()
 		if recv == nil {
-			return implementsResult{}, fmt.Errorf("this function is not a method")
+			return implementsResult{}, ErrNotAMethod
 		}
 		method = obj
 		T = recv.Type()
diff --git a/internal/lsp/source/source_test.go b/internal/lsp/source/source_test.go
index bdb3c66..f30e291 100644
--- a/internal/lsp/source/source_test.go
+++ b/internal/lsp/source/source_test.go
@@ -499,7 +499,7 @@
 	if err != nil {
 		t.Fatal(err)
 	}
-	ident, err := source.Identifier(ctx, r.view.Snapshot(), f, srcRng.Start)
+	ident, err := source.Identifier(ctx, r.view.Snapshot(), f, srcRng.Start, source.WidestCheckPackageHandle)
 	if err != nil {
 		t.Fatalf("failed for %v: %v", d.Src, err)
 	}
@@ -562,7 +562,7 @@
 	if err != nil {
 		t.Fatalf("failed for %v: %v", spn, err)
 	}
-	ident, err := source.Identifier(ctx, r.view.Snapshot(), f, loc.Range.Start)
+	ident, err := source.Identifier(ctx, r.view.Snapshot(), f, loc.Range.Start, source.WidestCheckPackageHandle)
 	if err != nil {
 		t.Fatalf("failed for %v: %v", spn, err)
 	}
@@ -649,7 +649,7 @@
 	if err != nil {
 		t.Fatal(err)
 	}
-	ident, err := source.Identifier(ctx, r.view.Snapshot(), f, srcRng.Start)
+	ident, err := source.Identifier(ctx, r.view.Snapshot(), f, srcRng.Start, source.WidestCheckPackageHandle)
 	if err != nil {
 		t.Fatalf("failed for %v: %v", src, err)
 	}
@@ -693,7 +693,7 @@
 	if err != nil {
 		t.Fatal(err)
 	}
-	ident, err := source.Identifier(r.ctx, r.view.Snapshot(), f, srcRng.Start)
+	ident, err := source.Identifier(r.ctx, r.view.Snapshot(), f, srcRng.Start, source.WidestCheckPackageHandle)
 	if err != nil {
 		t.Error(err)
 		return
@@ -782,7 +782,7 @@
 		t.Fatal(err)
 	}
 	// Find the identifier at the position.
-	ident, err := source.Identifier(ctx, r.view.Snapshot(), f, srcRng.Start)
+	ident, err := source.Identifier(ctx, r.view.Snapshot(), f, srcRng.Start, source.WidestCheckPackageHandle)
 	if err != nil {
 		if want.Text != "" { // expected an ident.
 			t.Errorf("prepare rename failed for %v: got error: %v", src, err)
diff --git a/internal/lsp/source/util.go b/internal/lsp/source/util.go
index 191e529..ada8444 100644
--- a/internal/lsp/source/util.go
+++ b/internal/lsp/source/util.go
@@ -67,7 +67,7 @@
 
 // getParsedFile is a convenience function that extracts the Package and ParseGoHandle for a File in a Snapshot.
 // selectPackage is typically Narrowest/WidestCheckPackageHandle below.
-func getParsedFile(ctx context.Context, snapshot Snapshot, f File, selectPackage func([]PackageHandle) (PackageHandle, error)) (Package, ParseGoHandle, error) {
+func getParsedFile(ctx context.Context, snapshot Snapshot, f File, selectPackage PackagePolicy) (Package, ParseGoHandle, error) {
 	fh := snapshot.Handle(ctx, f)
 	phs, err := snapshot.PackageHandles(ctx, fh)
 	if err != nil {
@@ -85,6 +85,8 @@
 	return pkg, pgh, err
 }
 
+type PackagePolicy func([]PackageHandle) (PackageHandle, error)
+
 // NarrowestCheckPackageHandle picks the "narrowest" package for a given file.
 //
 // By "narrowest" package, we mean the package with the fewest number of files
@@ -126,6 +128,20 @@
 	return result, nil
 }
 
+// SpecificPackageHandle creates a PackagePolicy to select a
+// particular PackageHandle when you alread know the one you want.
+func SpecificPackageHandle(desiredID string) PackagePolicy {
+	return func(handles []PackageHandle) (PackageHandle, error) {
+		for _, h := range handles {
+			if h.ID() == desiredID {
+				return h, nil
+			}
+		}
+
+		return nil, fmt.Errorf("no package handle with expected id %q", desiredID)
+	}
+}
+
 func IsGenerated(ctx context.Context, view View, uri span.URI) bool {
 	f, err := view.GetFile(ctx, uri)
 	if err != nil {