internal/lsp: improve expected type determination

Improve expected type determination for the following cases:

- search back further through ast path to handle cases where the
  position's node is more than two nodes from the ancestor node with
  type information
- generate expected type for return statements
- wrap and unwrap pointerness from expected type when position is
  preceded by "*" (dereference) or "&" (reference) operators,
  respectively
- fix some false positive expected types when completing the "Fun"
  (left) side of a CallExpr

Change-Id: I907ee3e405bd8420031a7b03329de5df1c3493b9
GitHub-Last-Rev: 20a0ac9bf2b5350494c6738f5960676cc50fb454
GitHub-Pull-Request: golang/tools#93
Reviewed-on: https://go-review.googlesource.com/c/tools/+/174477
Run-TryBot: Rebecca Stambler <rstambler@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
diff --git a/internal/lsp/source/completion.go b/internal/lsp/source/completion.go
index 3761542..37a1ca2 100644
--- a/internal/lsp/source/completion.go
+++ b/internal/lsp/source/completion.go
@@ -196,7 +196,6 @@
 		path:                      path,
 		pos:                       pos,
 		seen:                      make(map[types.Object]bool),
-		expectedType:              expectedType(path, pos, pkg.GetTypesInfo()),
 		enclosingFunction:         enclosingFunction(path, pos, pkg.GetTypesInfo()),
 		preferTypeNames:           preferTypeNames(path, pos),
 		enclosingCompositeLiteral: lit,
@@ -204,6 +203,8 @@
 		inCompositeLiteralField:   inCompositeLiteralField,
 	}
 
+	c.expectedType = expectedType(c)
+
 	// Composite literals are handled entirely separately.
 	if c.enclosingCompositeLiteral != nil {
 		c.expectedType = c.expectedCompositeLiteralType(c.enclosingCompositeLiteral, c.enclosingKeyValue)
@@ -458,14 +459,10 @@
 				// don't show composite literal completions.
 				ok = pos <= kv.Colon
 			}
-		case *ast.FuncType, *ast.CallExpr, *ast.TypeAssertExpr:
-			// These node types break the type link between the leaf node and
-			// the composite literal. The type of the leaf node becomes unrelated
-			// to the type of the composite literal, so we return nil to avoid
-			// inappropriate completions. For example, "Foo{Bar: x.Baz(<>)}"
-			// should complete as a function argument to Baz, not part of the Foo
-			// composite literal.
-			return nil, nil, false
+		default:
+			if breaksExpectedTypeInference(n) {
+				return nil, nil, false
+			}
 		}
 	}
 	return lit, kv, ok
@@ -538,50 +535,111 @@
 }
 
 // expectedType returns the expected type for an expression at the query position.
