internal/lsp/source: improve completion in type assertions

In cases like:

var foo *someType = bar.(some<>)

We will now complete "some" to "*someType". This involved two changes:

1. Properly detect expected type as *someType in above example. To do
   this I just removed *ast.TypeAssertExpr from
   breaksExpectedTypeInference() so we continue searching up the AST for
   the expected type.

2. If the given type name T doesn't match, also try *T. If *T does
   match, we mark the candidate as "makePointer=true" so we know to
   prepend the "*" when formatting the candidate.

Change-Id: I05859c68082a798141755b614673a1483d864e3e
Reviewed-on: https://go-review.googlesource.com/c/tools/+/212717
Run-TryBot: Muir Manders <muir@mnd.rs>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
diff --git a/internal/lsp/completion.go b/internal/lsp/completion.go
index 7edc184..560e82c 100644
--- a/internal/lsp/completion.go
+++ b/internal/lsp/completion.go
@@ -127,9 +127,9 @@
 			// https://github.com/Microsoft/language-server-protocol/issues/348.
 			SortText: fmt.Sprintf("%05d", i),
 
-			// Trim address operator (VSCode doesn't like weird characters
-			// in filterText).
-			FilterText: strings.TrimLeft(candidate.InsertText, "&"),
+			// Trim operators (VSCode doesn't like weird characters in
+			// filterText).
+			FilterText: strings.TrimLeft(candidate.InsertText, "&*"),
 
 			Preselect:     i == 0,
 			Documentation: candidate.Documentation,
diff --git a/internal/lsp/source/completion.go b/internal/lsp/source/completion.go
index 3d0ed09..4a08655 100644
--- a/internal/lsp/source/completion.go
+++ b/internal/lsp/source/completion.go
@@ -379,6 +379,9 @@
 	// addressable is true if a pointer can be taken to the candidate.
 	addressable bool
 
+	// makePointer is true if the candidate type name T should be made into *T.
+	makePointer bool
+
 	// imp is the import that needs to be added to this package in order
 	// for this candidate to be valid. nil if no import needed.
 	imp *importInfo
