internal/lsp: improve completion support for type assertions
In type assertion expressions and type switch clauses we now infer the
type from which candidates must be assertable. For example in:
var foo io.Writer
bar := foo.(<>)
When suggesting concrete types we will prefer types that actually
implement io.Writer.
I also added support for the "*" type name modifier. Using the above
example:
bar := foo.(*<>)
we will prefer type T such that *T implements io.Writer.
Change-Id: Ib483bf5e7b339338adc1bfb17b34bc4050d05ad1
GitHub-Last-Rev: 965b028cc00b036019bfdc97561d9e09b7b912ec
GitHub-Pull-Request: golang/tools#123
Reviewed-on: https://go-review.googlesource.com/c/tools/+/183137
Run-TryBot: Rebecca Stambler <rstambler@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
diff --git a/internal/lsp/source/completion.go b/internal/lsp/source/completion.go
index 81fbd99..fb8fb55 100644
--- a/internal/lsp/source/completion.go
+++ b/internal/lsp/source/completion.go
@@ -211,10 +211,6 @@
cand.score *= highScore
}
- if c.wantTypeName() && !isTypeName(obj) {
- cand.score *= lowScore
- }
-
c.items = append(c.items, c.item(cand))
}
@@ -673,9 +669,9 @@
type typeModifier int
const (
- dereference typeModifier = iota // dereference ("*") operator
- reference // reference ("&") operator
- chanRead // channel read ("<-") operator
+ star typeModifier = iota // dereference operator for expressions, pointer indicator for types
+ reference // reference ("&") operator
+ chanRead // channel read ("<-") operator
)
// typeInference holds information we have inferred about a type that can be
@@ -690,6 +686,9 @@
// modifiers are prefixes such as "*", "&" or "<-" that influence how
// a candidate type relates to the expected type.
modifiers []typeModifier
+
+ // assertableFrom is a type that must be assertable to our candidate type.
+ assertableFrom types.Type
}
// expectedType returns information about the expected type for an expression at
@@ -807,7 +806,7 @@
}
return typeInference{}
case *ast.StarExpr:
- modifiers = append(modifiers, dereference)
+ modifiers = append(modifiers, star)
case *ast.UnaryExpr:
switch node.Op {
case token.AND:
@@ -832,7 +831,7 @@
func (ti typeInference) applyTypeModifiers(typ types.Type) types.Type {
for _, mod := range ti.modifiers {
switch mod {
- case dereference:
+ case star:
// For every "*" deref operator, remove a pointer layer from candidate type.
typ = deref(typ)
case reference:
@@ -848,6 +847,18 @@
return typ
}
+// applyTypeNameModifiers applies the list of type modifiers to a type name.
+func (ti typeInference) applyTypeNameModifiers(typ types.Type) types.Type {
+ for _, mod := range ti.modifiers {
+ switch mod {
+ case star:
+ // For every "*" indicator, add a pointer layer to type name.
+ typ = types.NewPointer(typ)
+ }
+ }
+ return typ
+}
+
// findSwitchStmt returns an *ast.CaseClause's corresponding *ast.SwitchStmt or
// *ast.TypeSwitchStmt. path should start from the case clause's first ancestor.
func findSwitchStmt(path []ast.Node, pos token.Pos, c *ast.CaseClause) ast.Stmt {
@@ -886,7 +897,11 @@
// expectTypeName returns information about the expected type name at position.
func expectTypeName(c *completer) typeInference {
- var wantTypeName bool
+ var (
+ wantTypeName bool
+ modifiers []typeModifier
+ assertableFrom types.Type
+ )
Nodes:
for i, p := range c.path {
@@ -911,7 +926,15 @@
return typeInference{}
case *ast.CaseClause:
// Expect type names in type switch case clauses.
- if _, ok := findSwitchStmt(c.path[i+1:], c.pos, n).(*ast.TypeSwitchStmt); ok {
+ if swtch, ok := findSwitchStmt(c.path[i+1:], c.pos, n).(*ast.TypeSwitchStmt); ok {
+ // The case clause types must be assertable from the type switch parameter.
+ ast.Inspect(swtch.Assign, func(n ast.Node) bool {
+ if ta, ok := n.(*ast.TypeAssertExpr); ok {
+ assertableFrom = c.info.TypeOf(ta.X)
+ return false
+ }
+ return true
+ })
wantTypeName = true
break Nodes
}
@@ -919,10 +942,14 @@
case *ast.TypeAssertExpr:
// Expect type names in type assert expressions.
if n.Lparen < c.pos && c.pos <= n.Rparen {
+ // The type in parens must be assertable from the expression type.
+ assertableFrom = c.info.TypeOf(n.X)
wantTypeName = true
break Nodes
}
return typeInference{}
+ case *ast.StarExpr:
+ modifiers = append(modifiers, star)
default:
if breaksExpectedTypeInference(p) {
return typeInference{}
@@ -931,13 +958,19 @@
}
return typeInference{
- wantTypeName: wantTypeName,
+ wantTypeName: wantTypeName,
+ modifiers: modifiers,
+ assertableFrom: assertableFrom,
}
}
// matchingType reports whether an object is a good completion candidate
// in the context of the expected type.
func (c *completer) matchingType(cand *candidate) bool {
+ if isTypeName(cand.obj) {
+ return c.matchingTypeName(cand)
+ }
+
objType := cand.obj.Type()
// Default to invoking *types.Func candidates. This is so function
@@ -976,3 +1009,29 @@
return false
}
+
+func (c *completer) matchingTypeName(cand *candidate) bool {
+ if !c.wantTypeName() {
+ return false
+ }
+
+ // Take into account any type name modifier prefixes.
+ actual := c.expectedType.applyTypeNameModifiers(cand.obj.Type())
+
+ if c.expectedType.assertableFrom != nil {
+ // Don't suggest the starting type in type assertions. For example,
+ // if "foo" is an io.Writer, don't suggest "foo.(io.Writer)".
+ if types.Identical(c.expectedType.assertableFrom, actual) {
+ return false
+ }
+
+ if intf, ok := c.expectedType.assertableFrom.Underlying().(*types.Interface); ok {
+ if !types.AssertableTo(intf, actual) {
+ return false
+ }
+ }
+ }
+
+ // Default to saying any type name is a match.
+ return true
+}
diff --git a/internal/lsp/testdata/good/good1.go b/internal/lsp/testdata/good/good1.go
index f490b4d..b595950 100644
--- a/internal/lsp/testdata/good/good1.go
+++ b/internal/lsp/testdata/good/good1.go
@@ -12,7 +12,7 @@
func random2(y int) int { //@item(good_random2, "random2(y int)", "int", "func"),item(good_y_param, "y", "int", "parameter")
//@complete("", good_y_param, types_import, good_random, good_random2, good_stuff)
var b types.Bob = &types.X{}
- if _, ok := b.(*types.X); ok { //@complete("X", Bob_interface, X_struct, Y_struct)
+ if _, ok := b.(*types.X); ok { //@complete("X", X_struct, Y_struct, Bob_interface)
}
return y
diff --git a/internal/lsp/testdata/typeassert/type_assert.go b/internal/lsp/testdata/typeassert/type_assert.go
new file mode 100644
index 0000000..8b55bf3
--- /dev/null
+++ b/internal/lsp/testdata/typeassert/type_assert.go
@@ -0,0 +1,24 @@
+package typeassert
+
+type abc interface { //@item(abcIntf, "abc", "interface{...}", "interface")
+ abc()
+}
+
+type abcImpl struct{} //@item(abcImpl, "abcImpl", "struct{...}", "struct")
+func (abcImpl) abc()
+
+type abcPtrImpl struct{} //@item(abcPtrImpl, "abcPtrImpl", "struct{...}", "struct")
+func (*abcPtrImpl) abc()
+
+type abcNotImpl struct{} //@item(abcNotImpl, "abcNotImpl", "struct{...}", "struct")
+
+func _() {
+ var a abc
+ switch a.(type) {
+ case ab: //@complete(":", abcImpl, abcIntf, abcNotImpl, abcPtrImpl)
+ case *ab: //@complete(":", abcImpl, abcPtrImpl, abcIntf, abcNotImpl)
+ }
+
+ a.(ab) //@complete(")", abcImpl, abcIntf, abcNotImpl, abcPtrImpl)
+ a.(*ab) //@complete(")", abcImpl, abcPtrImpl, abcIntf, abcNotImpl)
+}
diff --git a/internal/lsp/tests/tests.go b/internal/lsp/tests/tests.go
index 5ecc71a..24b026d 100644
--- a/internal/lsp/tests/tests.go
+++ b/internal/lsp/tests/tests.go
@@ -25,7 +25,7 @@
// We hardcode the expected number of test cases to ensure that all tests
// are being executed. If a test is added, this number must be changed.
const (
- ExpectedCompletionsCount = 128
+ ExpectedCompletionsCount = 132
ExpectedCompletionSnippetCount = 14
ExpectedDiagnosticsCount = 17
ExpectedFormatCount = 5