internal/lsp/source: offer completion "if err != nil { return err }"

Now we offer an error-check-and-return completion candidate when
appropriate:

    func myFunc() (int, error) {
      f, err := os.Open("foo")
      <>
    }

offers the candidate:

    if err != nil {
      return 0, <err>
    }

where <> denotes a placeholder so you can easily alter "err".

The completion will only be offered when:
1. The position is in a function that returns an error as final result
   value, and
2. The statement preceding position is an assignment whose final LHS
   value is an error.

The completion will contain zero values for the non-error return values
as necessary.

Using the above example, the completion will be offered after the user
has typed:

    i
    if
    if err

Basically the candidate will be offered after every keystroke as the
user types "if err".

I call this new type of completion a statement completion - perfect
for when you want to make a statement!

Change-Id: I0a330e1c1fa81a2757d3afc84c24e853f46f26b0
Reviewed-on: https://go-review.googlesource.com/c/tools/+/221613
Run-TryBot: Muir Manders <muir@mnd.rs>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
diff --git a/internal/lsp/source/completion.go b/internal/lsp/source/completion.go
index 4c977da..7eb0e64 100644
--- a/internal/lsp/source/completion.go
+++ b/internal/lsp/source/completion.go
@@ -548,6 +548,10 @@
 		return c.items, c.getSurrounding(), nil
 	}
 
