internal/lsp: improve completion support for type assertions

In type assertion expressions and type switch clauses we now infer the
type from which candidates must be assertable. For example in:

var foo io.Writer
bar := foo.(<>)

When suggesting concrete types we will prefer types that actually
implement io.Writer.

I also added support for the "*" type name modifier. Using the above
example:

bar := foo.(*<>)

we will prefer type T such that *T implements io.Writer.

Change-Id: Ib483bf5e7b339338adc1bfb17b34bc4050d05ad1
GitHub-Last-Rev: 965b028cc00b036019bfdc97561d9e09b7b912ec
GitHub-Pull-Request: golang/tools#123
Reviewed-on: https://go-review.googlesource.com/c/tools/+/183137
Run-TryBot: Rebecca Stambler <rstambler@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
diff --git a/internal/lsp/source/completion.go b/internal/lsp/source/completion.go
index 81fbd99..fb8fb55 100644
--- a/internal/lsp/source/completion.go
+++ b/internal/lsp/source/completion.go
@@ -211,10 +211,6 @@
 		cand.score *= highScore
 	}
 
-	if c.wantTypeName() && !isTypeName(obj) {
-		cand.score *= lowScore
-	}
-
 	c.items = append(c.items, c.item(cand))
 }
 
@@ -673,9 +669,9 @@
 type typeModifier int
 
 const (
-	dereference typeModifier = iota // dereference ("*") operator
-	reference                       // reference ("&") operator
-	chanRead                        // channel read ("<-") operator
+	star      typeModifier = iota // dereference operator for expressions, pointer indicator for types
+	reference                     // reference ("&") operator
+	chanRead                      // channel read ("<-") operator
 )
 
 // typeInference holds information we have inferred about a type that can be
@@ -690,6 +686,9 @@
 	// modifiers are prefixes such as "*", "&" or "<-" that influence how
 	// a candidate type relates to the expected type.
 	modifiers []typeModifier
+
+	// assertableFrom is a type that must be assertable to our candidate type.
+	assertableFrom types.Type
 }
 
 // expectedType returns information about the expected type for an expression at
@@ -807,7 +806,7 @@
 			}
 			return typeInference{}
 		case *ast.StarExpr:
