// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package source

import (
	"context"
	"fmt"
	"go/ast"
	"go/token"
	"go/types"

	"golang.org/x/tools/internal/lsp/protocol"
	"golang.org/x/tools/internal/telemetry/log"
	"golang.org/x/tools/internal/telemetry/trace"
	errors "golang.org/x/xerrors"
)

func Implementation(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position) ([]protocol.Location, error) {
	ctx, done := trace.StartSpan(ctx, "source.Implementation")
	defer done()

	impls, err := implementations(ctx, s, f, pp)
	if err != nil {
		return nil, err
	}

	var locations []protocol.Location
	for _, impl := range impls {
		if impl.pkg == nil || len(impl.pkg.CompiledGoFiles()) == 0 {
			continue
		}

		rng, err := objToMappedRange(s.View(), impl.pkg, impl.obj)
		if err != nil {
			log.Error(ctx, "Error getting range for object", err)
			continue
		}

		pr, err := rng.Range()
		if err != nil {
			log.Error(ctx, "Error getting protocol range for object", err)
			continue
		}

		locations = append(locations, protocol.Location{
			URI:   protocol.NewURI(rng.URI()),
			Range: pr,
		})
	}
	return locations, nil
}

var ErrNotAnInterface = errors.New("not an interface or interface method")

func implementations(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position) ([]implementation, error) {

	var (
		impls []implementation
		seen  = make(map[token.Position]bool)
		fset  = s.View().Session().Cache().FileSet()
	)

	objs, err := objectsAtProtocolPos(ctx, s, f, pp)
	if err != nil {
		return nil, err
	}

	for _, obj := range objs {
		var (
			T      *types.Interface
			method *types.Func
		)

		switch obj := obj.(type) {
		case *types.Func:
			method = obj
			if recv := obj.Type().(*types.Signature).Recv(); recv != nil {
				T, _ = recv.Type().Underlying().(*types.Interface)
			}
		case *types.TypeName:
			T, _ = obj.Type().Underlying().(*types.Interface)
		}

		if T == nil {
			return nil, ErrNotAnInterface
		}

		if T.NumMethods() == 0 {
			return nil, nil
		}

		// Find all named types, even local types (which can have methods
		// due to promotion).
		var (
			allNamed []*types.Named
			pkgs     = make(map[*types.Package]Package)
		)
		for _, ph := range s.KnownPackages(ctx) {
			pkg, err := ph.Check(ctx)
			if err != nil {
				return nil, err
			}
			pkgs[pkg.GetTypes()] = pkg

			info := pkg.GetTypesInfo()
			for _, obj := range info.Defs {
				// We ignore aliases 'type M = N' to avoid duplicate reporting
				// of the Named type N.
				if obj, ok := obj.(*types.TypeName); ok && !obj.IsAlias() {
					// We skip interface types since we only want concrete
					// implementations.
					if named, ok := obj.Type().(*types.Named); ok && !isInterface(named) {
						allNamed = append(allNamed, named)
					}
				}
			}
		}

		// Find all the named types that implement our interface.
		for _, U := range allNamed {
			var concrete types.Type = U
			if !types.AssignableTo(concrete, T) {
				// We also accept T if *T implements our interface.
				concrete = types.NewPointer(concrete)
				if !types.AssignableTo(concrete, T) {
					continue
				}
			}

			var obj types.Object = U.Obj()
			if method != nil {
				obj = types.NewMethodSet(concrete).Lookup(method.Pkg(), method.Name()).Obj()
			}

			pos := fset.Position(obj.Pos())
			if obj == method || seen[pos] {
				continue
			}

			seen[pos] = true

			impls = append(impls, implementation{
				obj: obj,
				pkg: pkgs[obj.Pkg()],
			})
		}
	}

	return impls, nil
}

type implementation struct {
	// obj is the implementation, either a *types.TypeName or *types.Func.
	obj types.Object

	// pkg is the Package that contains obj's definition.
	pkg Package
}

// objectsAtProtocolPos returns all the type.Objects referenced at the given position.
// An object will be returned for every package that the file belongs to.
func objectsAtProtocolPos(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position) ([]types.Object, error) {
	phs, err := s.PackageHandles(ctx, f)
	if err != nil {
		return nil, err
	}

	var objs []types.Object

	// Check all the packages that the file belongs to.
	for _, ph := range phs {
		pkg, err := ph.Check(ctx)
		if err != nil {
			return nil, err
		}

		astFile, pos, err := getASTFile(pkg, f, pp)
		if err != nil {
			return nil, err
		}

		path := pathEnclosingIdent(astFile, pos)
		if len(path) == 0 {
			return nil, ErrNoIdentFound
		}

		ident := path[len(path)-1].(*ast.Ident)

		obj := pkg.GetTypesInfo().ObjectOf(ident)
		if obj == nil {
			return nil, fmt.Errorf("no object for %q", ident.Name)
		}

		objs = append(objs, obj)
	}

	return objs, nil
}

func getASTFile(pkg Package, f FileHandle, pos protocol.Position) (*ast.File, token.Pos, error) {
	pgh, err := pkg.File(f.Identity().URI)
	if err != nil {
		return nil, 0, err
	}

	file, m, _, err := pgh.Cached()
	if err != nil {
		return nil, 0, err
	}

	spn, err := m.PointSpan(pos)
	if err != nil {
		return nil, 0, err
	}

	rng, err := spn.Range(m.Converter)
	if err != nil {
		return nil, 0, err
	}

	return file, rng.Start, nil
}

// pathEnclosingIdent returns the ast path to the node that contains pos.
// It is similar to astutil.PathEnclosingInterval, but simpler, and it
// matches *ast.Ident nodes if pos is equal to node.End().
func pathEnclosingIdent(f *ast.File, pos token.Pos) []ast.Node {
	var (
		path  []ast.Node
		found bool
	)

	ast.Inspect(f, func(n ast.Node) bool {
		if found {
			return false
		}

		if n == nil {
			path = path[:len(path)-1]
			return false
		}

		switch n := n.(type) {
		case *ast.Ident:
			found = n.Pos() <= pos && pos <= n.End()
		}

		path = append(path, n)

		return !found
	})

	return path
}