-func expectedType(path []ast.Node, pos token.Pos, info *types.Info) types.Type {
-	for i, node := range path {
-		if i == 2 {
-			break
-		}
+func expectedType(c *completer) types.Type {
+	var (
+		derefCount int // count of deref "*" operators
+		refCount   int // count of reference "&" operators
+		typ        types.Type
+	)
+
+Nodes:
+	for _, node := range c.path {
 		switch expr := node.(type) {
 		case *ast.BinaryExpr:
 			// Determine if query position comes from left or right of op.
 			e := expr.X
-			if pos < expr.OpPos {
+			if c.pos < expr.OpPos {
 				e = expr.Y
 			}
-			if tv, ok := info.Types[e]; ok {
-				return tv.Type
+			if tv, ok := c.info.Types[e]; ok {
+				typ = tv.Type
+				break Nodes
 			}
 		case *ast.AssignStmt:
 			// Only rank completions if you are on the right side of the token.
-			if pos <= expr.TokPos {
-				break
-			}
-			i := indexExprAtPos(pos, expr.Rhs)
-			if i >= len(expr.Lhs) {
-				i = len(expr.Lhs) - 1
-			}
-			if tv, ok := info.Types[expr.Lhs[i]]; ok {
-				return tv.Type
-			}
-		case *ast.CallExpr:
-			if tv, ok := info.Types[expr.Fun]; ok {
-				if sig, ok := tv.Type.(*types.Signature); ok {
-					if sig.Params().Len() == 0 {
-						return nil
-					}
-					i := indexExprAtPos(pos, expr.Args)
-					// Make sure not to run past the end of expected parameters.
-					if i >= sig.Params().Len() {
-						i = sig.Params().Len() - 1
-					}
-					return sig.Params().At(i).Type()
+			if c.pos > expr.TokPos {
+				i := indexExprAtPos(c.pos, expr.Rhs)
+				if i >= len(expr.Lhs) {
+					i = len(expr.Lhs) - 1
 				}
+				if tv, ok := c.info.Types[expr.Lhs[i]]; ok {
+					typ = tv.Type
+					break Nodes
+				}
+			}
+			return nil
+		case *ast.CallExpr:
+			// Only consider CallExpr args if position falls between parens.
+			if expr.Lparen <= c.pos && c.pos <= expr.Rparen {
+				if tv, ok := c.info.Types[expr.Fun]; ok {
+					if sig, ok := tv.Type.(*types.Signature); ok {
+						if sig.Params().Len() == 0 {
+							return nil
+						}
+						i := indexExprAtPos(c.pos, expr.Args)
+						// Make sure not to run past the end of expected parameters.
+						if i >= sig.Params().Len() {
+							i = sig.Params().Len() - 1
+						}
+						typ = sig.Params().At(i).Type()
+						break Nodes
+					}
+				}
+			}
+			return nil
+		case *ast.ReturnStmt:
+			if sig := c.enclosingFunction; sig != nil {
+				// Find signature result that corresponds to our return expression.
+				if resultIdx := indexExprAtPos(c.pos, expr.Results); resultIdx < len(expr.Results) {
+					if resultIdx < sig.Results().Len() {
+						typ = sig.Results().At(resultIdx).Type()
+						break Nodes
+					}
+				}
+			}
+
+			return nil
+		case *ast.StarExpr:
+			derefCount++
+		case *ast.UnaryExpr:
+			if expr.Op == token.AND {
+				refCount++
+			}
+		default:
+			if breaksExpectedTypeInference(node) {
+				return nil
 			}
 		}
 	}
-	return nil
+
+	if typ != nil {
+		// For every "*" deref operator, add another pointer layer to expected type.
+		for i := 0; i < derefCount; i++ {
+			typ = types.NewPointer(typ)
+		}
+		// For every "&" ref operator, remove a pointer layer from expected type.
+		for i := 0; i < refCount; i++ {
+			if ptr, ok := typ.(*types.Pointer); ok {
+				typ = ptr.Elem()
+			} else {
+				break
+			}
+		}
+	}
+
+	return typ
+}
+
+// breaksExpectedTypeInference reports if an expression node's type is unrelated
+// to its child expression node types. For example, "Foo{Bar: x.Baz(<>)}" should
+// expect a function argument, not a composite literal value.
+func breaksExpectedTypeInference(n ast.Node) bool {
+	switch n.(type) {
+	case *ast.FuncLit, *ast.CallExpr, *ast.TypeAssertExpr, *ast.IndexExpr, *ast.SliceExpr, *ast.CompositeLit:
+		return true
+	default:
+		return false
+	}
 }
 
 // preferTypeNames checks if given token position is inside func receiver,
diff --git a/internal/lsp/testdata/complit/complit.go.in b/internal/lsp/testdata/complit/complit.go.in
index fec5aad..007cb65 100644
--- a/internal/lsp/testdata/complit/complit.go.in
+++ b/internal/lsp/testdata/complit/complit.go.in
@@ -44,6 +44,7 @@
 	}
 	_ = map[int]string{1: "" + s.A} //@complete("}", fieldAB, fieldAA)
 	_ = map[int]string{1: (func(i int) string { return "" })(s.A)} //@complete(")}", fieldAA, fieldAB)
+	_ = map[int]string{1: func() string { s.A }} //@complete(" }", fieldAA, fieldAB)
 }
 
 func _() {
diff --git a/internal/lsp/testdata/func_rank/func_rank.go.in b/internal/lsp/testdata/func_rank/func_rank.go.in
index cb5a1b4..d950d3e 100644
--- a/internal/lsp/testdata/func_rank/func_rank.go.in
+++ b/internal/lsp/testdata/func_rank/func_rank.go.in
@@ -1,8 +1,8 @@
 package func_rank
 
-var stringAVar = "var" //@item(stringAVar, "stringAVar", "string", "var")
+var stringAVar = "var"    //@item(stringAVar, "stringAVar", "string", "var")
 func stringBFunc() string { return "str" } //@item(stringBFunc, "stringBFunc()", "string", "func")
-type stringer struct{}   //@item(stringer, "stringer", "struct{...}", "struct")
+type stringer struct{}    //@item(stringer, "stringer", "struct{...}", "struct")
 
 func _() stringer //@complete("tr", stringer, stringAVar, stringBFunc)
 
@@ -10,3 +10,35 @@
 
 func (stringer) _() {} //@complete("tr", stringer, stringAVar, stringBFunc)
 
+func _() {
+	var s struct {
+		AA int    //@item(rankAA, "AA", "int", "field")
+		AB string //@item(rankAB, "AB", "string", "field")
+		AC int    //@item(rankAC, "AC", "int", "field")
+	}
+	fnStr := func(string) {}
+	fnStr(s.A)      //@complete(")", rankAB, rankAA, rankAC)
+	fnStr("" + s.A) //@complete(")", rankAB, rankAA, rankAC)
+
+	fnInt := func(int) {}
+	fnInt(-s.A) //@complete(")", rankAA, rankAC, rankAB)
+
+	// no expected type
+	fnInt(func() int { s.A }) //@complete(" }", rankAA, rankAB, rankAC)
+	fnInt(s.A())              //@complete("()", rankAA, rankAB, rankAC)
+	fnInt([]int{}[s.A])       //@complete("])", rankAA, rankAB, rankAC)
+	fnInt([]int{}[:s.A])      //@complete("])", rankAA, rankAB, rankAC)
+
+	fnInt(s.A.(int)) //@complete(".(", rankAA, rankAB, rankAC)
+
+	fnPtr := func(*string) {}
+	fnPtr(&s.A) //@complete(")", rankAB, rankAA, rankAC)
+
+	var aaPtr *string //@item(rankAAPtr, "aaPtr", "*string", "var")
+	var abPtr *int    //@item(rankABPtr, "abPtr", "*int", "var")
+	fnInt(*a)         //@complete(")", rankABPtr, rankAAPtr)
+
+	_ = func() string {
+		return s.A //@complete(" //", rankAB, rankAA, rankAC)
+	}
+}
diff --git a/internal/lsp/tests/tests.go b/internal/lsp/tests/tests.go
index 326c79f..595ebcc 100644
--- a/internal/lsp/tests/tests.go
+++ b/internal/lsp/tests/tests.go
@@ -28,7 +28,7 @@
 // We hardcode the expected number of test cases to ensure that all tests
 // are being executed. If a test is added, this number must be changed.
 const (
-	ExpectedCompletionsCount       = 85
+	ExpectedCompletionsCount       = 97
 	ExpectedDiagnosticsCount       = 17
 	ExpectedFormatCount            = 5
 	ExpectedDefinitionsCount       = 24