@@ -1535,7 +1538,7 @@
 // expect a function argument, not a composite literal value.
 func breaksExpectedTypeInference(n ast.Node) bool {
 	switch n.(type) {
-	case *ast.FuncLit, *ast.CallExpr, *ast.TypeAssertExpr, *ast.IndexExpr, *ast.SliceExpr, *ast.CompositeLit:
+	case *ast.FuncLit, *ast.CallExpr, *ast.IndexExpr, *ast.SliceExpr, *ast.CompositeLit:
 		return true
 	default:
 		return false
@@ -1758,39 +1761,52 @@
 		return false
 	}
 
-	// Take into account any type name modifier prefixes.
-	actual := c.expectedType.applyTypeNameModifiers(cand.obj.Type())
+	typeMatches := func(candType types.Type) bool {
+		// Take into account any type name modifier prefixes.
+		candType = c.expectedType.applyTypeNameModifiers(candType)
 
-	if c.expectedType.typeName.assertableFrom != nil {
-		// Don't suggest the starting type in type assertions. For example,
-		// if "foo" is an io.Writer, don't suggest "foo.(io.Writer)".
-		if types.Identical(c.expectedType.typeName.assertableFrom, actual) {
+		if from := c.expectedType.typeName.assertableFrom; from != nil {
+			// Don't suggest the starting type in type assertions. For example,
+			// if "foo" is an io.Writer, don't suggest "foo.(io.Writer)".
+			if types.Identical(from, candType) {
+				return false
+			}
+
+			if intf, ok := from.Underlying().(*types.Interface); ok {
+				if !types.AssertableTo(intf, candType) {
+					return false
+				}
+			}
+		}
+
+		if c.expectedType.typeName.wantComparable && !types.Comparable(candType) {
 			return false
 		}
 
-		if intf, ok := c.expectedType.typeName.assertableFrom.Underlying().(*types.Interface); ok {
-			if !types.AssertableTo(intf, actual) {
-				return false
-			}
+		// We can expect a type name and have an expected type in cases like:
+		//
+		//   var foo []int
+		//   foo = []i<>
+		//
+		// Where our expected type is "[]int", and we expect a type name.
+		if c.expectedType.objType != nil {
+			return types.AssignableTo(candType, c.expectedType.objType)
 		}
+
+		// Default to saying any type name is a match.
+		return true
 	}
 
-	if c.expectedType.typeName.wantComparable && !types.Comparable(actual) {
-		return false
+	if typeMatches(cand.obj.Type()) {
+		return true
 	}
 
-	// We can expect a type name and have an expected type in cases like:
-	//
-	//   var foo []int
-	//   foo = []i<>
-	//
-	// Where our expected type is "[]int", and we expect a type name.
-	if c.expectedType.objType != nil {
-		return types.AssignableTo(actual, c.expectedType.objType)
+	if typeMatches(types.NewPointer(cand.obj.Type())) {
+		cand.makePointer = true
+		return true
 	}
 
-	// Default to saying any type name is a match.
-	return true
+	return false
 }
 
 // candKind returns the objKind of candType, if any.
diff --git a/internal/lsp/source/completion_format.go b/internal/lsp/source/completion_format.go
index 950e06b..a7a8d6c 100644
--- a/internal/lsp/source/completion_format.go
+++ b/internal/lsp/source/completion_format.go
@@ -130,30 +130,29 @@
 		}
 	}
 
-	// Prepend "&" operator if our candidate needs address taken.
+	// Prepend "&" or "*" operator as appropriate.
+	var prefixOp string
 	if cand.takeAddress {
-		var (
-			sel *ast.SelectorExpr
-			ok  bool
-		)
-		if sel, ok = c.path[0].(*ast.SelectorExpr); !ok && len(c.path) > 1 {
-			sel, _ = c.path[1].(*ast.SelectorExpr)
-		}
+		prefixOp = "&"
+	} else if cand.makePointer {
+		prefixOp = "*"
+	}
 
-		// If we are in a selector, add an edit to place "&" before selector node.
-		if sel != nil {
-			edits, err := referenceEdit(c.snapshot.View().Session().Cache().FileSet(), c.mapper, sel)
+	if prefixOp != "" {
+		// If we are in a selector, add an edit to place prefix before selector.
+		if sel := enclosingSelector(c.path, c.pos); sel != nil {
+			edits, err := prependEdit(c.snapshot.View().Session().Cache().FileSet(), c.mapper, sel, prefixOp)
 			if err != nil {
-				log.Error(c.ctx, "error generating reference edit", err)
+				log.Error(c.ctx, "error generating prefix edit", err)
 			} else {
 				protocolEdits = append(protocolEdits, edits...)
 			}
 		} else {
-			// If there is no selector, just stick the "&" at the start.
-			insert = "&" + insert
+			// If there is no selector, just stick the prefix at the start.
+			insert = prefixOp + insert
 		}
 
-		label = "&" + label
+		label = prefixOp + label
 	}
 
 	detail = strings.TrimPrefix(detail, "untyped ")
diff --git a/internal/lsp/source/completion_literal.go b/internal/lsp/source/completion_literal.go
index 89ecda6..f7bdc39 100644
--- a/internal/lsp/source/completion_literal.go
+++ b/internal/lsp/source/completion_literal.go
@@ -116,7 +116,7 @@
 				// If we are in a selector we must place the "&" before the selector.
 				// For example, "foo.B<>" must complete to "&foo.Bar{}", not
 				// "foo.&Bar{}".
-				edits, err := referenceEdit(c.snapshot.View().Session().Cache().FileSet(), c.mapper, sel)
+				edits, err := prependEdit(c.snapshot.View().Session().Cache().FileSet(), c.mapper, sel, "&")
 				if err != nil {
 					log.Error(c.ctx, "error making edit for literal pointer completion", err)
 					return
@@ -173,9 +173,9 @@
 	}
 }
 
-// referenceEdit produces text edits that prepend a "&" operator to the
-// specified node.
-func referenceEdit(fset *token.FileSet, m *protocol.ColumnMapper, node ast.Node) ([]protocol.TextEdit, error) {
+// prependEdit produces text edits that preprend the specified prefix
+// to the specified node.
+func prependEdit(fset *token.FileSet, m *protocol.ColumnMapper, node ast.Node, prefix string) ([]protocol.TextEdit, error) {
 	rng := newMappedRange(fset, m, node.Pos(), node.Pos())
 	spn, err := rng.Span()
 	if err != nil {
@@ -183,7 +183,7 @@
 	}
 	return ToProtocolEdits(m, []diff.TextEdit{{
 		Span:    spn,
-		NewText: "&",
+		NewText: prefix,
 	}})
 }
 
diff --git a/internal/lsp/source/source_test.go b/internal/lsp/source/source_test.go
index bc95fc6..d649673 100644
--- a/internal/lsp/source/source_test.go
+++ b/internal/lsp/source/source_test.go
@@ -102,20 +102,13 @@
 	for _, pos := range test.CompletionItems {
 		want = append(want, tests.ToProtocolCompletionItem(*items[pos]))
 	}
-	prefix, list := r.callCompletion(t, src, func(opts *source.Options) {
-		opts.Matcher = source.Fuzzy
+	_, got := r.callCompletion(t, src, func(opts *source.Options) {
+		opts.Matcher = source.CaseInsensitive
 		opts.Literal = strings.Contains(string(src.URI()), "literal")
 		opts.DeepCompletion = false
 	})
 	if !strings.Contains(string(src.URI()), "builtins") {
-		list = tests.FilterBuiltins(list)
-	}
-	var got []protocol.CompletionItem
-	for _, item := range list {
-		if !strings.HasPrefix(strings.ToLower(item.Label), prefix) {
-			continue
-		}
-		got = append(got, item)
+		got = tests.FilterBuiltins(got)
 	}
 	if diff := tests.DiffCompletionItems(want, got); diff != "" {
 		t.Errorf("%s: %s", src, diff)
diff --git a/internal/lsp/testdata/func_rank/func_rank.go.in b/internal/lsp/testdata/func_rank/func_rank.go.in
index 0d0feeb..61ad6e9 100644
--- a/internal/lsp/testdata/func_rank/func_rank.go.in
+++ b/internal/lsp/testdata/func_rank/func_rank.go.in
@@ -29,7 +29,7 @@
 	fnInt([]int{}[s.A])       //@complete("])", rankAA, rankAC, rankAB)
 	fnInt([]int{}[:s.A])      //@complete("])", rankAA, rankAC, rankAB)
 
-	fnInt(s.A.(int)) //@complete(".(", rankAA, rankAB, rankAC)
+	fnInt(s.A.(int)) //@complete(".(", rankAA, rankAC, rankAB)
 
 	fnPtr := func(*string) {}
 	fnPtr(&s.A) //@complete(")", rankAB, rankAA, rankAC)
diff --git a/internal/lsp/testdata/maps/maps.go.in b/internal/lsp/testdata/maps/maps.go.in
index 36451ea..5c9dedd 100644
--- a/internal/lsp/testdata/maps/maps.go.in
+++ b/internal/lsp/testdata/maps/maps.go.in
@@ -6,13 +6,13 @@
 	// not comparabale
 	type aSlice []int     //@item(mapSliceType, "aSlice", "[]int", "type")
 
+	*aSlice     //@item(mapSliceTypePtr, "*aSlice", "[]int", "type")
+
 	// comparable
 	type aStruct struct{} //@item(mapStructType, "aStruct", "struct{...}", "struct")
 
-	map[]a{} //@complete("]", mapStructType, mapSliceType, mapVar)
+	map[]a{} //@complete("]", mapSliceTypePtr, mapStructType, mapVar)
 
-	map[a]a{} //@complete("]", mapStructType, mapSliceType, mapVar)
+	map[a]a{} //@complete("]", mapSliceTypePtr, mapStructType, mapVar)
 	map[a]a{} //@complete("{", mapSliceType, mapStructType, mapVar)
-
-	map[]a{} //@rank("]", int, mapSliceType)
 }
diff --git a/internal/lsp/testdata/summary.txt.golden b/internal/lsp/testdata/summary.txt.golden
index 65eebf3..132779e 100644
--- a/internal/lsp/testdata/summary.txt.golden
+++ b/internal/lsp/testdata/summary.txt.golden
@@ -4,7 +4,7 @@
 UnimportedCompletionsCount = 9
 DeepCompletionsCount = 5
 FuzzyCompletionsCount = 8
-RankedCompletionsCount = 56
+RankedCompletionsCount = 55
 CaseSensitiveCompletionsCount = 4
 DiagnosticsCount = 35
 FoldingRangesCount = 2
diff --git a/internal/lsp/testdata/typeassert/type_assert.go b/internal/lsp/testdata/typeassert/type_assert.go
index 8b55bf3..0dfd3a1 100644
--- a/internal/lsp/testdata/typeassert/type_assert.go
+++ b/internal/lsp/testdata/typeassert/type_assert.go
@@ -13,12 +13,14 @@
 type abcNotImpl struct{} //@item(abcNotImpl, "abcNotImpl", "struct{...}", "struct")
 
 func _() {
+	*abcPtrImpl //@item(abcPtrImplPtr, "*abcPtrImpl", "struct{...}", "struct")
+
 	var a abc
 	switch a.(type) {
-	case ab: //@complete(":", abcImpl, abcIntf, abcNotImpl, abcPtrImpl)
+	case ab: //@complete(":", abcPtrImplPtr, abcImpl, abcIntf, abcNotImpl)
 	case *ab: //@complete(":", abcImpl, abcPtrImpl, abcIntf, abcNotImpl)
 	}
 
-	a.(ab)  //@complete(")", abcImpl, abcIntf, abcNotImpl, abcPtrImpl)
+	a.(ab)  //@complete(")", abcPtrImplPtr, abcImpl, abcIntf, abcNotImpl)
 	a.(*ab) //@complete(")", abcImpl, abcPtrImpl, abcIntf, abcNotImpl)
 }