internal/lsp: add goreturns like functionality as quickfix
This change ports the functionality of https://github.com/sqs/goreturns
to be used as code actions on diagnostics that have missing
return values. It improves on the original goreturns functionality by:
- filling out empty return statements
- trying to match existing return values to the required return
values and then filling in missing parameters
Fixes golang/go#37091
Change-Id: Ifaf9bf571c3bc3c61e672b0a2f725d8d734d432d
Reviewed-on: https://go-review.googlesource.com/c/tools/+/224960
Run-TryBot: Rohan Challa <rohan@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
diff --git a/gopls/doc/analyzers.md b/gopls/doc/analyzers.md
index 47a65cf..90cf758 100644
--- a/gopls/doc/analyzers.md
+++ b/gopls/doc/analyzers.md
@@ -321,6 +321,26 @@
Default value: `true`.
+### **fillreturns**
+
+suggested fixes for "wrong number of return values (want %d, got %d)"
+
+This checker provides suggested fixes for type errors of the
+type "wrong number of return values (want %d, got %d)". For example:
+```go
+func m() (int, string, *bool, error) {
+ return
+}
+```
+will turn into
+```go
+func m() (int, string, *bool, error) {
+ return 0, "", nil, nil
+}
+```
+
+This functionality is similar to [goreturns](https://github.com/sqs/goreturns).
+
### **nonewvars**
suggested fixes for "no new vars on left side of :="
diff --git a/internal/analysisinternal/analysis.go b/internal/analysisinternal/analysis.go
index 39f4bb6..2658681 100644
--- a/internal/analysisinternal/analysis.go
+++ b/internal/analysisinternal/analysis.go
@@ -7,8 +7,13 @@
import (
"bytes"
+ "fmt"
+ "go/ast"
"go/token"
"go/types"
+ "strings"
+
+ "golang.org/x/tools/go/ast/astutil"
)
func TypeErrorEndPos(fset *token.FileSet, src []byte, start token.Pos) token.Pos {
@@ -23,6 +28,84 @@
return end
}
+func ZeroValue(fset *token.FileSet, f *ast.File, pkg *types.Package, typ types.Type) ast.Expr {
+ under := typ
+ if n, ok := typ.(*types.Named); ok {
+ under = n.Underlying()
+ }
+ switch u := under.(type) {
+ case *types.Basic:
+ switch {
+ case u.Info()&types.IsNumeric != 0:
+ return &ast.BasicLit{Kind: token.INT, Value: "0"}
+ case u.Info()&types.IsBoolean != 0:
+ return &ast.Ident{Name: "false"}
+ case u.Info()&types.IsString != 0:
+ return &ast.BasicLit{Kind: token.STRING, Value: `""`}
+ default:
+ panic("unknown basic type")
+ }
+ case *types.Chan, *types.Interface, *types.Map, *types.Pointer, *types.Signature, *types.Slice:
+ return ast.NewIdent("nil")
+ case *types.Struct:
+ texpr := typeExpr(fset, f, pkg, typ) // typ because we want the name here.
+ if texpr == nil {
+ return nil
+ }
+ return &ast.CompositeLit{
+ Type: texpr,
+ }
+ case *types.Array:
+ texpr := typeExpr(fset, f, pkg, u.Elem())
+ if texpr == nil {
+ return nil
+ }
+ return &ast.CompositeLit{
+ Type: &ast.ArrayType{
+ Elt: texpr,
+ Len: &ast.BasicLit{Kind: token.INT, Value: fmt.Sprintf("%v", u.Len())},
+ },
+ }
+ }
+ return nil
+}
+
+func typeExpr(fset *token.FileSet, f *ast.File, pkg *types.Package, typ types.Type) ast.Expr {
+ switch t := typ.(type) {
+ case *types.Basic:
+ switch t.Kind() {
+ case types.UnsafePointer:
+ return &ast.SelectorExpr{X: ast.NewIdent("unsafe"), Sel: ast.NewIdent("Pointer")}
+ default:
+ return ast.NewIdent(t.Name())
+ }
+ case *types.Named:
+ if t.Obj().Pkg() == pkg {
+ return ast.NewIdent(t.Obj().Name())
+ }
+ pkgName := t.Obj().Pkg().Name()
+ // If the file already imports the package under another name, use that.
+ for _, group := range astutil.Imports(fset, f) {
+ for _, cand := range group {
+ if strings.Trim(cand.Path.Value, `"`) == t.Obj().Pkg().Path() {
+ if cand.Name != nil && cand.Name.Name != "" {
+ pkgName = cand.Name.Name
+ }
+ }
+ }
+ }
+ if pkgName == "." {
+ return ast.NewIdent(t.Obj().Name())
+ }
+ return &ast.SelectorExpr{
+ X: ast.NewIdent(pkgName),
+ Sel: ast.NewIdent(t.Obj().Name()),
+ }
+ default:
+ return nil // TODO: anonymous structs, but who does that
+ }
+}
+
var GetTypeErrors = func(p interface{}) []types.Error { return nil }
var SetTypeErrors = func(p interface{}, errors []types.Error) {}
diff --git a/internal/lsp/analysis/fillreturns/fillreturns.go b/internal/lsp/analysis/fillreturns/fillreturns.go
new file mode 100644
index 0000000..a75f645
--- /dev/null
+++ b/internal/lsp/analysis/fillreturns/fillreturns.go
@@ -0,0 +1,205 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package fillreturns defines an Analyzer that will attempt to
+// automatically fill in a return statement that has missing
+// values with zero value elements.
+package fillreturns
+
+import (
+ "bytes"
+ "go/ast"
+ "go/format"
+ "go/types"
+ "regexp"
+ "strconv"
+ "strings"
+
+ "golang.org/x/tools/go/analysis"
+ "golang.org/x/tools/go/ast/astutil"
+ "golang.org/x/tools/internal/analysisinternal"
+)
+
+const Doc = `suggested fixes for "wrong number of return values (want %d, got %d)"
+
+This checker provides suggested fixes for type errors of the
+type "wrong number of return values (want %d, got %d)". For example:
+ func m() (int, string, *bool, error) {
+ return
+ }
+will turn into
+ func m() (int, string, *bool, error) {
+ return 0, "", nil, nil
+ }
+
+This functionality is similar to https://github.com/sqs/goreturns.
+`
+
+var Analyzer = &analysis.Analyzer{
+ Name: "fillreturns",
+ Doc: Doc,
+ Requires: []*analysis.Analyzer{},
+ Run: run,
+ RunDespiteErrors: true,
+}
+
+var wrongReturnNumRegex = regexp.MustCompile(`wrong number of return values \(want (\d+), got (\d+)\)`)
+
+func run(pass *analysis.Pass) (interface{}, error) {
+ errors := analysisinternal.GetTypeErrors(pass)
+ // Filter out the errors that are not relevant to this analyzer.
+ for _, typeErr := range errors {
+ matches := wrongReturnNumRegex.FindStringSubmatch(strings.TrimSpace(typeErr.Msg))
+ if len(matches) < 3 {
+ continue
+ }
+ wantNum, err := strconv.Atoi(matches[1])
+ if err != nil {
+ continue
+ }
+ gotNum, err := strconv.Atoi(matches[2])
+ if err != nil {
+ continue
+ }
+ // Logic for handling more return values than expected is hard.
+ if wantNum < gotNum {
+ continue
+ }
+
+ var file *ast.File
+ for _, f := range pass.Files {
+ if f.Pos() <= typeErr.Pos && typeErr.Pos <= f.End() {
+ file = f
+ break
+ }
+ }
+ if file == nil {
+ continue
+ }
+
+ // Get the end position of the error.
+ var buf bytes.Buffer
+ if err := format.Node(&buf, pass.Fset, file); err != nil {
+ continue
+ }
+ typeErrEndPos := analysisinternal.TypeErrorEndPos(pass.Fset, buf.Bytes(), typeErr.Pos)
+
+ // Get the path for the relevant range.
+ path, _ := astutil.PathEnclosingInterval(file, typeErr.Pos, typeErrEndPos)
+ if len(path) == 0 {
+ return nil, nil
+ }
+ // Check to make sure the node of interest is a ReturnStmt.
+ ret, ok := path[0].(*ast.ReturnStmt)
+ if !ok {
+ return nil, nil
+ }
+
+ // Get the function that encloses the ReturnStmt.
+ var enclosingFunc *ast.FuncType
+ Outer:
+ for _, n := range path {
+ switch node := n.(type) {
+ case *ast.FuncLit:
+ enclosingFunc = node.Type
+ break Outer
+ case *ast.FuncDecl:
+ enclosingFunc = node.Type
+ break Outer
+ }
+ }
+ if enclosingFunc == nil {
+ continue
+ }
+ numRetValues := len(ret.Results)
+ typeInfo := pass.TypesInfo
+
+ // skip if return value has a func call (whose multiple returns might be expanded)
+ for _, expr := range ret.Results {
+ e, ok := expr.(*ast.CallExpr)
+ if !ok {
+ continue
+ }
+ ident, ok := e.Fun.(*ast.Ident)
+ if !ok || ident.Obj == nil {
+ continue
+ }
+ fn, ok := ident.Obj.Decl.(*ast.FuncDecl)
+ if !ok {
+ continue
+ }
+ if len(fn.Type.Results.List) != 1 {
+ continue
+ }
+ if typeInfo == nil {
+ continue
+ }
+ if _, ok := typeInfo.TypeOf(e).(*types.Tuple); ok {
+ continue
+ }
+ }
+
+ // Fill in the missing arguments with zero-values.
+ returnCount := 0
+ zvs := make([]ast.Expr, len(enclosingFunc.Results.List))
+ for i, result := range enclosingFunc.Results.List {
+ zv := analysisinternal.ZeroValue(pass.Fset, file, pass.Pkg, typeInfo.TypeOf(result.Type))
+ if zv == nil {
+ return nil, nil
+ }
+ // We do not have any existing return values, fill in with zero-values.
+ if returnCount >= numRetValues {
+ zvs[i] = zv
+ continue
+ }
+ // Compare the types to see if they are the same.
+ current := ret.Results[returnCount]
+ if equalTypes(typeInfo.TypeOf(current), typeInfo.TypeOf(result.Type)) {
+ zvs[i] = current
+ returnCount += 1
+ continue
+ }
+ zvs[i] = zv
+ }
+ newRet := &ast.ReturnStmt{
+ Return: ret.Pos(),
+ Results: zvs,
+ }
+
+ // Convert the new return statement ast to text.
+ var newBuf bytes.Buffer
+ if err := format.Node(&newBuf, pass.Fset, newRet); err != nil {
+ return nil, err
+ }
+
+ pass.Report(analysis.Diagnostic{
+ Pos: typeErr.Pos,
+ End: typeErrEndPos,
+ Message: typeErr.Msg,
+ SuggestedFixes: []analysis.SuggestedFix{{
+ Message: "Fill with empty values",
+ TextEdits: []analysis.TextEdit{{
+ Pos: ret.Pos(),
+ End: ret.End(),
+ NewText: newBuf.Bytes(),
+ }},
+ }},
+ })
+ }
+ return nil, nil
+}
+
+func equalTypes(t1, t2 types.Type) bool {
+ if t1 == t2 || types.Identical(t1, t2) {
+ return true
+ }
+ // Code segment to help check for untyped equality from (golang/go#32146).
+ if rhs, ok := t1.(*types.Basic); ok && rhs.Info()&types.IsUntyped > 0 {
+ if lhs, ok := t2.Underlying().(*types.Basic); ok {
+ return rhs.Info()&types.IsConstType == lhs.Info()&types.IsConstType
+ }
+ }
+ // TODO: Figure out if we want to check for types.AssignableTo(t1, t2) || types.ConvertibleTo(t1, t2)
+ return false
+}
diff --git a/internal/lsp/analysis/fillreturns/fillreturns_test.go b/internal/lsp/analysis/fillreturns/fillreturns_test.go
new file mode 100644
index 0000000..d1ad656
--- /dev/null
+++ b/internal/lsp/analysis/fillreturns/fillreturns_test.go
@@ -0,0 +1,17 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fillreturns_test
+
+import (
+ "testing"
+
+ "golang.org/x/tools/go/analysis/analysistest"
+ "golang.org/x/tools/internal/lsp/analysis/fillreturns"
+)
+
+func Test(t *testing.T) {
+ testdata := analysistest.TestData()
+ analysistest.RunWithSuggestedFixes(t, testdata, fillreturns.Analyzer, "a")
+}
diff --git a/internal/lsp/analysis/fillreturns/testdata/src/a/a.go b/internal/lsp/analysis/fillreturns/testdata/src/a/a.go
new file mode 100644
index 0000000..e65d42a
--- /dev/null
+++ b/internal/lsp/analysis/fillreturns/testdata/src/a/a.go
@@ -0,0 +1,70 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fillreturns
+
+import (
+ "errors"
+ ast2 "go/ast"
+ "io"
+ "net/http"
+ . "net/http"
+ "net/url"
+)
+
+type T struct{}
+type T1 = T
+type I interface{}
+type I1 = I
+type z func(string, http.Handler) error
+
+func x() error {
+ return errors.New("foo")
+}
+
+func b() (string, int, error) {
+ return "", errors.New("foo") // want "wrong number of return values \\(want 3, got 2\\)"
+}
+
+func c() (string, int, error) {
+ return 7, errors.New("foo") // want "wrong number of return values \\(want 3, got 2\\)"
+}
+
+func d() (string, int, error) {
+ return "", 7 // want "wrong number of return values \\(want 3, got 2\\)"
+}
+
+func e() (T, error, *bool) {
+ return (z(http.ListenAndServe))("", nil) // want "wrong number of return values \\(want 3, got 1\\)"
+}
+
+func closure() (string, error) {
+ _ = func() (int, error) {
+ return errors.New("foo") // want "wrong number of return values \\(want 2, got 1\\)"
+ }
+ return // want "wrong number of return values \\(want 2, got 0\\)"
+}
+
+func basic() (uint8, uint16, uint32, uint64, int8, int16, int32, int64, float32, float64, complex64, complex128, byte, rune, uint, int, uintptr, string, bool, error) {
+ return // want "wrong number of return values \\(want 20, got 0\\)"
+}
+
+func complex() (*int, []int, [2]int, map[int]int) {
+ return // want "wrong number of return values \\(want 4, got 0\\)"
+}
+
+func structsAndInterfaces() (T, url.URL, T1, I, I1, io.Reader, Client, ast2.Stmt) {
+ return // want "wrong number of return values \\(want 8, got 0\\)"
+}
+
+func m() (int, error) {
+ if 1 == 2 {
+ return // want "wrong number of return values \\(want 2, got 0\\)"
+ } else if 1 == 3 {
+ return // want "wrong number of return values \\(want 2, got 0\\)"
+ } else {
+ return // want "wrong number of return values \\(want 2, got 0\\)"
+ }
+ return // want "wrong number of return values \\(want 2, got 0\\)"
+}
diff --git a/internal/lsp/analysis/fillreturns/testdata/src/a/a.go.golden b/internal/lsp/analysis/fillreturns/testdata/src/a/a.go.golden
new file mode 100644
index 0000000..e5a4abb
--- /dev/null
+++ b/internal/lsp/analysis/fillreturns/testdata/src/a/a.go.golden
@@ -0,0 +1,70 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fillreturns
+
+import (
+ "errors"
+ ast2 "go/ast"
+ "io"
+ "net/http"
+ . "net/http"
+ "net/url"
+)
+
+type T struct{}
+type T1 = T
+type I interface{}
+type I1 = I
+type z func(string, http.Handler) error
+
+func x() error {
+ return errors.New("foo")
+}
+
+func b() (string, int, error) {
+ return "", 0, errors.New("foo") // want "wrong number of return values \\(want 3, got 2\\)"
+}
+
+func c() (string, int, error) {
+ return "", 7, errors.New("foo") // want "wrong number of return values \\(want 3, got 2\\)"
+}
+
+func d() (string, int, error) {
+ return "", 7, nil // want "wrong number of return values \\(want 3, got 2\\)"
+}
+
+func e() (T, error, *bool) {
+ return T{}, (z(http.ListenAndServe))("", nil), nil // want "wrong number of return values \\(want 3, got 1\\)"
+}
+
+func closure() (string, error) {
+ _ = func() (int, error) {
+ return 0, errors.New("foo") // want "wrong number of return values \\(want 2, got 1\\)"
+ }
+ return "", nil // want "wrong number of return values \\(want 2, got 0\\)"
+}
+
+func basic() (uint8, uint16, uint32, uint64, int8, int16, int32, int64, float32, float64, complex64, complex128, byte, rune, uint, int, uintptr, string, bool, error) {
+ return 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, "", false, nil // want "wrong number of return values \\(want 20, got 0\\)"
+}
+
+func complex() (*int, []int, [2]int, map[int]int) {
+ return nil, nil, [2]int{}, nil // want "wrong number of return values \\(want 4, got 0\\)"
+}
+
+func structsAndInterfaces() (T, url.URL, T1, I, I1, io.Reader, Client, ast2.Stmt) {
+ return T{}, url.URL{}, T{}, nil, nil, nil, Client{}, nil // want "wrong number of return values \\(want 8, got 0\\)"
+}
+
+func m() (int, error) {
+ if 1 == 2 {
+ return 0, nil // want "wrong number of return values \\(want 2, got 0\\)"
+ } else if 1 == 3 {
+ return 0, nil // want "wrong number of return values \\(want 2, got 0\\)"
+ } else {
+ return 0, nil // want "wrong number of return values \\(want 2, got 0\\)"
+ }
+ return 0, nil // want "wrong number of return values \\(want 2, got 0\\)"
+}
diff --git a/internal/lsp/source/options.go b/internal/lsp/source/options.go
index 568a20c..aa4f66e 100644
--- a/internal/lsp/source/options.go
+++ b/internal/lsp/source/options.go
@@ -36,6 +36,7 @@
"golang.org/x/tools/go/analysis/passes/unreachable"
"golang.org/x/tools/go/analysis/passes/unsafeptr"
"golang.org/x/tools/go/analysis/passes/unusedresult"
+ "golang.org/x/tools/internal/lsp/analysis/fillreturns"
"golang.org/x/tools/internal/lsp/analysis/nonewvars"
"golang.org/x/tools/internal/lsp/analysis/noresultvalues"
"golang.org/x/tools/internal/lsp/analysis/simplifycompositelit"
@@ -489,6 +490,7 @@
func typeErrorAnalyzers() map[string]Analyzer {
return map[string]Analyzer{
+ fillreturns.Analyzer.Name: {Analyzer: fillreturns.Analyzer, Enabled: true},
nonewvars.Analyzer.Name: {Analyzer: nonewvars.Analyzer, Enabled: true},
noresultvalues.Analyzer.Name: {Analyzer: noresultvalues.Analyzer, Enabled: true},
undeclaredname.Analyzer.Name: {Analyzer: undeclaredname.Analyzer, Enabled: true},
diff --git a/internal/lsp/source/util.go b/internal/lsp/source/util.go
index 16fa1be..627d7f4 100644
--- a/internal/lsp/source/util.go
+++ b/internal/lsp/source/util.go
@@ -16,7 +16,6 @@
"sort"
"strings"
- "golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/internal/lsp/protocol"
"golang.org/x/tools/internal/span"
errors "golang.org/x/xerrors"
@@ -775,81 +774,3 @@
return types.TypeString(T, qf) + "{}"
}
}
-
-func zeroValue(fset *token.FileSet, f *ast.File, pkg *types.Package, typ types.Type) ast.Expr {
- under := typ
- if n, ok := typ.(*types.Named); ok {
- under = n.Underlying()
- }
- switch u := under.(type) {
- case *types.Basic:
- switch {
- case u.Info()&types.IsNumeric != 0:
- return &ast.BasicLit{Kind: token.INT, Value: "0"}
- case u.Info()&types.IsBoolean != 0:
- return &ast.Ident{Name: "false"}
- case u.Info()&types.IsString != 0:
- return &ast.BasicLit{Kind: token.STRING, Value: `""`}
- default:
- panic("unknown basic type")
- }
- case *types.Chan, *types.Interface, *types.Map, *types.Pointer, *types.Signature, *types.Slice:
- return ast.NewIdent("nil")
- case *types.Struct:
- texpr := typeExpr(fset, f, pkg, typ) // typ because we want the name here.
- if texpr == nil {
- return nil
- }
- return &ast.CompositeLit{
- Type: texpr,
- }
- case *types.Array:
- texpr := typeExpr(fset, f, pkg, u.Elem())
- if texpr == nil {
- return nil
- }
- return &ast.CompositeLit{
- Type: &ast.ArrayType{
- Elt: texpr,
- Len: &ast.BasicLit{Kind: token.INT, Value: fmt.Sprintf("%v", u.Len())},
- },
- }
- }
- return nil
-}
-
-func typeExpr(fset *token.FileSet, f *ast.File, pkg *types.Package, typ types.Type) ast.Expr {
- switch t := typ.(type) {
- case *types.Basic:
- switch t.Kind() {
- case types.UnsafePointer:
- return &ast.SelectorExpr{X: ast.NewIdent("unsafe"), Sel: ast.NewIdent("Pointer")}
- default:
- return ast.NewIdent(t.Name())
- }
- case *types.Named:
- if t.Obj().Pkg() == pkg {
- return ast.NewIdent(t.Obj().Name())
- }
- pkgName := t.Obj().Pkg().Name()
- // If the file already imports the package under another name, use that.
- for _, group := range astutil.Imports(fset, f) {
- for _, cand := range group {
- if strings.Trim(cand.Path.Value, `"`) == t.Obj().Pkg().Path() {
- if cand.Name != nil && cand.Name.Name != "" {
- pkgName = cand.Name.Name
- }
- }
- }
- }
- if pkgName == "." {
- return ast.NewIdent(t.Obj().Name())
- }
- return &ast.SelectorExpr{
- X: ast.NewIdent(pkgName),
- Sel: ast.NewIdent(t.Obj().Name()),
- }
- default:
- return nil // TODO: anonymous structs, but who does that
- }
-}