internal/lsp/source: fix composite literal type name completion

Fix completion in the following cases:

    type foo struct{}

    // now we offer "&foo" instead of "foo"
    var _ *foo = fo<>{}

    struct { f *foo }{
      // now we offer "&foo" instead of "*foo"
      f: fo<>{},
    }

Composite literal type names are a bit special because they are part
of an arbitrary value expression rather than just a standalone type
name expression. In particular, they can be preceded by "&", which
affects how they relate to the surrounding context. The "&" doesn't
technically apply to the type name, but we must take it into account.

I made three changes to fix the behavior:
1. When we want to make a composite literal type name into a pointer,
   we use "&" instead of "*".
2. Record if a composite literal type is already has a "&" so we don't
   add it again.
3. Fix "var _ *foo = fo<>{}" to properly infer expected type of "*foo"
   by not stopping at *ast.CompositeLit searching up AST path when the
   position is in the type name (as opposed to within the curlies).

Change-Id: Iee828f259eb939646b68f5066614ea3a262585c2
Reviewed-on: https://go-review.googlesource.com/c/tools/+/247525
Run-TryBot: Muir Manders <muir@mnd.rs>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Robert Findley <rfindley@google.com>
diff --git a/internal/lsp/source/completion.go b/internal/lsp/source/completion.go
index 639e6a6..670865f 100644
--- a/internal/lsp/source/completion.go
+++ b/internal/lsp/source/completion.go
@@ -1445,7 +1445,7 @@
 
 			return &clInfo
 		default:
-			if breaksExpectedTypeInference(n) {
+			if breaksExpectedTypeInference(n, pos) {
 				return nil
 			}
 		}
@@ -1535,11 +1535,11 @@
 type typeMod int
 
 const (
-	star     typeMod = iota // pointer indirection for expressions, pointer indicator for types
-	address                 // address operator ("&")
-	chanRead                // channel read operator ("<-")
-	slice                   // make a slice type ("[]" in "[]int")
-	array                   // make an array type ("[2]" in "[2]int")
+	dereference typeMod = iota // pointer indirection: "*"
+	reference                  // adds level of pointer: "&" for values, "*" for type names
+	chanRead                   // channel read operator ("<-")
+	slice                      // make a slice type ("[]" in "[]int")
+	array                      // make an array type ("[2]" in "[2]int")
 )
 
 type objKind int
@@ -1651,6 +1651,10 @@
 	// seenTypeSwitchCases tracks types that have already been used by
 	// the containing type switch.
 	seenTypeSwitchCases []types.Type
+
+	// compLitType is true if we are completing a composite literal type
+	// name, e.g "foo<>{}".
+	compLitType bool
 }
 
 // expectedCandidate returns information about the expected candidate
@@ -1862,11 +1866,11 @@
 			}
 			return inf
 		case *ast.StarExpr:
-			inf.modifiers = append(inf.modifiers, typeModifier{mod: star})
+			inf.modifiers = append(inf.modifiers, typeModifier{mod: dereference})
 		case *ast.UnaryExpr:
 			switch node.Op {
 			case token.AND:
-				inf.modifiers = append(inf.modifiers, typeModifier{mod: address})
+				inf.modifiers = append(inf.modifiers, typeModifier{mod: reference})
 			case token.ARROW:
 				inf.modifiers = append(inf.modifiers, typeModifier{mod: chanRead})
 			}
@@ -1874,7 +1878,7 @@
 			inf.objKind |= kindFunc
 			return inf
 		default:
-			if breaksExpectedTypeInference(node) {
+			if breaksExpectedTypeInference(node, c.pos) {
 				return inf
 			}
 		}
