internal/lsp: improve expected type determination
Improve expected type determination for the following cases:
- search back further through ast path to handle cases where the
position's node is more than two nodes from the ancestor node with
type information
- generate expected type for return statements
- wrap and unwrap pointerness from expected type when position is
preceded by "*" (dereference) or "&" (reference) operators,
respectively
- fix some false positive expected types when completing the "Fun"
(left) side of a CallExpr
Change-Id: I907ee3e405bd8420031a7b03329de5df1c3493b9
GitHub-Last-Rev: 20a0ac9bf2b5350494c6738f5960676cc50fb454
GitHub-Pull-Request: golang/tools#93
Reviewed-on: https://go-review.googlesource.com/c/tools/+/174477
Run-TryBot: Rebecca Stambler <rstambler@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
diff --git a/internal/lsp/source/completion.go b/internal/lsp/source/completion.go
index 3761542..37a1ca2 100644
--- a/internal/lsp/source/completion.go
+++ b/internal/lsp/source/completion.go
@@ -196,7 +196,6 @@
path: path,
pos: pos,
seen: make(map[types.Object]bool),
- expectedType: expectedType(path, pos, pkg.GetTypesInfo()),
enclosingFunction: enclosingFunction(path, pos, pkg.GetTypesInfo()),
preferTypeNames: preferTypeNames(path, pos),
enclosingCompositeLiteral: lit,
@@ -204,6 +203,8 @@
inCompositeLiteralField: inCompositeLiteralField,
}
+ c.expectedType = expectedType(c)
+
// Composite literals are handled entirely separately.
if c.enclosingCompositeLiteral != nil {
c.expectedType = c.expectedCompositeLiteralType(c.enclosingCompositeLiteral, c.enclosingKeyValue)
@@ -458,14 +459,10 @@
// don't show composite literal completions.
ok = pos <= kv.Colon
}
- case *ast.FuncType, *ast.CallExpr, *ast.TypeAssertExpr:
- // These node types break the type link between the leaf node and
- // the composite literal. The type of the leaf node becomes unrelated
- // to the type of the composite literal, so we return nil to avoid
- // inappropriate completions. For example, "Foo{Bar: x.Baz(<>)}"
- // should complete as a function argument to Baz, not part of the Foo
- // composite literal.
- return nil, nil, false
+ default:
+ if breaksExpectedTypeInference(n) {
+ return nil, nil, false
+ }
}
}
return lit, kv, ok
@@ -538,50 +535,111 @@
}
// expectedType returns the expected type for an expression at the query position.
-func expectedType(path []ast.Node, pos token.Pos, info *types.Info) types.Type {
- for i, node := range path {
- if i == 2 {
- break
- }
+func expectedType(c *completer) types.Type {
+ var (
+ derefCount int // count of deref "*" operators
+ refCount int // count of reference "&" operators
+ typ types.Type
+ )
+
+Nodes:
+ for _, node := range c.path {
switch expr := node.(type) {
case *ast.BinaryExpr:
// Determine if query position comes from left or right of op.
e := expr.X
- if pos < expr.OpPos {
+ if c.pos < expr.OpPos {
e = expr.Y
}
- if tv, ok := info.Types[e]; ok {
- return tv.Type
+ if tv, ok := c.info.Types[e]; ok {
+ typ = tv.Type
+ break Nodes
}
case *ast.AssignStmt:
// Only rank completions if you are on the right side of the token.
- if pos <= expr.TokPos {
- break
- }
- i := indexExprAtPos(pos, expr.Rhs)
- if i >= len(expr.Lhs) {
- i = len(expr.Lhs) - 1
- }
- if tv, ok := info.Types[expr.Lhs[i]]; ok {
- return tv.Type
- }
- case *ast.CallExpr:
- if tv, ok := info.Types[expr.Fun]; ok {
- if sig, ok := tv.Type.(*types.Signature); ok {
- if sig.Params().Len() == 0 {
- return nil
- }
- i := indexExprAtPos(pos, expr.Args)
- // Make sure not to run past the end of expected parameters.
- if i >= sig.Params().Len() {
- i = sig.Params().Len() - 1
- }
- return sig.Params().At(i).Type()
+ if c.pos > expr.TokPos {
+ i := indexExprAtPos(c.pos, expr.Rhs)
+ if i >= len(expr.Lhs) {
+ i = len(expr.Lhs) - 1
}
+ if tv, ok := c.info.Types[expr.Lhs[i]]; ok {
+ typ = tv.Type
+ break Nodes
+ }
+ }
+ return nil
+ case *ast.CallExpr:
+ // Only consider CallExpr args if position falls between parens.
+ if expr.Lparen <= c.pos && c.pos <= expr.Rparen {
+ if tv, ok := c.info.Types[expr.Fun]; ok {
+ if sig, ok := tv.Type.(*types.Signature); ok {
+ if sig.Params().Len() == 0 {
+ return nil
+ }
+ i := indexExprAtPos(c.pos, expr.Args)
+ // Make sure not to run past the end of expected parameters.
+ if i >= sig.Params().Len() {
+ i = sig.Params().Len() - 1
+ }
+ typ = sig.Params().At(i).Type()
+ break Nodes
+ }
+ }
+ }
+ return nil
+ case *ast.ReturnStmt:
+ if sig := c.enclosingFunction; sig != nil {
+ // Find signature result that corresponds to our return expression.
+ if resultIdx := indexExprAtPos(c.pos, expr.Results); resultIdx < len(expr.Results) {
+ if resultIdx < sig.Results().Len() {
+ typ = sig.Results().At(resultIdx).Type()
+ break Nodes
+ }
+ }
+ }
+
+ return nil
+ case *ast.StarExpr:
+ derefCount++
+ case *ast.UnaryExpr:
+ if expr.Op == token.AND {
+ refCount++
+ }
+ default:
+ if breaksExpectedTypeInference(node) {
+ return nil
}
}
}
- return nil
+
+ if typ != nil {
+ // For every "*" deref operator, add another pointer layer to expected type.
+ for i := 0; i < derefCount; i++ {
+ typ = types.NewPointer(typ)
+ }
+ // For every "&" ref operator, remove a pointer layer from expected type.
+ for i := 0; i < refCount; i++ {
+ if ptr, ok := typ.(*types.Pointer); ok {
+ typ = ptr.Elem()
+ } else {
+ break
+ }
+ }
+ }
+
+ return typ
+}
+
+// breaksExpectedTypeInference reports if an expression node's type is unrelated
+// to its child expression node types. For example, "Foo{Bar: x.Baz(<>)}" should
+// expect a function argument, not a composite literal value.
+func breaksExpectedTypeInference(n ast.Node) bool {
+ switch n.(type) {
+ case *ast.FuncLit, *ast.CallExpr, *ast.TypeAssertExpr, *ast.IndexExpr, *ast.SliceExpr, *ast.CompositeLit:
+ return true
+ default:
+ return false
+ }
}
// preferTypeNames checks if given token position is inside func receiver,
diff --git a/internal/lsp/testdata/complit/complit.go.in b/internal/lsp/testdata/complit/complit.go.in
index fec5aad..007cb65 100644
--- a/internal/lsp/testdata/complit/complit.go.in
+++ b/internal/lsp/testdata/complit/complit.go.in
@@ -44,6 +44,7 @@
}
_ = map[int]string{1: "" + s.A} //@complete("}", fieldAB, fieldAA)
_ = map[int]string{1: (func(i int) string { return "" })(s.A)} //@complete(")}", fieldAA, fieldAB)
+ _ = map[int]string{1: func() string { s.A }} //@complete(" }", fieldAA, fieldAB)
}
func _() {
diff --git a/internal/lsp/testdata/func_rank/func_rank.go.in b/internal/lsp/testdata/func_rank/func_rank.go.in
index cb5a1b4..d950d3e 100644
--- a/internal/lsp/testdata/func_rank/func_rank.go.in
+++ b/internal/lsp/testdata/func_rank/func_rank.go.in
@@ -1,8 +1,8 @@
package func_rank
-var stringAVar = "var" //@item(stringAVar, "stringAVar", "string", "var")
+var stringAVar = "var" //@item(stringAVar, "stringAVar", "string", "var")
func stringBFunc() string { return "str" } //@item(stringBFunc, "stringBFunc()", "string", "func")
-type stringer struct{} //@item(stringer, "stringer", "struct{...}", "struct")
+type stringer struct{} //@item(stringer, "stringer", "struct{...}", "struct")
func _() stringer //@complete("tr", stringer, stringAVar, stringBFunc)
@@ -10,3 +10,35 @@
func (stringer) _() {} //@complete("tr", stringer, stringAVar, stringBFunc)
+func _() {
+ var s struct {
+ AA int //@item(rankAA, "AA", "int", "field")
+ AB string //@item(rankAB, "AB", "string", "field")
+ AC int //@item(rankAC, "AC", "int", "field")
+ }
+ fnStr := func(string) {}
+ fnStr(s.A) //@complete(")", rankAB, rankAA, rankAC)
+ fnStr("" + s.A) //@complete(")", rankAB, rankAA, rankAC)
+
+ fnInt := func(int) {}
+ fnInt(-s.A) //@complete(")", rankAA, rankAC, rankAB)
+
+ // no expected type
+ fnInt(func() int { s.A }) //@complete(" }", rankAA, rankAB, rankAC)
+ fnInt(s.A()) //@complete("()", rankAA, rankAB, rankAC)
+ fnInt([]int{}[s.A]) //@complete("])", rankAA, rankAB, rankAC)
+ fnInt([]int{}[:s.A]) //@complete("])", rankAA, rankAB, rankAC)
+
+ fnInt(s.A.(int)) //@complete(".(", rankAA, rankAB, rankAC)
+
+ fnPtr := func(*string) {}
+ fnPtr(&s.A) //@complete(")", rankAB, rankAA, rankAC)
+
+ var aaPtr *string //@item(rankAAPtr, "aaPtr", "*string", "var")
+ var abPtr *int //@item(rankABPtr, "abPtr", "*int", "var")
+ fnInt(*a) //@complete(")", rankABPtr, rankAAPtr)
+
+ _ = func() string {
+ return s.A //@complete(" //", rankAB, rankAA, rankAC)
+ }
+}
diff --git a/internal/lsp/tests/tests.go b/internal/lsp/tests/tests.go
index 326c79f..595ebcc 100644
--- a/internal/lsp/tests/tests.go
+++ b/internal/lsp/tests/tests.go
@@ -28,7 +28,7 @@
// We hardcode the expected number of test cases to ensure that all tests
// are being executed. If a test is added, this number must be changed.
const (
- ExpectedCompletionsCount = 85
+ ExpectedCompletionsCount = 97
ExpectedDiagnosticsCount = 17
ExpectedFormatCount = 5
ExpectedDefinitionsCount = 24