-			modifiers = append(modifiers, dereference)
+			modifiers = append(modifiers, star)
 		case *ast.UnaryExpr:
 			switch node.Op {
 			case token.AND:
@@ -832,7 +831,7 @@
 func (ti typeInference) applyTypeModifiers(typ types.Type) types.Type {
 	for _, mod := range ti.modifiers {
 		switch mod {
-		case dereference:
+		case star:
 			// For every "*" deref operator, remove a pointer layer from candidate type.
 			typ = deref(typ)
 		case reference:
@@ -848,6 +847,18 @@
 	return typ
 }
 
+// applyTypeNameModifiers applies the list of type modifiers to a type name.
+func (ti typeInference) applyTypeNameModifiers(typ types.Type) types.Type {
+	for _, mod := range ti.modifiers {
+		switch mod {
+		case star:
+			// For every "*" indicator, add a pointer layer to type name.
+			typ = types.NewPointer(typ)
+		}
+	}
+	return typ
+}
+
 // findSwitchStmt returns an *ast.CaseClause's corresponding *ast.SwitchStmt or
 // *ast.TypeSwitchStmt. path should start from the case clause's first ancestor.
 func findSwitchStmt(path []ast.Node, pos token.Pos, c *ast.CaseClause) ast.Stmt {
@@ -886,7 +897,11 @@
 
 // expectTypeName returns information about the expected type name at position.
 func expectTypeName(c *completer) typeInference {
-	var wantTypeName bool
+	var (
+		wantTypeName   bool
+		modifiers      []typeModifier
+		assertableFrom types.Type
+	)
 
 Nodes:
 	for i, p := range c.path {
@@ -911,7 +926,15 @@
 			return typeInference{}
 		case *ast.CaseClause:
 			// Expect type names in type switch case clauses.
-			if _, ok := findSwitchStmt(c.path[i+1:], c.pos, n).(*ast.TypeSwitchStmt); ok {
+			if swtch, ok := findSwitchStmt(c.path[i+1:], c.pos, n).(*ast.TypeSwitchStmt); ok {
+				// The case clause types must be assertable from the type switch parameter.
+				ast.Inspect(swtch.Assign, func(n ast.Node) bool {
+					if ta, ok := n.(*ast.TypeAssertExpr); ok {
+						assertableFrom = c.info.TypeOf(ta.X)
+						return false
+					}
+					return true
+				})
 				wantTypeName = true
 				break Nodes
 			}
@@ -919,10 +942,14 @@
 		case *ast.TypeAssertExpr:
 			// Expect type names in type assert expressions.
 			if n.Lparen < c.pos && c.pos <= n.Rparen {
+				// The type in parens must be assertable from the expression type.
+				assertableFrom = c.info.TypeOf(n.X)
 				wantTypeName = true
 				break Nodes
 			}
 			return typeInference{}
+		case *ast.StarExpr:
+			modifiers = append(modifiers, star)
 		default:
 			if breaksExpectedTypeInference(p) {
 				return typeInference{}
@@ -931,13 +958,19 @@
 	}
 
 	return typeInference{
-		wantTypeName: wantTypeName,
+		wantTypeName:   wantTypeName,
+		modifiers:      modifiers,
+		assertableFrom: assertableFrom,
 	}
 }
 
 // matchingType reports whether an object is a good completion candidate
 // in the context of the expected type.
 func (c *completer) matchingType(cand *candidate) bool {
+	if isTypeName(cand.obj) {
+		return c.matchingTypeName(cand)
+	}
+
 	objType := cand.obj.Type()
 
 	// Default to invoking *types.Func candidates. This is so function
@@ -976,3 +1009,29 @@
 
 	return false
 }
+
+func (c *completer) matchingTypeName(cand *candidate) bool {
+	if !c.wantTypeName() {
+		return false
+	}
+
+	// Take into account any type name modifier prefixes.
+	actual := c.expectedType.applyTypeNameModifiers(cand.obj.Type())
+
+	if c.expectedType.assertableFrom != nil {
+		// Don't suggest the starting type in type assertions. For example,
+		// if "foo" is an io.Writer, don't suggest "foo.(io.Writer)".
+		if types.Identical(c.expectedType.assertableFrom, actual) {
+			return false
+		}
+
+		if intf, ok := c.expectedType.assertableFrom.Underlying().(*types.Interface); ok {
+			if !types.AssertableTo(intf, actual) {
+				return false
+			}
+		}
+	}
+
+	// Default to saying any type name is a match.
+	return true
+}
diff --git a/internal/lsp/testdata/good/good1.go b/internal/lsp/testdata/good/good1.go
index f490b4d..b595950 100644
--- a/internal/lsp/testdata/good/good1.go
+++ b/internal/lsp/testdata/good/good1.go
@@ -12,7 +12,7 @@
 func random2(y int) int { //@item(good_random2, "random2(y int)", "int", "func"),item(good_y_param, "y", "int", "parameter")
 	//@complete("", good_y_param, types_import, good_random, good_random2, good_stuff)
 	var b types.Bob = &types.X{}
-	if _, ok := b.(*types.X); ok { //@complete("X", Bob_interface, X_struct, Y_struct)
+	if _, ok := b.(*types.X); ok { //@complete("X", X_struct, Y_struct, Bob_interface)
 	}
 
 	return y
diff --git a/internal/lsp/testdata/typeassert/type_assert.go b/internal/lsp/testdata/typeassert/type_assert.go
new file mode 100644
index 0000000..8b55bf3
--- /dev/null
+++ b/internal/lsp/testdata/typeassert/type_assert.go
@@ -0,0 +1,24 @@
+package typeassert
+
+type abc interface { //@item(abcIntf, "abc", "interface{...}", "interface")
+	abc()
+}
+
+type abcImpl struct{} //@item(abcImpl, "abcImpl", "struct{...}", "struct")
+func (abcImpl) abc()
+
+type abcPtrImpl struct{} //@item(abcPtrImpl, "abcPtrImpl", "struct{...}", "struct")
+func (*abcPtrImpl) abc()
+
+type abcNotImpl struct{} //@item(abcNotImpl, "abcNotImpl", "struct{...}", "struct")
+
+func _() {
+	var a abc
+	switch a.(type) {
+	case ab: //@complete(":", abcImpl, abcIntf, abcNotImpl, abcPtrImpl)
+	case *ab: //@complete(":", abcImpl, abcPtrImpl, abcIntf, abcNotImpl)
+	}
+
+	a.(ab)  //@complete(")", abcImpl, abcIntf, abcNotImpl, abcPtrImpl)
+	a.(*ab) //@complete(")", abcImpl, abcPtrImpl, abcIntf, abcNotImpl)
+}
diff --git a/internal/lsp/tests/tests.go b/internal/lsp/tests/tests.go
index 5ecc71a..24b026d 100644
--- a/internal/lsp/tests/tests.go
+++ b/internal/lsp/tests/tests.go
@@ -25,7 +25,7 @@
 // We hardcode the expected number of test cases to ensure that all tests
 // are being executed. If a test is added, this number must be changed.
 const (
-	ExpectedCompletionsCount       = 128
+	ExpectedCompletionsCount       = 132
 	ExpectedCompletionSnippetCount = 14
 	ExpectedDiagnosticsCount       = 17
 	ExpectedFormatCount            = 5