internal/lsp: implement type definitions
Extend definition tests to add typdef test.
Change-Id: Ibad988ae68f91d18f2c6b4739d758a536172fb35
Reviewed-on: https://go-review.googlesource.com/c/152239
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Cottrell <iancottrell@google.com>
diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go
index 03c862f..666792a 100644
--- a/internal/lsp/lsp_test.go
+++ b/internal/lsp/lsp_test.go
@@ -38,6 +38,7 @@
const expectedDiagnosticsCount = 14
const expectedFormatCount = 3
const expectedDefinitionsCount = 16
+ const expectedTypeDefinitionsCount = 2
files := packagestest.MustCopyFileTree(dir)
for fragment, operation := range files {
@@ -78,6 +79,7 @@
expectedCompletions := make(completions)
expectedFormat := make(formats)
expectedDefinitions := make(definitions)
+ expectedTypeDefinitions := make(definitions)
// Collect any data that needs to be used by subsequent tests.
if err := exported.Expect(map[string]interface{}{
@@ -86,6 +88,7 @@
"complete": expectedCompletions.collect,
"format": expectedFormat.collect,
"godef": expectedDefinitions.collect,
+ "typdef": expectedTypeDefinitions.collect,
}); err != nil {
t.Fatal(err)
}
@@ -127,7 +130,17 @@
t.Errorf("got %v definitions expected %v", len(expectedDefinitions), expectedDefinitionsCount)
}
}
- expectedDefinitions.test(t, s)
+ expectedDefinitions.test(t, s, false)
+ })
+
+ t.Run("TypeDefinitions", func(t *testing.T) {
+ t.Helper()
+ if goVersion111 { // TODO(rstambler): Remove this when we no longer support Go 1.10.
+ if len(expectedTypeDefinitions) != expectedTypeDefinitionsCount {
+ t.Errorf("got %v type definitions expected %v", len(expectedTypeDefinitions), expectedTypeDefinitionsCount)
+ }
+ }
+ expectedTypeDefinitions.test(t, s, true)
})
}
@@ -290,14 +303,21 @@
f[pos.Filename] = stdout.String()
}
-func (d definitions) test(t *testing.T, s *server) {
+func (d definitions) test(t *testing.T, s *server, typ bool) {
for src, target := range d {
- locs, err := s.Definition(context.Background(), &protocol.TextDocumentPositionParams{
+ params := &protocol.TextDocumentPositionParams{
TextDocument: protocol.TextDocumentIdentifier{
URI: src.URI,
},
Position: src.Range.Start,
- })
+ }
+ var locs []protocol.Location
+ var err error
+ if typ {
+ locs, err = s.TypeDefinition(context.Background(), params)
+ } else {
+ locs, err = s.Definition(context.Background(), params)
+ }
if err != nil {
t.Fatal(err)
}
diff --git a/internal/lsp/server.go b/internal/lsp/server.go
index f5bdf0d..b8b0952 100644
--- a/internal/lsp/server.go
+++ b/internal/lsp/server.go
@@ -64,6 +64,7 @@
Change: float64(protocol.Full), // full contents of file sent on each update
OpenClose: true,
},
+ TypeDefinitionProvider: true,
},
}, nil
}
@@ -215,8 +216,18 @@
return []protocol.Location{toProtocolLocation(s.view.Config.Fset, r)}, nil
}
-func (s *server) TypeDefinition(context.Context, *protocol.TextDocumentPositionParams) ([]protocol.Location, error) {
- return nil, notImplemented("TypeDefinition")
+func (s *server) TypeDefinition(ctx context.Context, params *protocol.TextDocumentPositionParams) ([]protocol.Location, error) {
+ f := s.view.GetFile(source.URI(params.TextDocument.URI))
+ tok, err := f.GetToken()
+ if err != nil {
+ return nil, err
+ }
+ pos := fromProtocolPosition(tok, params.Position)
+ r, err := source.TypeDefinition(ctx, f, pos)
+ if err != nil {
+ return nil, err
+ }
+ return []protocol.Location{toProtocolLocation(s.view.Config.Fset, r)}, nil
}
func (s *server) Implementation(context.Context, *protocol.TextDocumentPositionParams) ([]protocol.Location, error) {
diff --git a/internal/lsp/source/definition.go b/internal/lsp/source/definition.go
index 4354638..b4f05f4 100644
--- a/internal/lsp/source/definition.go
+++ b/internal/lsp/source/definition.go
@@ -48,6 +48,43 @@
return objToRange(f.view.Config.Fset, obj), nil
}
+func TypeDefinition(ctx context.Context, f *File, pos token.Pos) (Range, error) {
+ fAST, err := f.GetAST()
+ if err != nil {
+ return Range{}, err
+ }
+ pkg, err := f.GetPackage()
+ if err != nil {
+ return Range{}, err
+ }
+ i, err := findIdentifier(fAST, pos)
+ if err != nil {
+ return Range{}, err
+ }
+ if i.ident == nil {
+ return Range{}, fmt.Errorf("not a valid identifier")
+ }
+ typ := pkg.TypesInfo.TypeOf(i.ident)
+ if typ == nil {
+ return Range{}, fmt.Errorf("no type for %s", i.ident.Name)
+ }
+ obj := typeToObject(typ)
+ if obj == nil {
+ return Range{}, fmt.Errorf("no object for type %s", typ.String())
+ }
+ return objToRange(f.view.Config.Fset, obj), nil
+}
+
+func typeToObject(typ types.Type) (obj types.Object) {
+ switch typ := typ.(type) {
+ case *types.Named:
+ obj = typ.Obj()
+ case *types.Pointer:
+ obj = typeToObject(typ.Elem())
+ }
+ return obj
+}
+
// ident returns the ident plus any extra information needed
type ident struct {
ident *ast.Ident
diff --git a/internal/lsp/testdata/baz/baz.go.in b/internal/lsp/testdata/baz/baz.go.in
index 1af3bc4..90d952b 100644
--- a/internal/lsp/testdata/baz/baz.go.in
+++ b/internal/lsp/testdata/baz/baz.go.in
@@ -12,7 +12,7 @@
defer bar.Bar() //@complete("B", Bar)
// TODO(rstambler): Test completion here.
defer bar.B
- var _ f.IntFoo //@complete("n", IntFoo)
+ var x f.IntFoo //@complete("n", IntFoo),typdef("x", IntFoo)
bar.Bar() //@complete("B", Bar)
}
diff --git a/internal/lsp/testdata/foo/foo.go b/internal/lsp/testdata/foo/foo.go
index 27c1b42..e02099b 100644
--- a/internal/lsp/testdata/foo/foo.go
+++ b/internal/lsp/testdata/foo/foo.go
@@ -14,7 +14,7 @@
func _() {
var sFoo StructFoo //@complete("t", StructFoo)
- if x := sFoo; x.Value == 1 { //@complete("V", Value)
+ if x := sFoo; x.Value == 1 { //@complete("V", Value),typdef("sFoo", StructFoo)
return
}
}