| // Copyright 2019 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package source |
| |
| import ( |
| "context" |
| "fmt" |
| "go/ast" |
| "go/token" |
| "go/types" |
| |
| "golang.org/x/tools/internal/lsp/protocol" |
| "golang.org/x/tools/internal/telemetry/trace" |
| errors "golang.org/x/xerrors" |
| ) |
| |
| func Implementation(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position) ([]protocol.Location, error) { |
| ctx, done := trace.StartSpan(ctx, "source.Implementation") |
| defer done() |
| |
| impls, err := implementations(ctx, s, f, pp) |
| if err != nil { |
| return nil, err |
| } |
| |
| var locations []protocol.Location |
| for _, impl := range impls { |
| if impl.pkg == nil || len(impl.pkg.CompiledGoFiles()) == 0 { |
| continue |
| } |
| rng, err := objToMappedRange(s.View(), impl.pkg, impl.obj) |
| if err != nil { |
| return nil, err |
| } |
| pr, err := rng.Range() |
| if err != nil { |
| return nil, err |
| } |
| locations = append(locations, protocol.Location{ |
| URI: protocol.NewURI(rng.URI()), |
| Range: pr, |
| }) |
| } |
| return locations, nil |
| } |
| |
| var ErrNotAnInterface = errors.New("not an interface or interface method") |
| |
| func implementations(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position) ([]implementation, error) { |
| var ( |
| impls []implementation |
| seen = make(map[token.Position]bool) |
| fset = s.View().Session().Cache().FileSet() |
| ) |
| |
| objs, err := objectsAtProtocolPos(ctx, s, f, pp) |
| if err != nil { |
| return nil, err |
| } |
| |
| for _, obj := range objs { |
| var ( |
| T *types.Interface |
| method *types.Func |
| ) |
| |
| switch obj := obj.(type) { |
| case *types.Func: |
| method = obj |
| if recv := obj.Type().(*types.Signature).Recv(); recv != nil { |
| T, _ = recv.Type().Underlying().(*types.Interface) |
| } |
| case *types.TypeName: |
| T, _ = obj.Type().Underlying().(*types.Interface) |
| } |
| |
| if T == nil { |
| return nil, ErrNotAnInterface |
| } |
| |
| if T.NumMethods() == 0 { |
| return nil, nil |
| } |
| |
| // Find all named types, even local types (which can have methods |
| // due to promotion). |
| var ( |
| allNamed []*types.Named |
| pkgs = make(map[*types.Package]Package) |
| ) |
| knownPkgs, err := s.KnownPackages(ctx) |
| if err != nil { |
| return nil, err |
| } |
| for _, ph := range knownPkgs { |
| pkg, err := ph.Check(ctx) |
| if err != nil { |
| return nil, err |
| } |
| pkgs[pkg.GetTypes()] = pkg |
| info := pkg.GetTypesInfo() |
| for _, obj := range info.Defs { |
| obj, ok := obj.(*types.TypeName) |
| // We ignore aliases 'type M = N' to avoid duplicate reporting |
| // of the Named type N. |
| if !ok || obj.IsAlias() { |
| continue |
| } |
| named, ok := obj.Type().(*types.Named) |
| // We skip interface types since we only want concrete |
| // implementations. |
| if !ok || isInterface(named) { |
| continue |
| } |
| allNamed = append(allNamed, named) |
| } |
| } |
| |
| // Find all the named types that implement our interface. |
| for _, U := range allNamed { |
| var concrete types.Type = U |
| if !types.AssignableTo(concrete, T) { |
| // We also accept T if *T implements our interface. |
| concrete = types.NewPointer(concrete) |
| if !types.AssignableTo(concrete, T) { |
| continue |
| } |
| } |
| |
| var obj types.Object = U.Obj() |
| if method != nil { |
| obj = types.NewMethodSet(concrete).Lookup(method.Pkg(), method.Name()).Obj() |
| } |
| |
| pos := fset.Position(obj.Pos()) |
| if obj == method || seen[pos] { |
| continue |
| } |
| |
| seen[pos] = true |
| |
| impls = append(impls, implementation{ |
| obj: obj, |
| pkg: pkgs[obj.Pkg()], |
| }) |
| } |
| } |
| |
| return impls, nil |
| } |
| |
| type implementation struct { |
| // obj is the implementation, either a *types.TypeName or *types.Func. |
| obj types.Object |
| |
| // pkg is the Package that contains obj's definition. |
| pkg Package |
| } |
| |
| // objectsAtProtocolPos returns all the type.Objects referenced at the given position. |
| // An object will be returned for every package that the file belongs to. |
| func objectsAtProtocolPos(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position) ([]types.Object, error) { |
| phs, err := s.PackageHandles(ctx, f) |
| if err != nil { |
| return nil, err |
| } |
| |
| var objs []types.Object |
| |
| // Check all the packages that the file belongs to. |
| for _, ph := range phs { |
| pkg, err := ph.Check(ctx) |
| if err != nil { |
| return nil, err |
| } |
| |
| astFile, pos, err := getASTFile(pkg, f, pp) |
| if err != nil { |
| return nil, err |
| } |
| |
| path := pathEnclosingIdent(astFile, pos) |
| if len(path) == 0 { |
| return nil, ErrNoIdentFound |
| } |
| |
| ident := path[len(path)-1].(*ast.Ident) |
| |
| obj := pkg.GetTypesInfo().ObjectOf(ident) |
| if obj == nil { |
| return nil, fmt.Errorf("no object for %q", ident.Name) |
| } |
| |
| objs = append(objs, obj) |
| } |
| |
| return objs, nil |
| } |
| |
| func getASTFile(pkg Package, f FileHandle, pos protocol.Position) (*ast.File, token.Pos, error) { |
| pgh, err := pkg.File(f.Identity().URI) |
| if err != nil { |
| return nil, 0, err |
| } |
| |
| file, m, _, err := pgh.Cached() |
| if err != nil { |
| return nil, 0, err |
| } |
| |
| spn, err := m.PointSpan(pos) |
| if err != nil { |
| return nil, 0, err |
| } |
| |
| rng, err := spn.Range(m.Converter) |
| if err != nil { |
| return nil, 0, err |
| } |
| |
| return file, rng.Start, nil |
| } |
| |
| // pathEnclosingIdent returns the ast path to the node that contains pos. |
| // It is similar to astutil.PathEnclosingInterval, but simpler, and it |
| // matches *ast.Ident nodes if pos is equal to node.End(). |
| func pathEnclosingIdent(f *ast.File, pos token.Pos) []ast.Node { |
| var ( |
| path []ast.Node |
| found bool |
| ) |
| |
| ast.Inspect(f, func(n ast.Node) bool { |
| if found { |
| return false |
| } |
| |
| if n == nil { |
| path = path[:len(path)-1] |
| return false |
| } |
| |
| switch n := n.(type) { |
| case *ast.Ident: |
| found = n.Pos() <= pos && pos <= n.End() |
| } |
| |
| path = append(path, n) |
| |
| return !found |
| }) |
| |
| return path |
| } |