internal/lsp: add goreturns like functionality as quickfix

This change ports the functionality of https://github.com/sqs/goreturns
to be used as code actions on diagnostics that have missing
return values. It improves on the original goreturns functionality by:

- filling out empty return statements
- trying to match existing return values to the required return
  values and then filling in missing parameters

Fixes golang/go#37091

Change-Id: Ifaf9bf571c3bc3c61e672b0a2f725d8d734d432d
Reviewed-on: https://go-review.googlesource.com/c/tools/+/224960
Run-TryBot: Rohan Challa <rohan@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
diff --git a/gopls/doc/analyzers.md b/gopls/doc/analyzers.md
index 47a65cf..90cf758 100644
--- a/gopls/doc/analyzers.md
+++ b/gopls/doc/analyzers.md
@@ -321,6 +321,26 @@
 
 Default value: `true`.
 
+### **fillreturns**
+
+suggested fixes for "wrong number of return values (want %d, got %d)"
+
+This checker provides suggested fixes for type errors of the
+type "wrong number of return values (want %d, got %d)". For example:
+```go
+func m() (int, string, *bool, error) {
+  return
+}
+```
+will turn into
+```go
+func m() (int, string, *bool, error) {
+  return 0, "", nil, nil
+}
+```
+
+This functionality is similar to [goreturns](https://github.com/sqs/goreturns).
+
 ### **nonewvars**
 
 suggested fixes for "no new vars on left side of :="
diff --git a/internal/analysisinternal/analysis.go b/internal/analysisinternal/analysis.go
index 39f4bb6..2658681 100644
--- a/internal/analysisinternal/analysis.go
+++ b/internal/analysisinternal/analysis.go
@@ -7,8 +7,13 @@
 
 import (
 	"bytes"
+	"fmt"
+	"go/ast"
 	"go/token"
 	"go/types"
+	"strings"
+
+	"golang.org/x/tools/go/ast/astutil"
 )
 
 func TypeErrorEndPos(fset *token.FileSet, src []byte, start token.Pos) token.Pos {
@@ -23,6 +28,84 @@
 	return end
 }
 
+func ZeroValue(fset *token.FileSet, f *ast.File, pkg *types.Package, typ types.Type) ast.Expr {
+	under := typ
+	if n, ok := typ.(*types.Named); ok {
+		under = n.Underlying()
+	}
+	switch u := under.(type) {
+	case *types.Basic:
+		switch {
+		case u.Info()&types.IsNumeric != 0:
+			return &ast.BasicLit{Kind: token.INT, Value: "0"}
+		case u.Info()&types.IsBoolean != 0:
+			return &ast.Ident{Name: "false"}
+		case u.Info()&types.IsString != 0:
+			return &ast.BasicLit{Kind: token.STRING, Value: `""`}
+		default:
+			panic("unknown basic type")
+		}
+	case *types.Chan, *types.Interface, *types.Map, *types.Pointer, *types.Signature, *types.Slice:
+		return ast.NewIdent("nil")
+	case *types.Struct:
+		texpr := typeExpr(fset, f, pkg, typ) // typ because we want the name here.
+		if texpr == nil {
+			return nil
+		}
+		return &ast.CompositeLit{
+			Type: texpr,
+		}
+	case *types.Array:
+		texpr := typeExpr(fset, f, pkg, u.Elem())
+		if texpr == nil {
+			return nil
+		}
+		return &ast.CompositeLit{
+			Type: &ast.ArrayType{
+				Elt: texpr,
+				Len: &ast.BasicLit{Kind: token.INT, Value: fmt.Sprintf("%v", u.Len())},
+			},
+		}
+	}
+	return nil
+}
+
+func typeExpr(fset *token.FileSet, f *ast.File, pkg *types.Package, typ types.Type) ast.Expr {
+	switch t := typ.(type) {
+	case *types.Basic:
+		switch t.Kind() {
+		case types.UnsafePointer:
+			return &ast.SelectorExpr{X: ast.NewIdent("unsafe"), Sel: ast.NewIdent("Pointer")}
+		default:
+			return ast.NewIdent(t.Name())
+		}
+	case *types.Named:
+		if t.Obj().Pkg() == pkg {
+			return ast.NewIdent(t.Obj().Name())
+		}
+		pkgName := t.Obj().Pkg().Name()
+		// If the file already imports the package under another name, use that.
+		for _, group := range astutil.Imports(fset, f) {
+			for _, cand := range group {
+				if strings.Trim(cand.Path.Value, `"`) == t.Obj().Pkg().Path() {
+					if cand.Name != nil && cand.Name.Name != "" {
+						pkgName = cand.Name.Name
+					}
+				}
+			}
+		}
+		if pkgName == "." {
+			return ast.NewIdent(t.Obj().Name())
+		}
+		return &ast.SelectorExpr{
+			X:   ast.NewIdent(pkgName),
+			Sel: ast.NewIdent(t.Obj().Name()),
+		}
+	default:
+		return nil // TODO: anonymous structs, but who does that
+	}
+}
+
 var GetTypeErrors = func(p interface{}) []types.Error { return nil }
 var SetTypeErrors = func(p interface{}, errors []types.Error) {}
 
diff --git a/internal/lsp/analysis/fillreturns/fillreturns.go b/internal/lsp/analysis/fillreturns/fillreturns.go
new file mode 100644
index 0000000..a75f645
--- /dev/null
+++ b/internal/lsp/analysis/fillreturns/fillreturns.go
@@ -0,0 +1,205 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package fillreturns defines an Analyzer that will attempt to
+// automatically fill in a return statement that has missing
+// values with zero value elements.
+package fillreturns
+
+import (
+	"bytes"
+	"go/ast"
+	"go/format"
+	"go/types"
+	"regexp"
+	"strconv"
+	"strings"
+
+	"golang.org/x/tools/go/analysis"
+	"golang.org/x/tools/go/ast/astutil"
+	"golang.org/x/tools/internal/analysisinternal"
+)
+
+const Doc = `suggested fixes for "wrong number of return values (want %d, got %d)"
+
+This checker provides suggested fixes for type errors of the
+type "wrong number of return values (want %d, got %d)". For example:
+	func m() (int, string, *bool, error) {
+		return
+	}
+will turn into
+	func m() (int, string, *bool, error) {
+		return 0, "", nil, nil
+	}
+
+This functionality is similar to https://github.com/sqs/goreturns.
+`
+
+var Analyzer = &analysis.Analyzer{
+	Name:             "fillreturns",
+	Doc:              Doc,
+	Requires:         []*analysis.Analyzer{},
+	Run:              run,
+	RunDespiteErrors: true,
+}
+
+var wrongReturnNumRegex = regexp.MustCompile(`wrong number of return values \(want (\d+), got (\d+)\)`)
+
+func run(pass *analysis.Pass) (interface{}, error) {
+	errors := analysisinternal.GetTypeErrors(pass)
+	// Filter out the errors that are not relevant to this analyzer.
+	for _, typeErr := range errors {
+		matches := wrongReturnNumRegex.FindStringSubmatch(strings.TrimSpace(typeErr.Msg))
+		if len(matches) < 3 {
+			continue
+		}
+		wantNum, err := strconv.Atoi(matches[1])
+		if err != nil {
+			continue
+		}
+		gotNum, err := strconv.Atoi(matches[2])
+		if err != nil {
+			continue
+		}
+		// Logic for handling more return values than expected is hard.
+		if wantNum < gotNum {
+			continue
+		}
+
+		var file *ast.File
+		for _, f := range pass.Files {
+			if f.Pos() <= typeErr.Pos && typeErr.Pos <= f.End() {
+				file = f
+				break
+			}
+		}
+		if file == nil {
+			continue
+		}
+
+		// Get the end position of the error.
+		var buf bytes.Buffer
+		if err := format.Node(&buf, pass.Fset, file); err != nil {
+			continue
+		}
+		typeErrEndPos := analysisinternal.TypeErrorEndPos(pass.Fset, buf.Bytes(), typeErr.Pos)
+
+		// Get the path for the relevant range.
+		path, _ := astutil.PathEnclosingInterval(file, typeErr.Pos, typeErrEndPos)
+		if len(path) == 0 {
+			return nil, nil
+		}
+		// Check to make sure the node of interest is a ReturnStmt.
+		ret, ok := path[0].(*ast.ReturnStmt)
+		if !ok {
+			return nil, nil
+		}
+
+		// Get the function that encloses the ReturnStmt.
+		var enclosingFunc *ast.FuncType
+	Outer:
+		for _, n := range path {
+			switch node := n.(type) {
+			case *ast.FuncLit:
+				enclosingFunc = node.Type
+				break Outer
+			case *ast.FuncDecl:
+				enclosingFunc = node.Type
+				break Outer
+			}
+		}
+		if enclosingFunc == nil {
+			continue
+		}
+		numRetValues := len(ret.Results)
+		typeInfo := pass.TypesInfo
+
+		// skip if return value has a func call (whose multiple returns might be expanded)
+		for _, expr := range ret.Results {
+			e, ok := expr.(*ast.CallExpr)
+			if !ok {
+				continue
+			}
+			ident, ok := e.Fun.(*ast.Ident)
+			if !ok || ident.Obj == nil {
+				continue
+			}
+			fn, ok := ident.Obj.Decl.(*ast.FuncDecl)
+			if !ok {
+				continue
+			}
+			if len(fn.Type.Results.List) != 1 {
+				continue
+			}
+			if typeInfo == nil {
+				continue
+			}
+			if _, ok := typeInfo.TypeOf(e).(*types.Tuple); ok {
+				continue
+			}
+		}
+
+		// Fill in the missing arguments with zero-values.
+		returnCount := 0
+		zvs := make([]ast.Expr, len(enclosingFunc.Results.List))
+		for i, result := range enclosingFunc.Results.List {
+			zv := analysisinternal.ZeroValue(pass.Fset, file, pass.Pkg, typeInfo.TypeOf(result.Type))
+			if zv == nil {
+				return nil, nil
+			}
+			// We do not have any existing return values, fill in with zero-values.
+			if returnCount >= numRetValues {
+				zvs[i] = zv
+				continue
+			}
+			// Compare the types to see if they are the same.
+			current := ret.Results[returnCount]
+			if equalTypes(typeInfo.TypeOf(current), typeInfo.TypeOf(result.Type)) {
+				zvs[i] = current
+				returnCount += 1
+				continue
+			}
+			zvs[i] = zv
+		}
+		newRet := &ast.ReturnStmt{
+			Return:  ret.Pos(),
+			Results: zvs,
+		}
+
+		// Convert the new return statement ast to text.
+		var newBuf bytes.Buffer
+		if err := format.Node(&newBuf, pass.Fset, newRet); err != nil {
+			return nil, err
+		}
+
+		pass.Report(analysis.Diagnostic{
+			Pos:     typeErr.Pos,
+			End:     typeErrEndPos,
+			Message: typeErr.Msg,
+			SuggestedFixes: []analysis.SuggestedFix{{
+				Message: "Fill with empty values",
+				TextEdits: []analysis.TextEdit{{
+					Pos:     ret.Pos(),
+					End:     ret.End(),
+					NewText: newBuf.Bytes(),
+				}},
+			}},
+		})
+	}
+	return nil, nil
+}
+
+func equalTypes(t1, t2 types.Type) bool {
+	if t1 == t2 || types.Identical(t1, t2) {
+		return true
+	}
+	// Code segment to help check for untyped equality from (golang/go#32146).
+	if rhs, ok := t1.(*types.Basic); ok && rhs.Info()&types.IsUntyped > 0 {
+		if lhs, ok := t2.Underlying().(*types.Basic); ok {
+			return rhs.Info()&types.IsConstType == lhs.Info()&types.IsConstType
+		}
+	}
+	// TODO: Figure out if we want to check for types.AssignableTo(t1, t2) || types.ConvertibleTo(t1, t2)
+	return false
+}
diff --git a/internal/lsp/analysis/fillreturns/fillreturns_test.go b/internal/lsp/analysis/fillreturns/fillreturns_test.go
new file mode 100644
index 0000000..d1ad656
--- /dev/null
+++ b/internal/lsp/analysis/fillreturns/fillreturns_test.go
@@ -0,0 +1,17 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fillreturns_test
+
+import (
+	"testing"
+
+	"golang.org/x/tools/go/analysis/analysistest"
+	"golang.org/x/tools/internal/lsp/analysis/fillreturns"
+)
+
+func Test(t *testing.T) {
+	testdata := analysistest.TestData()
+	analysistest.RunWithSuggestedFixes(t, testdata, fillreturns.Analyzer, "a")
+}
diff --git a/internal/lsp/analysis/fillreturns/testdata/src/a/a.go b/internal/lsp/analysis/fillreturns/testdata/src/a/a.go
new file mode 100644
index 0000000..e65d42a
--- /dev/null
+++ b/internal/lsp/analysis/fillreturns/testdata/src/a/a.go
@@ -0,0 +1,70 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fillreturns
+
+import (
+	"errors"
+	ast2 "go/ast"
+	"io"
+	"net/http"
+	. "net/http"
+	"net/url"
+)
+
+type T struct{}
+type T1 = T
+type I interface{}
+type I1 = I
+type z func(string, http.Handler) error
+
+func x() error {
+	return errors.New("foo")
+}
+
+func b() (string, int, error) {
+	return "", errors.New("foo") // want "wrong number of return values \\(want 3, got 2\\)"
+}
+
+func c() (string, int, error) {
+	return 7, errors.New("foo") // want "wrong number of return values \\(want 3, got 2\\)"
+}
+
+func d() (string, int, error) {
+	return "", 7 // want "wrong number of return values \\(want 3, got 2\\)"
+}
+
+func e() (T, error, *bool) {
+	return (z(http.ListenAndServe))("", nil) // want "wrong number of return values \\(want 3, got 1\\)"
+}
+
+func closure() (string, error) {
+	_ = func() (int, error) {
+		return errors.New("foo") // want "wrong number of return values \\(want 2, got 1\\)"
+	}
+	return // want "wrong number of return values \\(want 2, got 0\\)"
+}
+
+func basic() (uint8, uint16, uint32, uint64, int8, int16, int32, int64, float32, float64, complex64, complex128, byte, rune, uint, int, uintptr, string, bool, error) {
+	return // want "wrong number of return values \\(want 20, got 0\\)"
+}
+
+func complex() (*int, []int, [2]int, map[int]int) {
+	return // want "wrong number of return values \\(want 4, got 0\\)"
+}
+
+func structsAndInterfaces() (T, url.URL, T1, I, I1, io.Reader, Client, ast2.Stmt) {
+	return // want "wrong number of return values \\(want 8, got 0\\)"
+}
+
+func m() (int, error) {
+	if 1 == 2 {
+		return // want "wrong number of return values \\(want 2, got 0\\)"
+	} else if 1 == 3 {
+		return // want "wrong number of return values \\(want 2, got 0\\)"
+	} else {
+		return // want "wrong number of return values \\(want 2, got 0\\)"
+	}
+	return // want "wrong number of return values \\(want 2, got 0\\)"
+}
diff --git a/internal/lsp/analysis/fillreturns/testdata/src/a/a.go.golden b/internal/lsp/analysis/fillreturns/testdata/src/a/a.go.golden
new file mode 100644
index 0000000..e5a4abb
--- /dev/null
+++ b/internal/lsp/analysis/fillreturns/testdata/src/a/a.go.golden
@@ -0,0 +1,70 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fillreturns
+
+import (
+	"errors"
+	ast2 "go/ast"
+	"io"
+	"net/http"
+	. "net/http"
+	"net/url"
+)
+
+type T struct{}
+type T1 = T
+type I interface{}
+type I1 = I
+type z func(string, http.Handler) error
+
+func x() error {
+	return errors.New("foo")
+}
+
+func b() (string, int, error) {
+	return "", 0, errors.New("foo") // want "wrong number of return values \\(want 3, got 2\\)"
+}
+
+func c() (string, int, error) {
+	return "", 7, errors.New("foo") // want "wrong number of return values \\(want 3, got 2\\)"
+}
+
+func d() (string, int, error) {
+	return "", 7, nil // want "wrong number of return values \\(want 3, got 2\\)"
+}
+
+func e() (T, error, *bool) {
+	return T{}, (z(http.ListenAndServe))("", nil), nil // want "wrong number of return values \\(want 3, got 1\\)"
+}
+
+func closure() (string, error) {
+	_ = func() (int, error) {
+		return 0, errors.New("foo") // want "wrong number of return values \\(want 2, got 1\\)"
+	}
+	return "", nil // want "wrong number of return values \\(want 2, got 0\\)"
+}
+
+func basic() (uint8, uint16, uint32, uint64, int8, int16, int32, int64, float32, float64, complex64, complex128, byte, rune, uint, int, uintptr, string, bool, error) {
+	return 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, "", false, nil // want "wrong number of return values \\(want 20, got 0\\)"
+}
+
+func complex() (*int, []int, [2]int, map[int]int) {
+	return nil, nil, [2]int{}, nil // want "wrong number of return values \\(want 4, got 0\\)"
+}
+
+func structsAndInterfaces() (T, url.URL, T1, I, I1, io.Reader, Client, ast2.Stmt) {
+	return T{}, url.URL{}, T{}, nil, nil, nil, Client{}, nil // want "wrong number of return values \\(want 8, got 0\\)"
+}
+
+func m() (int, error) {
+	if 1 == 2 {
+		return 0, nil // want "wrong number of return values \\(want 2, got 0\\)"
+	} else if 1 == 3 {
+		return 0, nil // want "wrong number of return values \\(want 2, got 0\\)"
+	} else {
+		return 0, nil // want "wrong number of return values \\(want 2, got 0\\)"
+	}
+	return 0, nil // want "wrong number of return values \\(want 2, got 0\\)"
+}
diff --git a/internal/lsp/source/options.go b/internal/lsp/source/options.go
index 568a20c..aa4f66e 100644
--- a/internal/lsp/source/options.go
+++ b/internal/lsp/source/options.go
@@ -36,6 +36,7 @@
 	"golang.org/x/tools/go/analysis/passes/unreachable"
 	"golang.org/x/tools/go/analysis/passes/unsafeptr"
 	"golang.org/x/tools/go/analysis/passes/unusedresult"
+	"golang.org/x/tools/internal/lsp/analysis/fillreturns"
 	"golang.org/x/tools/internal/lsp/analysis/nonewvars"
 	"golang.org/x/tools/internal/lsp/analysis/noresultvalues"
 	"golang.org/x/tools/internal/lsp/analysis/simplifycompositelit"
@@ -489,6 +490,7 @@
 
 func typeErrorAnalyzers() map[string]Analyzer {
 	return map[string]Analyzer{
+		fillreturns.Analyzer.Name:    {Analyzer: fillreturns.Analyzer, Enabled: true},
 		nonewvars.Analyzer.Name:      {Analyzer: nonewvars.Analyzer, Enabled: true},
 		noresultvalues.Analyzer.Name: {Analyzer: noresultvalues.Analyzer, Enabled: true},
 		undeclaredname.Analyzer.Name: {Analyzer: undeclaredname.Analyzer, Enabled: true},
diff --git a/internal/lsp/source/util.go b/internal/lsp/source/util.go
index 16fa1be..627d7f4 100644
--- a/internal/lsp/source/util.go
+++ b/internal/lsp/source/util.go
@@ -16,7 +16,6 @@
 	"sort"
 	"strings"
 
-	"golang.org/x/tools/go/ast/astutil"
 	"golang.org/x/tools/internal/lsp/protocol"
 	"golang.org/x/tools/internal/span"
 	errors "golang.org/x/xerrors"
@@ -775,81 +774,3 @@
 		return types.TypeString(T, qf) + "{}"
 	}
 }
-
-func zeroValue(fset *token.FileSet, f *ast.File, pkg *types.Package, typ types.Type) ast.Expr {
-	under := typ
-	if n, ok := typ.(*types.Named); ok {
-		under = n.Underlying()
-	}
-	switch u := under.(type) {
-	case *types.Basic:
-		switch {
-		case u.Info()&types.IsNumeric != 0:
-			return &ast.BasicLit{Kind: token.INT, Value: "0"}
-		case u.Info()&types.IsBoolean != 0:
-			return &ast.Ident{Name: "false"}
-		case u.Info()&types.IsString != 0:
-			return &ast.BasicLit{Kind: token.STRING, Value: `""`}
-		default:
-			panic("unknown basic type")
-		}
-	case *types.Chan, *types.Interface, *types.Map, *types.Pointer, *types.Signature, *types.Slice:
-		return ast.NewIdent("nil")
-	case *types.Struct:
-		texpr := typeExpr(fset, f, pkg, typ) // typ because we want the name here.
-		if texpr == nil {
-			return nil
-		}
-		return &ast.CompositeLit{
-			Type: texpr,
-		}
-	case *types.Array:
-		texpr := typeExpr(fset, f, pkg, u.Elem())
-		if texpr == nil {
-			return nil
-		}
-		return &ast.CompositeLit{
-			Type: &ast.ArrayType{
-				Elt: texpr,
-				Len: &ast.BasicLit{Kind: token.INT, Value: fmt.Sprintf("%v", u.Len())},
-			},
-		}
-	}
-	return nil
-}
-
-func typeExpr(fset *token.FileSet, f *ast.File, pkg *types.Package, typ types.Type) ast.Expr {
-	switch t := typ.(type) {
-	case *types.Basic:
-		switch t.Kind() {
-		case types.UnsafePointer:
-			return &ast.SelectorExpr{X: ast.NewIdent("unsafe"), Sel: ast.NewIdent("Pointer")}
-		default:
-			return ast.NewIdent(t.Name())
-		}
-	case *types.Named:
-		if t.Obj().Pkg() == pkg {
-			return ast.NewIdent(t.Obj().Name())
-		}
-		pkgName := t.Obj().Pkg().Name()
-		// If the file already imports the package under another name, use that.
-		for _, group := range astutil.Imports(fset, f) {
-			for _, cand := range group {
-				if strings.Trim(cand.Path.Value, `"`) == t.Obj().Pkg().Path() {
-					if cand.Name != nil && cand.Name.Name != "" {
-						pkgName = cand.Name.Name
-					}
-				}
-			}
-		}
-		if pkgName == "." {
-			return ast.NewIdent(t.Obj().Name())
-		}
-		return &ast.SelectorExpr{
-			X:   ast.NewIdent(pkgName),
-			Sel: ast.NewIdent(t.Obj().Name()),
-		}
-	default:
-		return nil // TODO: anonymous structs, but who does that
-	}
-}