internal/lsp: handle more expected type cases

Calculate expected type in the following cases:

- switch case statements
- index expressions (e.g. []int{}[<>] or map[string]int{}[<>])
- slice expressions (e.g. []int{}[1:<>])
- channel send statements
- channel receive expression

We now also prefer type names in type switch clauses and type asserts.

Change-Id: Iff8c317a9116868b36701d931c802d9147f962d8
GitHub-Last-Rev: e039a45aebe1c6aa9b2011cad67ddaa5e4ed4d77
GitHub-Pull-Request: golang/tools#97
Reviewed-on: https://go-review.googlesource.com/c/tools/+/176941
Run-TryBot: Rebecca Stambler <rstambler@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
diff --git a/internal/lsp/source/completion.go b/internal/lsp/source/completion.go
index d654397..8b193b7 100644
--- a/internal/lsp/source/completion.go
+++ b/internal/lsp/source/completion.go
@@ -619,6 +619,15 @@
 	return nil
 }
 
+// typeModifier represents an operator that changes the expected type.
+type typeModifier int
+
+const (
+	dereference typeModifier = iota // dereference ("*") operator
+	reference                       // reference ("&") operator
+	chanRead                        // channel read ("<-") operator
+)
+
 // expectedType returns the expected type for an expression at the query position.
 func expectedType(c *completer) types.Type {
 	if c.enclosingCompositeLiteral != nil {
@@ -626,19 +635,18 @@
 	}
 
 	var (
-		derefCount int // count of deref "*" operators
-		refCount   int // count of reference "&" operators
-		typ        types.Type
+		modifiers []typeModifier
+		typ       types.Type
 	)
 
 Nodes:
-	for _, node := range c.path {
-		switch expr := node.(type) {
+	for i, node := range c.path {
+		switch node := node.(type) {
 		case *ast.BinaryExpr:
 			// Determine if query position comes from left or right of op.
-			e := expr.X
-			if c.pos < expr.OpPos {
-				e = expr.Y
+			e := node.X
+			if c.pos < node.OpPos {
+				e = node.Y
 			}
 			if tv, ok := c.info.Types[e]; ok {
 				typ = tv.Type
@@ -646,12 +654,12 @@
 			}
 		case *ast.AssignStmt:
 			// Only rank completions if you are on the right side of the token.
-			if c.pos > expr.TokPos {
-				i := indexExprAtPos(c.pos, expr.Rhs)
-				if i >= len(expr.Lhs) {
-					i = len(expr.Lhs) - 1
+			if c.pos > node.TokPos {
+				i := indexExprAtPos(c.pos, node.Rhs)
+				if i >= len(node.Lhs) {
+					i = len(node.Lhs) - 1
 				}
-				if tv, ok := c.info.Types[expr.Lhs[i]]; ok {
+				if tv, ok := c.info.Types[node.Lhs[i]]; ok {
 					typ = tv.Type
 					break Nodes
 				}
@@ -659,13 +667,13 @@
 			return nil
 		case *ast.CallExpr:
 			// Only consider CallExpr args if position falls between parens.
-			if expr.Lparen <= c.pos && c.pos <= expr.Rparen {
-				if tv, ok := c.info.Types[expr.Fun]; ok {
+			if node.Lparen <= c.pos && c.pos <= node.Rparen {
+				if tv, ok := c.info.Types[node.Fun]; ok {
 					if sig, ok := tv.Type.(*types.Signature); ok {
 						if sig.Params().Len() == 0 {
 							return nil
 						}
-						i := indexExprAtPos(c.pos, expr.Args)
+						i := indexExprAtPos(c.pos, node.Args)
 						// Make sure not to run past the end of expected parameters.
 						if i >= sig.Params().Len() {
 							i = sig.Params().Len() - 1
@@ -678,21 +686,65 @@
 			return nil
 		case *ast.ReturnStmt:
 			if sig := c.enclosingFunction; sig != nil {
-				// Find signature result that corresponds to our return expression.
-				if resultIdx := indexExprAtPos(c.pos, expr.Results); resultIdx < len(expr.Results) {
+				// Find signature result that corresponds to our return statement.
+				if resultIdx := indexExprAtPos(c.pos, node.Results); resultIdx < len(node.Results) {
 					if resultIdx < sig.Results().Len() {
 						typ = sig.Results().At(resultIdx).Type()
 						break Nodes
 					}
 				}
 			}
-
+			return nil
+		case *ast.CaseClause:
+			if swtch, ok := findSwitchStmt(c.path[i+1:], c.pos, node).(*ast.SwitchStmt); ok {
+				if tv, ok := c.info.Types[swtch.Tag]; ok {
+					typ = tv.Type
+					break Nodes
+				}
+			}
+			return nil
+		case *ast.SliceExpr:
+			// Make sure position falls within the brackets (e.g. "foo[a:<>]").
+			if node.Lbrack < c.pos && c.pos <= node.Rbrack {
+				typ = types.Typ[types.Int]
+				break Nodes
+			}
+			return nil
+		case *ast.IndexExpr:
+			// Make sure position falls within the brackets (e.g. "foo[<>]").
+			if node.Lbrack < c.pos && c.pos <= node.Rbrack {
+				if tv, ok := c.info.Types[node.X]; ok {
+					switch t := tv.Type.Underlying().(type) {
+					case *types.Map:
+						typ = t.Key()
+					case *types.Slice, *types.Array:
+						typ = types.Typ[types.Int]
+					default:
+						return nil
+					}
+					break Nodes
+				}
+			}
+			return nil
+		case *ast.SendStmt:
+			// Make sure we are on right side of arrow (e.g. "foo <- <>").
+			if c.pos > node.Arrow+1 {
+				if tv, ok := c.info.Types[node.Chan]; ok {
+					if ch, ok := tv.Type.Underlying().(*types.Chan); ok {
+						typ = ch.Elem()
+						break Nodes
+					}
+				}
+			}
 			return nil
 		case *ast.StarExpr:
-			derefCount++
+			modifiers = append(modifiers, dereference)
 		case *ast.UnaryExpr:
-			if expr.Op == token.AND {
-				refCount++
+			switch node.Op {
+			case token.AND:
+				modifiers = append(modifiers, reference)
+			case token.ARROW:
+				modifiers = append(modifiers, chanRead)
 			}
 		default:
 			if breaksExpectedTypeInference(node) {
@@ -702,16 +754,17 @@
 	}
 
 	if typ != nil {
-		// For every "*" deref operator, add another pointer layer to expected type.
-		for i := 0; i < derefCount; i++ {
-			typ = types.NewPointer(typ)
-		}
-		// For every "&" ref operator, remove a pointer layer from expected type.
-		for i := 0; i < refCount; i++ {
-			if ptr, ok := typ.(*types.Pointer); ok {
-				typ = ptr.Elem()
-			} else {
-				break
+		for _, mod := range modifiers {
+			switch mod {
+			case dereference:
+				// For every "*" deref operator, add another pointer layer to expected type.
+				typ = types.NewPointer(typ)
+			case reference:
+				// For every "&" ref operator, remove a pointer layer from expected type.
+				typ = deref(typ)
+			case chanRead:
+				// For every "<-" operator, add another layer of channelness.
+				typ = types.NewChan(types.SendRecv, typ)
 			}
 		}
 	}
@@ -719,6 +772,30 @@
 	return typ
 }
 
+// findSwitchStmt returns an *ast.CaseClause's corresponding *ast.SwitchStmt or
+// *ast.TypeSwitchStmt. path should start from the case clause's first ancestor.
+func findSwitchStmt(path []ast.Node, pos token.Pos, c *ast.CaseClause) ast.Stmt {
+	// Make sure position falls within a "case <>:" clause.
+	if exprAtPos(pos, c.List) == nil {
+		return nil
+	}
+	// A case clause is always nested within a block statement in a switch statement.
+	if len(path) < 2 {
+		return nil
+	}
+	if _, ok := path[0].(*ast.BlockStmt); !ok {
+		return nil
+	}
+	switch s := path[1].(type) {
+	case *ast.SwitchStmt:
+		return s
+	case *ast.TypeSwitchStmt:
+		return s
+	default:
+		return nil
+	}
+}
+
 // breaksExpectedTypeInference reports if an expression node's type is unrelated
 // to its child expression node types. For example, "Foo{Bar: x.Baz(<>)}" should
 // expect a function argument, not a composite literal value.
@@ -737,7 +814,7 @@
 // func (<>) foo(<>) (<>) {}
 //
 func preferTypeNames(path []ast.Node, pos token.Pos) bool {
-	for _, p := range path {
+	for i, p := range path {
 		switch n := p.(type) {
 		case *ast.FuncDecl:
 			if r := n.Recv; r != nil && r.Pos() <= pos && pos <= r.End() {
@@ -752,6 +829,13 @@
 				}
 			}
 			return false
+		case *ast.CaseClause:
+			_, isTypeSwitch := findSwitchStmt(path[i+1:], pos, n).(*ast.TypeSwitchStmt)
+			return isTypeSwitch
+		case *ast.TypeAssertExpr:
+			if n.Lparen < pos && pos <= n.Rparen {
+				return true
+			}
 		}
 	}
 	return false
diff --git a/internal/lsp/testdata/channel/channel.go b/internal/lsp/testdata/channel/channel.go
new file mode 100644
index 0000000..a83b895
--- /dev/null
+++ b/internal/lsp/testdata/channel/channel.go
@@ -0,0 +1,25 @@
+package channel
+
+func _() {
+	var (
+		aa = "123" //@item(channelAA, "aa", "string", "var")
+		ab = 123   //@item(channelAB, "ab", "int", "var")
+	)
+
+	{
+		type myChan chan int
+		var mc myChan
+		mc <- a //@complete(" //", channelAB, channelAA)
+	}
+
+	{
+		var ac chan int //@item(channelAC, "ac", "chan int", "var")
+		a <- a //@complete(" <-", channelAC, channelAA, channelAB)
+	}
+
+	{
+		var foo chan int //@item(channelFoo, "foo", "chan int", "var")
+		wantsInt := func(int) {} //@item(channelWantsInt, "wantsInt", "func(int)", "var")
+		wantsInt(<-) //@complete(")", channelFoo, channelWantsInt, channelAA, channelAB)
+	}
+}
diff --git a/internal/lsp/testdata/func_rank/func_rank.go.in b/internal/lsp/testdata/func_rank/func_rank.go.in
index d950d3e..ca98324 100644
--- a/internal/lsp/testdata/func_rank/func_rank.go.in
+++ b/internal/lsp/testdata/func_rank/func_rank.go.in
@@ -26,8 +26,8 @@
 	// no expected type
 	fnInt(func() int { s.A }) //@complete(" }", rankAA, rankAB, rankAC)
 	fnInt(s.A())              //@complete("()", rankAA, rankAB, rankAC)
-	fnInt([]int{}[s.A])       //@complete("])", rankAA, rankAB, rankAC)
-	fnInt([]int{}[:s.A])      //@complete("])", rankAA, rankAB, rankAC)
+	fnInt([]int{}[s.A])       //@complete("])", rankAA, rankAC, rankAB)
+	fnInt([]int{}[:s.A])      //@complete("])", rankAA, rankAC, rankAB)
 
 	fnInt(s.A.(int)) //@complete(".(", rankAA, rankAB, rankAC)
 
diff --git a/internal/lsp/testdata/index/index.go b/internal/lsp/testdata/index/index.go
new file mode 100644
index 0000000..7e56b51
--- /dev/null
+++ b/internal/lsp/testdata/index/index.go
@@ -0,0 +1,21 @@
+package index
+
+func _() {
+	var (
+		aa = "123" //@item(indexAA, "aa", "string", "var")
+		ab = 123   //@item(indexAB, "ab", "int", "var")
+	)
+
+	var foo [1]int
+	foo[a]  //@complete("]", indexAB, indexAA)
+	foo[:a] //@complete("]", indexAB, indexAA)
+	a[:a]   //@complete("[", indexAA, indexAB)
+	a[a]    //@complete("[", indexAA, indexAB)
+
+	var bar map[string]int
+	bar[a] //@complete("]", indexAA, indexAB)
+
+	type myMap map[string]int
+	var baz myMap
+	baz[a] //@complete("]", indexAA, indexAB)
+}
diff --git a/internal/lsp/testdata/rank/switch_rank.go.in b/internal/lsp/testdata/rank/switch_rank.go.in
new file mode 100644
index 0000000..9e23f6b
--- /dev/null
+++ b/internal/lsp/testdata/rank/switch_rank.go.in
@@ -0,0 +1,12 @@
+package rank
+
+func _() {
+	switch pear {
+	case : //@complete(":", pear, apple)
+	}
+
+	switch pear {
+	case "hi":
+		//@complete("", apple, pear)
+	}
+}
diff --git a/internal/lsp/testdata/rank/type_assert_rank.go.in b/internal/lsp/testdata/rank/type_assert_rank.go.in
new file mode 100644
index 0000000..3490c85
--- /dev/null
+++ b/internal/lsp/testdata/rank/type_assert_rank.go.in
@@ -0,0 +1,8 @@
+package rank
+
+func _() {
+	type flower int //@item(flower, "flower", "int", "type")
+	var fig string  //@item(fig, "fig", "string", "var")
+
+	_ = interface{}(nil).(f) //@complete(") //", flower, fig)
+}
diff --git a/internal/lsp/testdata/rank/type_switch_rank.go.in b/internal/lsp/testdata/rank/type_switch_rank.go.in
new file mode 100644
index 0000000..457c64b
--- /dev/null
+++ b/internal/lsp/testdata/rank/type_switch_rank.go.in
@@ -0,0 +1,11 @@
+package rank
+
+func _() {
+	type basket int   //@item(basket, "basket", "int", "type")
+	var banana string //@item(banana, "banana", "string", "var")
+
+	switch interface{}(pear).(type) {
+	case b: //@complete(":", basket, banana)
+		b //@complete(" //", banana, basket)
+	}
+}
diff --git a/internal/lsp/tests/tests.go b/internal/lsp/tests/tests.go
index 7b06478..d374d70 100644
--- a/internal/lsp/tests/tests.go
+++ b/internal/lsp/tests/tests.go
@@ -28,7 +28,7 @@
 // We hardcode the expected number of test cases to ensure that all tests
 // are being executed. If a test is added, this number must be changed.
 const (
-	ExpectedCompletionsCount       = 107
+	ExpectedCompletionsCount       = 121
 	ExpectedCompletionSnippetCount = 13
 	ExpectedDiagnosticsCount       = 17
 	ExpectedFormatCount            = 5