blob: 5fc7ded76a7a7331245b67d3f1403cccbd1d9ca0 [file] [log] [blame]
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fillreturns
import (
"bytes"
_ "embed"
"fmt"
"go/ast"
"go/format"
"go/token"
"go/types"
"regexp"
"slices"
"strings"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/inspect"
"golang.org/x/tools/go/ast/inspector"
"golang.org/x/tools/gopls/internal/fuzzy"
"golang.org/x/tools/gopls/internal/util/cursorutil"
"golang.org/x/tools/internal/analysis/analyzerutil"
"golang.org/x/tools/internal/moreiters"
"golang.org/x/tools/internal/typesinternal"
)
//go:embed doc.go
var doc string
var Analyzer = &analysis.Analyzer{
Name: "fillreturns",
Doc: analyzerutil.MustExtractDoc(doc, "fillreturns"),
Requires: []*analysis.Analyzer{inspect.Analyzer},
Run: run,
RunDespiteErrors: true,
URL: "https://pkg.go.dev/golang.org/x/tools/gopls/internal/analysis/fillreturns",
}
func run(pass *analysis.Pass) (any, error) {
inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
info := pass.TypesInfo
outer:
for _, typeErr := range pass.TypeErrors {
if !fixesError(typeErr) {
continue // irrelevant type error
}
_, start, end, ok := typesinternal.ErrorCodeStartEnd(typeErr)
if !ok {
continue // no position information
}
curErr, ok := inspect.Root().FindByPos(start, end)
if !ok {
continue // can't find node
}
// Find cursor for enclosing return statement (which may be curErr itself).
ret, curRet := cursorutil.FirstEnclosing[*ast.ReturnStmt](curErr)
if ret == nil {
continue // no enclosing return
}
// Skip if any return argument is a tuple-valued function call.
for _, expr := range ret.Results {
e, ok := expr.(*ast.CallExpr)
if !ok {
continue
}
if tup, ok := info.TypeOf(e).(*types.Tuple); ok && tup.Len() > 1 {
continue outer
}
}
// Get type of innermost enclosing function.
var funcType *ast.FuncType
curFunc, _ := enclosingFunc(curRet) // can't fail
switch fn := curFunc.Node().(type) {
case *ast.FuncLit:
funcType = fn.Type
case *ast.FuncDecl:
funcType = fn.Type
// Skip generic functions since type parameters don't have zero values.
// TODO(rfindley): We should be able to handle this if the return
// values are all concrete types.
if funcType.TypeParams.NumFields() > 0 {
continue
}
}
if funcType.Results == nil {
continue
}
// Duplicate the return values to track which values have been matched.
remaining := make([]ast.Expr, len(ret.Results))
copy(remaining, ret.Results)
fixed := make([]ast.Expr, len(funcType.Results.List))
// For each value in the return function declaration, find the leftmost element
// in the return statement that has the desired type. If no such element exists,
// fill in the missing value with the appropriate "zero" value.
// Beware that type information may be incomplete.
var retTyps []types.Type
for _, ret := range funcType.Results.List {
retTyp := info.TypeOf(ret.Type)
if retTyp == nil {
return nil, nil
}
retTyps = append(retTyps, retTyp)
}
file, _ := cursorutil.FirstEnclosing[*ast.File](curRet)
matches := MatchingIdents(retTyps, file, ret.Pos(), info, pass.Pkg)
qual := typesinternal.FileQualifier(file, pass.Pkg)
for i, retTyp := range retTyps {
var match ast.Expr
var idx int
for j, val := range remaining {
if t := info.TypeOf(val); t == nil || !matchingTypes(t, retTyp) {
continue
}
if !isZeroExpr(val) {
match, idx = val, j
break
}
// If the current match is a "zero" value, we keep searching in
// case we find a non-"zero" value match. If we do not find a
// non-"zero" value, we will use the "zero" value.
match, idx = val, j
}
if match != nil {
fixed[i] = match
remaining = slices.Delete(remaining, idx, idx+1)
} else {
names, ok := matches[retTyp]
if !ok {
return nil, fmt.Errorf("invalid return type: %v", retTyp)
}
// Find the identifier most similar to the return type.
// If no identifier matches the pattern, generate a zero value.
if best := fuzzy.BestMatch(retTyp.String(), names); best != "" {
fixed[i] = ast.NewIdent(best)
} else if zero, isValid := typesinternal.ZeroExpr(retTyp, qual); isValid {
fixed[i] = zero
} else {
return nil, nil
}
}
}
// Remove any non-matching "zero values" from the leftover values.
var nonZeroRemaining []ast.Expr
for _, expr := range remaining {
if !isZeroExpr(expr) {
nonZeroRemaining = append(nonZeroRemaining, expr)
}
}
// Append leftover return values to end of new return statement.
fixed = append(fixed, nonZeroRemaining...)
newRet := &ast.ReturnStmt{
Return: ret.Pos(),
Results: fixed,
}
// Convert the new return statement AST to text.
var newBuf bytes.Buffer
if err := format.Node(&newBuf, pass.Fset, newRet); err != nil {
return nil, err
}
pass.Report(analysis.Diagnostic{
Pos: start,
End: end,
Message: typeErr.Msg,
SuggestedFixes: []analysis.SuggestedFix{{
Message: "Fill in return values",
TextEdits: []analysis.TextEdit{{
Pos: ret.Pos(),
End: ret.End(),
NewText: newBuf.Bytes(),
}},
}},
})
}
return nil, nil
}
func matchingTypes(want, got types.Type) bool {
if want == got || types.Identical(want, got) {
return true
}
// Code segment to help check for untyped equality from (golang/go#32146).
if rhs, ok := want.(*types.Basic); ok && rhs.Info()&types.IsUntyped > 0 {
if lhs, ok := got.Underlying().(*types.Basic); ok {
return rhs.Info()&types.IsConstType == lhs.Info()&types.IsConstType
}
}
return types.AssignableTo(want, got) || types.ConvertibleTo(want, got)
}
// Error messages have changed across Go versions. These regexps capture recent
// incarnations.
//
// TODO(rfindley): once error codes are exported and exposed via go/packages,
// use error codes rather than string matching here.
var wrongReturnNumRegexes = []*regexp.Regexp{
regexp.MustCompile(`wrong number of return values \(want (\d+), got (\d+)\)`),
regexp.MustCompile(`too many return values`),
regexp.MustCompile(`not enough return values`),
}
func fixesError(err types.Error) bool {
msg := strings.TrimSpace(err.Msg)
for _, rx := range wrongReturnNumRegexes {
if rx.MatchString(msg) {
return true
}
}
return false
}
// enclosingFunc returns the cursor for the innermost Func{Decl,Lit}
// that encloses c, if any.
func enclosingFunc(c inspector.Cursor) (inspector.Cursor, bool) {
return moreiters.First(c.Enclosing((*ast.FuncDecl)(nil), (*ast.FuncLit)(nil)))
}
// isZeroExpr uses simple syntactic heuristics to report whether expr
// is a obvious zero value, such as 0, "", nil, or false.
// It cannot do better without type information.
func isZeroExpr(expr ast.Expr) bool {
switch e := expr.(type) {
case *ast.BasicLit:
return e.Value == "0" || e.Value == `""`
case *ast.Ident:
return e.Name == "nil" || e.Name == "false"
default:
return false
}
}
// MatchingIdents finds the names of all identifiers in 'node' that match any of the given types.
// 'pos' represents the position at which the identifiers may be inserted. 'pos' must be within
// the scope of each of identifier we select. Otherwise, we will insert a variable at 'pos' that
// is unrecognized.
//
// This function is shared with the 'fillstruct' analyzer.
func MatchingIdents(typs []types.Type, node ast.Node, pos token.Pos, info *types.Info, pkg *types.Package) map[types.Type][]string {
// Initialize matches to contain the variable types we are searching for.
matches := make(map[types.Type][]string)
for _, typ := range typs {
if typ == nil {
continue // TODO(adonovan): is this reachable?
}
matches[typ] = nil // create entry
}
seen := map[types.Object]struct{}{}
ast.Inspect(node, func(n ast.Node) bool {
if n == nil {
return false
}
// Prevent circular definitions. If 'pos' is within an assignment statement, do not
// allow any identifiers in that assignment statement to be selected. Otherwise,
// we could do the following, where 'x' satisfies the type of 'f0':
//
// x := fakeStruct{f0: x}
//
if assign, ok := n.(*ast.AssignStmt); ok && pos > assign.Pos() && pos <= assign.End() {
return false
}
if n.End() > pos {
return n.Pos() <= pos
}
ident, ok := n.(*ast.Ident)
if !ok || ident.Name == "_" {
return true
}
obj := info.Defs[ident]
if obj == nil || obj.Type() == nil {
return true
}
if _, ok := obj.(*types.TypeName); ok {
return true
}
// Prevent duplicates in matches' values.
if _, ok = seen[obj]; ok {
return true
}
seen[obj] = struct{}{}
// Find the scope for the given position. Then, check whether the object
// exists within the scope.
innerScope := pkg.Scope().Innermost(pos)
if innerScope == nil {
return true
}
_, foundObj := innerScope.LookupParent(ident.Name, pos)
if foundObj != obj {
return true
}
// The object must match one of the types that we are searching for.
// TODO(adonovan): opt: use typeutil.Map?
if names, ok := matches[obj.Type()]; ok {
matches[obj.Type()] = append(names, ident.Name)
} else {
// If the object type does not exactly match
// any of the target types, greedily find the first
// target type that the object type can satisfy.
for typ := range matches {
if equivalentTypes(obj.Type(), typ) {
matches[typ] = append(matches[typ], ident.Name)
}
}
}
return true
})
return matches
}
func equivalentTypes(want, got types.Type) bool {
if types.Identical(want, got) {
return true
}
// Code segment to help check for untyped equality from (golang/go#32146).
if rhs, ok := want.(*types.Basic); ok && rhs.Info()&types.IsUntyped > 0 {
if lhs, ok := got.Underlying().(*types.Basic); ok {
return rhs.Info()&types.IsConstType == lhs.Info()&types.IsConstType
}
}
return types.AssignableTo(want, got)
}