+	// Statement candidates offer an entire statement in certain
+	// contexts, as opposed to a single object.
+	c.addStatementCandidates()
+
 	switch n := path[0].(type) {
 	case *ast.Ident:
 		// Is this the Sel part of a selector?
diff --git a/internal/lsp/source/completion_statements.go b/internal/lsp/source/completion_statements.go
new file mode 100644
index 0000000..7806168
--- /dev/null
+++ b/internal/lsp/source/completion_statements.go
@@ -0,0 +1,163 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package source
+
+import (
+	"fmt"
+	"go/ast"
+	"go/token"
+	"go/types"
+
+	"golang.org/x/tools/internal/lsp/protocol"
+	"golang.org/x/tools/internal/lsp/snippet"
+)
+
+// addStatementCandidates adds full statement completion candidates
+// appropriate for the current context.
+func (c *completer) addStatementCandidates() {
+	c.addErrCheckAndReturn()
+}
+
+// addErrCheckAndReturn offers a completion candidate of the form:
+//
+//     if err != nil {
+//       return nil, err
+//     }
+//
+// The position must be in a function that returns an error, and the
+// statement preceding the position must be an assignment where the
+// final LHS object is an error. addErrCheckAndReturn will synthesize
+// zero values as necessary to make the return statement valid.
+func (c *completer) addErrCheckAndReturn() {
+	if len(c.path) < 2 || c.enclosingFunc == nil || !c.opts.placeholders {
+		return
+	}
+
+	var (
+		errorType = types.Universe.Lookup("error").Type()
+		result    = c.enclosingFunc.sig.Results()
+	)
+	// Make sure our enclosing function returns an error.
+	if result.Len() == 0 || !types.Identical(result.At(result.Len()-1).Type(), errorType) {
+		return
+	}
+
+	prevLine := prevStmt(c.pos, c.path)
+	if prevLine == nil {
+		return
+	}
+
+	// Make sure our preceding statement was as assignment.
+	assign, _ := prevLine.(*ast.AssignStmt)
+	if assign == nil || len(assign.Lhs) == 0 {
+		return
+	}
+
+	lastAssignee := assign.Lhs[len(assign.Lhs)-1]
+
+	// Make sure the final assignee is an error.
+	if !types.Identical(c.pkg.GetTypesInfo().TypeOf(lastAssignee), errorType) {
+		return
+	}
+
+	var (
+		// errText is e.g. "err" in "foo, err := bar()".
+		errText = formatNode(c.snapshot.View().Session().Cache().FileSet(), lastAssignee)
+
+		// Whether we need to include the "if" keyword in our candidate.
+		needsIf = true
+	)
+
+	// "_" isn't a real object.
+	if errText == "_" {
+		return
+	}
+
+	// Below we try to detect if the user has already started typing "if
+	// err" so we can replace what they've typed with our complete
+	// statement.
+	switch n := c.path[0].(type) {
+	case *ast.Ident:
+		switch c.path[1].(type) {
+		case *ast.ExprStmt:
+			// This handles:
+			//
+			//     f, err := os.Open("foo")
+			//     i<>
+
+			// Make sure they are typing "if".
+			if c.matcher.Score("if") <= 0 {
+				return
+			}
+		case *ast.IfStmt:
+			// This handles:
+			//
+			//     f, err := os.Open("foo")
+			//     if er<>
+
+			// Make sure they are typing the error's name.
+			if c.matcher.Score(errText) <= 0 {
+				return
+			}
+
+			needsIf = false
+		default:
+			return
+		}
+	case *ast.IfStmt:
+		// This handles:
+		//
+		//     f, err := os.Open("foo")
+		//     if <>
+
+		// Avoid false positives by ensuring the if's cond is a bad
+		// expression. For example, don't offer the completion in cases
+		// like "if <> somethingElse".
+		if _, bad := n.Cond.(*ast.BadExpr); !bad {
+			return
+		}
+
+		// If "if" is our direct prefix, we need to include it in our
+		// candidate since the existing "if" will be overwritten.
+		needsIf = c.pos == n.Pos()+token.Pos(len("if"))
+	}
+
+	// Build up a snippet that looks like:
+	//
+	//     if err != nil {
+	//       return <zero value>, ..., ${1:err}
+	//     }
+	//
+	// We make the error a placeholder so it is easy to alter the error.
+	var snip snippet.Builder
+	if needsIf {
+		snip.WriteText("if ")
+	}
+	snip.WriteText(fmt.Sprintf("%s != nil {\n\treturn ", errText))
+
+	for i := 0; i < result.Len()-1; i++ {
+		snip.WriteText(formatZeroValue(result.At(i).Type(), c.qf))
+		snip.WriteText(", ")
+	}
+
+	snip.WritePlaceholder(func(b *snippet.Builder) {
+		b.WriteText(errText)
+	})
+
+	snip.WriteText("\n}")
+
+	label := fmt.Sprintf("%[1]s != nil { return %[1]s }", errText)
+	if needsIf {
+		label = "if " + label
+	}
+
+	c.items = append(c.items, CompletionItem{
+		Label: label,
+		// There doesn't seem to be a more appropriate kind.
+		Kind:    protocol.KeywordCompletion,
+		Score:   highScore,
+		snippet: &snip,
+	})
+}
diff --git a/internal/lsp/source/util.go b/internal/lsp/source/util.go
index 86b017f..b691f7c 100644
--- a/internal/lsp/source/util.go
+++ b/internal/lsp/source/util.go
@@ -719,3 +719,53 @@
 	}
 	return nil, nil, errors.Errorf("no file for %s in package %s", uri, pkg.ID())
 }