@@ -1928,7 +1932,7 @@
 func (ci candidateInference) applyTypeModifiers(typ types.Type, addressable bool) types.Type {
 	for _, mod := range ci.modifiers {
 		switch mod.mod {
-		case star:
+		case dereference:
 			// For every "*" indirection operator, remove a pointer layer
 			// from candidate type.
 			if ptr, ok := typ.Underlying().(*types.Pointer); ok {
@@ -1936,7 +1940,7 @@
 			} else {
 				return nil
 			}
-		case address:
+		case reference:
 			// For every "&" address operator, add another pointer layer to
 			// candidate type, if the candidate is addressable.
 			if addressable {
@@ -1961,8 +1965,7 @@
 func (ci candidateInference) applyTypeNameModifiers(typ types.Type) types.Type {
 	for _, mod := range ci.typeName.modifiers {
 		switch mod.mod {
-		case star:
-			// For every "*" indicator, add a pointer layer to type name.
+		case reference:
 			typ = types.NewPointer(typ)
 		case array:
 			typ = types.NewArray(typ, mod.arrayLen)
@@ -2006,9 +2009,17 @@
 // breaksExpectedTypeInference reports if an expression node's type is unrelated
 // to its child expression node types. For example, "Foo{Bar: x.Baz(<>)}" should
 // expect a function argument, not a composite literal value.
-func breaksExpectedTypeInference(n ast.Node) bool {
-	switch n.(type) {
-	case *ast.FuncLit, *ast.CallExpr, *ast.IndexExpr, *ast.SliceExpr, *ast.CompositeLit:
+func breaksExpectedTypeInference(n ast.Node, pos token.Pos) bool {
+	switch n := n.(type) {
+	case *ast.CompositeLit:
+		// Doesn't break inference if pos is in type name.
+		// For example: "Foo<>{Bar: 123}"
+		return !nodeContains(n.Type, pos)
+	case *ast.CallExpr:
+		// Doesn't break inference if pos is in func name.
+		// For example: "Foo<>(123)"
+		return !nodeContains(n.Fun, pos)
+	case *ast.FuncLit, *ast.IndexExpr, *ast.SliceExpr:
 		return true
 	default:
 		return false
@@ -2017,13 +2028,7 @@
 
 // expectTypeName returns information about the expected type name at position.
 func expectTypeName(c *completer) typeNameInference {
-	var (
-		wantTypeName        bool
-		wantComparable      bool
-		modifiers           []typeModifier
-		assertableFrom      types.Type
-		seenTypeSwitchCases []types.Type
-	)
+	var inf typeNameInference
 
 Nodes:
 	for i, p := range c.path {
@@ -2034,7 +2039,7 @@
 			// InterfaceType. We don't need to worry about the field name
 			// because completion bails out early if pos is in an *ast.Ident
 			// that defines an object.
-			wantTypeName = true
+			inf.wantTypeName = true
 			break Nodes
 		case *ast.CaseClause:
 			// Expect type names in type switch case clauses.
@@ -2042,12 +2047,12 @@
 				// The case clause types must be assertable from the type switch parameter.
 				ast.Inspect(swtch.Assign, func(n ast.Node) bool {
 					if ta, ok := n.(*ast.TypeAssertExpr); ok {
-						assertableFrom = c.pkg.GetTypesInfo().TypeOf(ta.X)
+						inf.assertableFrom = c.pkg.GetTypesInfo().TypeOf(ta.X)
 						return false
 					}
 					return true
 				})
-				wantTypeName = true
+				inf.wantTypeName = true
 
 				// Track the types that have already been used in this
 				// switch's case statements so we don't recommend them.
@@ -2060,7 +2065,7 @@
 						}
 
 						if t := c.pkg.GetTypesInfo().TypeOf(typeExpr); t != nil {
-							seenTypeSwitchCases = append(seenTypeSwitchCases, t)
+							inf.seenTypeSwitchCases = append(inf.seenTypeSwitchCases, t)
 						}
 					}
 				}
@@ -2072,33 +2077,43 @@
 			// Expect type names in type assert expressions.
 			if n.Lparen < c.pos && c.pos <= n.Rparen {
 				// The type in parens must be assertable from the expression type.
-				assertableFrom = c.pkg.GetTypesInfo().TypeOf(n.X)
-				wantTypeName = true
+				inf.assertableFrom = c.pkg.GetTypesInfo().TypeOf(n.X)
+				inf.wantTypeName = true
 				break Nodes
 			}
 			return typeNameInference{}
 		case *ast.StarExpr:
-			modifiers = append(modifiers, typeModifier{mod: star})
+			inf.modifiers = append(inf.modifiers, typeModifier{mod: reference})
 		case *ast.CompositeLit:
 			// We want a type name if position is in the "Type" part of a
 			// composite literal (e.g. "Foo<>{}").
 			if n.Type != nil && n.Type.Pos() <= c.pos && c.pos <= n.Type.End() {
-				wantTypeName = true
+				inf.wantTypeName = true
+				inf.compLitType = true
+
+				if i < len(c.path)-1 {
+					// Track preceding "&" operator. Technically it applies to
+					// the composite literal and not the type name, but if
+					// affects our type completion nonetheless.
+					if u, ok := c.path[i+1].(*ast.UnaryExpr); ok && u.Op == token.AND {
+						inf.modifiers = append(inf.modifiers, typeModifier{mod: reference})
+					}
+				}
 			}
 			break Nodes
 		case *ast.ArrayType:
 			// If we are inside the "Elt" part of an array type, we want a type name.
 			if n.Elt.Pos() <= c.pos && c.pos <= n.Elt.End() {
-				wantTypeName = true
+				inf.wantTypeName = true
 				if n.Len == nil {
 					// No "Len" expression means a slice type.
-					modifiers = append(modifiers, typeModifier{mod: slice})
+					inf.modifiers = append(inf.modifiers, typeModifier{mod: slice})
 				} else {
 					// Try to get the array type using the constant value of "Len".
 					tv, ok := c.pkg.GetTypesInfo().Types[n.Len]
 					if ok && tv.Value != nil && tv.Value.Kind() == constant.Int {
 						if arrayLen, ok := constant.Int64Val(tv.Value); ok {
-							modifiers = append(modifiers, typeModifier{mod: array, arrayLen: arrayLen})
+							inf.modifiers = append(inf.modifiers, typeModifier{mod: array, arrayLen: arrayLen})
 						}
 					}
 				}
@@ -2114,34 +2129,28 @@
 				break Nodes
 			}
 		case *ast.MapType:
-			wantTypeName = true
+			inf.wantTypeName = true
 			if n.Key != nil {
-				wantComparable = nodeContains(n.Key, c.pos)
+				inf.wantComparable = nodeContains(n.Key, c.pos)
 			} else {
 				// If the key is empty, assume we are completing the key if
 				// pos is directly after the "map[".
-				wantComparable = c.pos == n.Pos()+token.Pos(len("map["))
+				inf.wantComparable = c.pos == n.Pos()+token.Pos(len("map["))
 			}
 			break Nodes
 		case *ast.ValueSpec:
-			wantTypeName = nodeContains(n.Type, c.pos)
+			inf.wantTypeName = nodeContains(n.Type, c.pos)
 			break Nodes
 		case *ast.TypeSpec:
-			wantTypeName = nodeContains(n.Type, c.pos)
+			inf.wantTypeName = nodeContains(n.Type, c.pos)
 		default:
-			if breaksExpectedTypeInference(p) {
+			if breaksExpectedTypeInference(p, c.pos) {
 				return typeNameInference{}
 			}
 		}
 	}
 
-	return typeNameInference{
-		wantTypeName:        wantTypeName,
-		wantComparable:      wantComparable,
-		modifiers:           modifiers,
-		assertableFrom:      assertableFrom,
-		seenTypeSwitchCases: seenTypeSwitchCases,
-	}
+	return inf
 }
 
 func (c *completer) fakeObj(T types.Type) *types.Var {
@@ -2519,7 +2528,15 @@
 	}
 
 	if !isInterface(t) && typeMatches(types.NewPointer(t)) {
-		cand.makePointer = true
+		if c.inference.typeName.compLitType {
+			// If we are completing a composite literal type as in
+			// "foo<>{}", to make a pointer we must prepend "&".
+			cand.takeAddress = true
+		} else {
+			// If we are completing a normal type name such as "foo<>", to
+			// make a pointer we must prepend "*".
+			cand.makePointer = true
+		}
 		return true
 	}
 
diff --git a/internal/lsp/testdata/lsp/primarymod/complit/complit.go.in b/internal/lsp/testdata/lsp/primarymod/complit/complit.go.in
index ec6544e..465a72c 100644
--- a/internal/lsp/testdata/lsp/primarymod/complit/complit.go.in
+++ b/internal/lsp/testdata/lsp/primarymod/complit/complit.go.in
@@ -95,6 +95,20 @@
 }
 
 func _() {
+	type foo struct{} //@item(complitFoo, "foo", "struct{...}", "struct")
+
+	"&foo" //@item(complitAndFoo, "&foo", "struct{...}", "struct")
+
+	var _ *foo = &fo{} //@rank("{", complitFoo)
+	var _ *foo = fo{} //@rank("{", complitAndFoo)
+
+	struct { a, b *foo }{
+		a: &fo{}, //@rank("{", complitFoo)
+		b: fo{}, //@rank("{", complitAndFoo)
+	}
+}
+
+func _() {
 	_ := position{
 		X: 1, //@complete("X", fieldX),complete(" 1", exportedFunc, multilineWithPrefix, structPosition, cVar, exportedConst, exportedType)
 		Y: ,  //@complete(":", fieldY),complete(" ,", exportedFunc, multilineWithPrefix, structPosition, cVar, exportedConst, exportedType)
diff --git a/internal/lsp/testdata/lsp/primarymod/snippets/literal_snippets.go.in b/internal/lsp/testdata/lsp/primarymod/snippets/literal_snippets.go.in
index d970bf1..4a505e3 100644
--- a/internal/lsp/testdata/lsp/primarymod/snippets/literal_snippets.go.in
+++ b/internal/lsp/testdata/lsp/primarymod/snippets/literal_snippets.go.in
@@ -199,6 +199,12 @@
 	ptrStruct{
 		p: &ptrSt, //@rank(",", litPtrStruct)
 	}
+
+	&ptrStruct{} //@item(litPtrStructPtr, "&ptrStruct{}", "", "var")
+
+	&ptrStruct{
+		p: ptrSt, //@rank(",", litPtrStructPtr)
+	}
 }
 
 func _() {
diff --git a/internal/lsp/testdata/lsp/summary.txt.golden b/internal/lsp/testdata/lsp/summary.txt.golden
index 8ff02af..5d0f432 100644
--- a/internal/lsp/testdata/lsp/summary.txt.golden
+++ b/internal/lsp/testdata/lsp/summary.txt.golden
@@ -6,7 +6,7 @@
 UnimportedCompletionsCount = 6
 DeepCompletionsCount = 5
 FuzzyCompletionsCount = 8
-RankedCompletionsCount = 152
+RankedCompletionsCount = 157
 CaseSensitiveCompletionsCount = 4
 DiagnosticsCount = 44
 FoldingRangesCount = 2