internal/lsp: implement type definitions

Extend definition tests to add typdef test.

Change-Id: Ibad988ae68f91d18f2c6b4739d758a536172fb35
Reviewed-on: https://go-review.googlesource.com/c/152239
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Cottrell <iancottrell@google.com>
diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go
index 03c862f..666792a 100644
--- a/internal/lsp/lsp_test.go
+++ b/internal/lsp/lsp_test.go
@@ -38,6 +38,7 @@
 	const expectedDiagnosticsCount = 14
 	const expectedFormatCount = 3
 	const expectedDefinitionsCount = 16
+	const expectedTypeDefinitionsCount = 2
 
 	files := packagestest.MustCopyFileTree(dir)
 	for fragment, operation := range files {
@@ -78,6 +79,7 @@
 	expectedCompletions := make(completions)
 	expectedFormat := make(formats)
 	expectedDefinitions := make(definitions)
+	expectedTypeDefinitions := make(definitions)
 
 	// Collect any data that needs to be used by subsequent tests.
 	if err := exported.Expect(map[string]interface{}{
@@ -86,6 +88,7 @@
 		"complete": expectedCompletions.collect,
 		"format":   expectedFormat.collect,
 		"godef":    expectedDefinitions.collect,
+		"typdef":   expectedTypeDefinitions.collect,
 	}); err != nil {
 		t.Fatal(err)
 	}
@@ -127,7 +130,17 @@
 				t.Errorf("got %v definitions expected %v", len(expectedDefinitions), expectedDefinitionsCount)
 			}
 		}
-		expectedDefinitions.test(t, s)
+		expectedDefinitions.test(t, s, false)
+	})
+
+	t.Run("TypeDefinitions", func(t *testing.T) {
+		t.Helper()
+		if goVersion111 { // TODO(rstambler): Remove this when we no longer support Go 1.10.
+			if len(expectedTypeDefinitions) != expectedTypeDefinitionsCount {
+				t.Errorf("got %v type definitions expected %v", len(expectedTypeDefinitions), expectedTypeDefinitionsCount)
+			}
+		}
+		expectedTypeDefinitions.test(t, s, true)
 	})
 }
 
@@ -290,14 +303,21 @@
 	f[pos.Filename] = stdout.String()
 }
 
-func (d definitions) test(t *testing.T, s *server) {
+func (d definitions) test(t *testing.T, s *server, typ bool) {
 	for src, target := range d {
-		locs, err := s.Definition(context.Background(), &protocol.TextDocumentPositionParams{
+		params := &protocol.TextDocumentPositionParams{
 			TextDocument: protocol.TextDocumentIdentifier{
 				URI: src.URI,
 			},
 			Position: src.Range.Start,
-		})
+		}
+		var locs []protocol.Location
+		var err error
+		if typ {
+			locs, err = s.TypeDefinition(context.Background(), params)
+		} else {
+			locs, err = s.Definition(context.Background(), params)
+		}
 		if err != nil {
 			t.Fatal(err)
 		}
diff --git a/internal/lsp/server.go b/internal/lsp/server.go
index f5bdf0d..b8b0952 100644
--- a/internal/lsp/server.go
+++ b/internal/lsp/server.go
@@ -64,6 +64,7 @@
 				Change:    float64(protocol.Full), // full contents of file sent on each update
 				OpenClose: true,
 			},
+			TypeDefinitionProvider: true,
 		},
 	}, nil
 }
@@ -215,8 +216,18 @@
 	return []protocol.Location{toProtocolLocation(s.view.Config.Fset, r)}, nil
 }
 
-func (s *server) TypeDefinition(context.Context, *protocol.TextDocumentPositionParams) ([]protocol.Location, error) {
-	return nil, notImplemented("TypeDefinition")
+func (s *server) TypeDefinition(ctx context.Context, params *protocol.TextDocumentPositionParams) ([]protocol.Location, error) {
+	f := s.view.GetFile(source.URI(params.TextDocument.URI))
+	tok, err := f.GetToken()
+	if err != nil {
+		return nil, err
+	}
+	pos := fromProtocolPosition(tok, params.Position)
+	r, err := source.TypeDefinition(ctx, f, pos)
+	if err != nil {
+		return nil, err
+	}
+	return []protocol.Location{toProtocolLocation(s.view.Config.Fset, r)}, nil
 }
 
 func (s *server) Implementation(context.Context, *protocol.TextDocumentPositionParams) ([]protocol.Location, error) {
diff --git a/internal/lsp/source/definition.go b/internal/lsp/source/definition.go
index 4354638..b4f05f4 100644
--- a/internal/lsp/source/definition.go
+++ b/internal/lsp/source/definition.go
@@ -48,6 +48,43 @@
 	return objToRange(f.view.Config.Fset, obj), nil
 }
 
+func TypeDefinition(ctx context.Context, f *File, pos token.Pos) (Range, error) {
+	fAST, err := f.GetAST()
+	if err != nil {
+		return Range{}, err
+	}
+	pkg, err := f.GetPackage()
+	if err != nil {
+		return Range{}, err
+	}
+	i, err := findIdentifier(fAST, pos)
+	if err != nil {
+		return Range{}, err
+	}
+	if i.ident == nil {
+		return Range{}, fmt.Errorf("not a valid identifier")
+	}
+	typ := pkg.TypesInfo.TypeOf(i.ident)
+	if typ == nil {
+		return Range{}, fmt.Errorf("no type for %s", i.ident.Name)
+	}
+	obj := typeToObject(typ)
+	if obj == nil {
+		return Range{}, fmt.Errorf("no object for type %s", typ.String())
+	}
+	return objToRange(f.view.Config.Fset, obj), nil
+}
+
+func typeToObject(typ types.Type) (obj types.Object) {
+	switch typ := typ.(type) {
+	case *types.Named:
+		obj = typ.Obj()
+	case *types.Pointer:
+		obj = typeToObject(typ.Elem())
+	}
+	return obj
+}
+
 // ident returns the ident plus any extra information needed
 type ident struct {
 	ident            *ast.Ident
diff --git a/internal/lsp/testdata/baz/baz.go.in b/internal/lsp/testdata/baz/baz.go.in
index 1af3bc4..90d952b 100644
--- a/internal/lsp/testdata/baz/baz.go.in
+++ b/internal/lsp/testdata/baz/baz.go.in
@@ -12,7 +12,7 @@
 	defer bar.Bar() //@complete("B", Bar)
 	// TODO(rstambler): Test completion here.
 	defer bar.B
-	var _ f.IntFoo  //@complete("n", IntFoo)
+	var x f.IntFoo  //@complete("n", IntFoo),typdef("x", IntFoo)
 	bar.Bar()       //@complete("B", Bar)
 }
 
diff --git a/internal/lsp/testdata/foo/foo.go b/internal/lsp/testdata/foo/foo.go
index 27c1b42..e02099b 100644
--- a/internal/lsp/testdata/foo/foo.go
+++ b/internal/lsp/testdata/foo/foo.go
@@ -14,7 +14,7 @@
 
 func _() {
 	var sFoo StructFoo           //@complete("t", StructFoo)
-	if x := sFoo; x.Value == 1 { //@complete("V", Value)
+	if x := sFoo; x.Value == 1 { //@complete("V", Value),typdef("sFoo", StructFoo)
 		return
 	}
 }