+
+// prevStmt returns the statement that precedes the statement containing pos.
+// For example:
+//
+//     foo := 1
+//     bar(1 + 2<>)
+//
+// If "<>" is pos, prevStmt returns "foo := 1"
+func prevStmt(pos token.Pos, path []ast.Node) ast.Stmt {
+	var blockLines []ast.Stmt
+	for i := 0; i < len(path) && blockLines == nil; i++ {
+		switch n := path[i].(type) {
+		case *ast.BlockStmt:
+			blockLines = n.List
+		case *ast.CommClause:
+			blockLines = n.Body
+		case *ast.CaseClause:
+			blockLines = n.Body
+		}
+	}
+
+	for i := len(blockLines) - 1; i >= 0; i-- {
+		if blockLines[i].End() < pos {
+			return blockLines[i]
+		}
+	}
+
+	return nil
+}
+
+// formatZeroValue produces Go code representing the zero value of T.
+func formatZeroValue(T types.Type, qf types.Qualifier) string {
+	switch u := T.Underlying().(type) {
+	case *types.Basic:
+		switch {
+		case u.Info()&types.IsNumeric > 0:
+			return "0"
+		case u.Info()&types.IsString > 0:
+			return `""`
+		case u.Info()&types.IsBoolean > 0:
+			return "false"
+		default:
+			panic(fmt.Sprintf("unhandled basic type: %v", u))
+		}
+	case *types.Pointer, *types.Interface, *types.Chan, *types.Map, *types.Slice, *types.Signature:
+		return "nil"
+	default:
+		return types.TypeString(T, qf) + "{}"
+	}
+}
diff --git a/internal/lsp/testdata/lsp/primarymod/statements/if_err_check_return.go b/internal/lsp/testdata/lsp/primarymod/statements/if_err_check_return.go
new file mode 100644
index 0000000..e82b783
--- /dev/null
+++ b/internal/lsp/testdata/lsp/primarymod/statements/if_err_check_return.go
@@ -0,0 +1,27 @@
+package statements
+
+import (
+	"bytes"
+	"io"
+	"os"
+)
+
+func one() (int, float32, io.Writer, *int, []int, bytes.Buffer, error) {
+	/* if err != nil { return err } */ //@item(stmtOneIfErrReturn, "if err != nil { return err }", "", "")
+	/* err != nil { return err } */ //@item(stmtOneErrReturn, "err != nil { return err }", "", "")
+
+	_, err := os.Open("foo")
+	//@snippet("", stmtOneIfErrReturn, "", "if err != nil {\n\treturn 0, 0, nil, nil, nil, bytes.Buffer{\\}, ${1:err}\n\\}")
+
+	_, err = os.Open("foo")
+	i //@snippet(" //", stmtOneIfErrReturn, "", "if err != nil {\n\treturn 0, 0, nil, nil, nil, bytes.Buffer{\\}, ${1:err}\n\\}")
+
+	_, err = os.Open("foo")
+	if er //@snippet(" //", stmtOneErrReturn, "", "err != nil {\n\treturn 0, 0, nil, nil, nil, bytes.Buffer{\\}, ${1:err}\n\\}")
+
+	_, err = os.Open("foo")
+	if //@snippet(" //", stmtOneIfErrReturn, "", "if err != nil {\n\treturn 0, 0, nil, nil, nil, bytes.Buffer{\\}, ${1:err}\n\\}")
+
+	_, err = os.Open("foo")
+	if //@snippet("//", stmtOneIfErrReturn, "", "if err != nil {\n\treturn 0, 0, nil, nil, nil, bytes.Buffer{\\}, ${1:err}\n\\}")
+}
diff --git a/internal/lsp/testdata/lsp/primarymod/statements/if_err_check_return_2.go b/internal/lsp/testdata/lsp/primarymod/statements/if_err_check_return_2.go
new file mode 100644
index 0000000..e2dce80
--- /dev/null
+++ b/internal/lsp/testdata/lsp/primarymod/statements/if_err_check_return_2.go
@@ -0,0 +1,12 @@
+package statements
+
+import "os"
+
+func two() error {
+	var s struct{ err error }
+
+	/* if s.err != nil { return s.err } */ //@item(stmtTwoIfErrReturn, "if s.err != nil { return s.err }", "", "")
+
+	_, s.err = os.Open("foo")
+	//@snippet("", stmtTwoIfErrReturn, "", "if s.err != nil {\n\treturn ${1:s.err}\n\\}")
+}
diff --git a/internal/lsp/testdata/lsp/summary.txt.golden b/internal/lsp/testdata/lsp/summary.txt.golden
index e56b30f..114f94c 100644
--- a/internal/lsp/testdata/lsp/summary.txt.golden
+++ b/internal/lsp/testdata/lsp/summary.txt.golden
@@ -1,7 +1,7 @@
 -- summary --
 CodeLensCount = 0
 CompletionsCount = 231
-CompletionSnippetCount = 68
+CompletionSnippetCount = 74
 UnimportedCompletionsCount = 11
 DeepCompletionsCount = 5
 FuzzyCompletionsCount = 8