blob: 447e30b28d88aaa77f77f9a35eb382b51dd2951c [file] [log] [blame]
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package modernize
import (
"fmt"
"go/ast"
"go/token"
"go/types"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/inspect"
"golang.org/x/tools/go/ast/inspector"
"golang.org/x/tools/go/types/typeutil"
"golang.org/x/tools/internal/analysisinternal"
"golang.org/x/tools/internal/analysisinternal/generated"
typeindexanalyzer "golang.org/x/tools/internal/analysisinternal/typeindex"
"golang.org/x/tools/internal/typeparams"
"golang.org/x/tools/internal/typesinternal/typeindex"
)
var SlicesContainsAnalyzer = &analysis.Analyzer{
Name: "slicescontains",
Doc: analysisinternal.MustExtractDoc(doc, "slicescontains"),
Requires: []*analysis.Analyzer{
generated.Analyzer,
inspect.Analyzer,
typeindexanalyzer.Analyzer,
},
Run: slicescontains,
URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/modernize#slicescontains",
}
// The slicescontains pass identifies loops that can be replaced by a
// call to slices.Contains{,Func}. For example:
//
// for i, elem := range s {
// if elem == needle {
// ...
// break
// }
// }
//
// =>
//
// if slices.Contains(s, needle) { ... }
//
// Variants:
// - if the if-condition is f(elem), the replacement
// uses slices.ContainsFunc(s, f).
// - if the if-body is "return true" and the fallthrough
// statement is "return false" (or vice versa), the
// loop becomes "return [!]slices.Contains(...)".
// - if the if-body is "found = true" and the previous
// statement is "found = false" (or vice versa), the
// loop becomes "found = [!]slices.Contains(...)".
//
// It may change cardinality of effects of the "needle" expression.
// (Mostly this appears to be a desirable optimization, avoiding
// redundantly repeated evaluation.)
//
// TODO(adonovan): Add a check that needle/predicate expression from
// if-statement has no effects. Now the program behavior may change.
func slicescontains(pass *analysis.Pass) (any, error) {
skipGenerated(pass)
// Skip the analyzer in packages where its
// fixes would create an import cycle.
if within(pass, "slices", "runtime") {
return nil, nil
}
var (
inspect = pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
index = pass.ResultOf[typeindexanalyzer.Analyzer].(*typeindex.Index)
info = pass.TypesInfo
)
// check is called for each RangeStmt of this form:
// for i, elem := range s { if cond { ... } }
check := func(file *ast.File, curRange inspector.Cursor) {
rng := curRange.Node().(*ast.RangeStmt)
ifStmt := rng.Body.List[0].(*ast.IfStmt)
// isSliceElem reports whether e denotes the
// current slice element (elem or s[i]).
isSliceElem := func(e ast.Expr) bool {
if rng.Value != nil && equalSyntax(e, rng.Value) {
return true // "elem"
}
if x, ok := e.(*ast.IndexExpr); ok &&
equalSyntax(x.X, rng.X) &&
equalSyntax(x.Index, rng.Key) {
return true // "s[i]"
}
return false
}
// Examine the condition for one of these forms:
//
// - if elem or s[i] == needle { ... } => Contains
// - if predicate(s[i] or elem) { ... } => ContainsFunc
var (
funcName string // "Contains" or "ContainsFunc"
arg2 ast.Expr // second argument to func (needle or predicate)
)
switch cond := ifStmt.Cond.(type) {
case *ast.BinaryExpr:
if cond.Op == token.EQL {
var elem ast.Expr
if isSliceElem(cond.X) {
funcName = "Contains"
elem = cond.X
arg2 = cond.Y // "if elem == needle"
} else if isSliceElem(cond.Y) {
funcName = "Contains"
elem = cond.Y
arg2 = cond.X // "if needle == elem"
}
// Reject if elem and needle have different types.
if elem != nil {
tElem := info.TypeOf(elem)
tNeedle := info.TypeOf(arg2)
if !types.Identical(tElem, tNeedle) {
// Avoid ill-typed slices.Contains([]error, any).
if !types.AssignableTo(tNeedle, tElem) {
return
}
// TODO(adonovan): relax this check to allow
// slices.Contains([]error, error(any)),
// inserting an explicit widening conversion
// around the needle.
return
}
}
}
case *ast.CallExpr:
if len(cond.Args) == 1 &&
isSliceElem(cond.Args[0]) &&
typeutil.Callee(info, cond) != nil { // not a conversion
// Attempt to get signature
sig, isSignature := info.TypeOf(cond.Fun).(*types.Signature)
if isSignature {
// skip variadic functions
if sig.Variadic() {
return
}
// Slice element type must match function parameter type.
var (
tElem = typeparams.CoreType(info.TypeOf(rng.X)).(*types.Slice).Elem()
tParam = sig.Params().At(0).Type()
)
if !types.Identical(tElem, tParam) {
return
}
}
funcName = "ContainsFunc"
arg2 = cond.Fun // "if predicate(elem)"
}
}
if funcName == "" {
return // not a candidate for Contains{,Func}
}
// body is the "true" body.
body := ifStmt.Body
if len(body.List) == 0 {
// (We could perhaps delete the loop entirely.)
return
}
// Reject if the body, needle or predicate references either range variable.
usesRangeVar := func(n ast.Node) bool {
cur, ok := curRange.FindNode(n)
if !ok {
panic(fmt.Sprintf("FindNode(%T) failed", n))
}
return uses(index, cur, info.Defs[rng.Key.(*ast.Ident)]) ||
rng.Value != nil && uses(index, cur, info.Defs[rng.Value.(*ast.Ident)])
}
if usesRangeVar(body) {
// Body uses range var "i" or "elem".
//
// (The check for "i" could be relaxed when we
// generalize this to support slices.Index;
// and the check for "elem" could be relaxed
// if "elem" can safely be replaced in the
// body by "needle".)
return
}
if usesRangeVar(arg2) {
return
}
// Prepare slices.Contains{,Func} call.
_, prefix, importEdits := analysisinternal.AddImport(info, file, "slices", "slices", funcName, rng.Pos())
contains := fmt.Sprintf("%s%s(%s, %s)",
prefix,
funcName,
analysisinternal.Format(pass.Fset, rng.X),
analysisinternal.Format(pass.Fset, arg2))
report := func(edits []analysis.TextEdit) {
pass.Report(analysis.Diagnostic{
Pos: rng.Pos(),
End: rng.End(),
Message: fmt.Sprintf("Loop can be simplified using slices.%s", funcName),
SuggestedFixes: []analysis.SuggestedFix{{
Message: "Replace loop by call to slices." + funcName,
TextEdits: append(edits, importEdits...),
}},
})
}
// Last statement of body must return/break out of the loop.
//
// TODO(adonovan): opt:consider avoiding FindNode with new API of form:
// curRange.Get(edge.RangeStmt_Body, -1).
// Get(edge.BodyStmt_List, 0).
// Get(edge.IfStmt_Body)
curBody, _ := curRange.FindNode(body)
curLastStmt, _ := curBody.LastChild()
// Reject if any statement in the body except the
// last has a free continuation (continue or break)
// that might affected by melting down the loop.
//
// TODO(adonovan): relax check by analyzing branch target.
for curBodyStmt := range curBody.Children() {
if curBodyStmt != curLastStmt {
for range curBodyStmt.Preorder((*ast.BranchStmt)(nil), (*ast.ReturnStmt)(nil)) {
return
}
}
}
switch lastStmt := curLastStmt.Node().(type) {
case *ast.ReturnStmt:
// Have: for ... range seq { if ... { stmts; return x } }
// Special case:
// body={ return true } next="return false" (or negation)
// => return [!]slices.Contains(...)
if curNext, ok := curRange.NextSibling(); ok {
nextStmt := curNext.Node().(ast.Stmt)
tval := isReturnTrueOrFalse(info, lastStmt)
fval := isReturnTrueOrFalse(info, nextStmt)
if len(body.List) == 1 && tval*fval < 0 {
// for ... { if ... { return true/false } }
// => return [!]slices.Contains(...)
report([]analysis.TextEdit{
// Delete the range statement and following space.
{
Pos: rng.Pos(),
End: nextStmt.Pos(),
},
// Change return to [!]slices.Contains(...).
{
Pos: nextStmt.Pos(),
End: nextStmt.End(),
NewText: fmt.Appendf(nil, "return %s%s",
cond(tval > 0, "", "!"),
contains),
},
})
return
}
}
// General case:
// => if slices.Contains(...) { stmts; return x }
report([]analysis.TextEdit{
// Replace "for ... { if ... " with "if slices.Contains(...)".
{
Pos: rng.Pos(),
End: ifStmt.Body.Pos(),
NewText: fmt.Appendf(nil, "if %s ", contains),
},
// Delete '}' of range statement and preceding space.
{
Pos: ifStmt.Body.End(),
End: rng.End(),
},
})
return
case *ast.BranchStmt:
if lastStmt.Tok == token.BREAK && lastStmt.Label == nil { // unlabeled break
// Have: for ... { if ... { stmts; break } }
var prevStmt ast.Stmt // previous statement to range (if any)
if curPrev, ok := curRange.PrevSibling(); ok {
// If the RangeStmt's previous sibling is a Stmt,
// the RangeStmt must be among the Body list of
// a BlockStmt, CauseClause, or CommClause.
// In all cases, the prevStmt is the immediate
// predecessor of the RangeStmt during execution.
//
// (This is not true for Stmts in general;
// see [Cursor.Children] and #71074.)
prevStmt, _ = curPrev.Node().(ast.Stmt)
}
// Special case:
// prev="lhs = false" body={ lhs = true; break }
// => lhs = slices.Contains(...) (or negation)
if assign, ok := body.List[0].(*ast.AssignStmt); ok &&
len(body.List) == 2 &&
assign.Tok == token.ASSIGN &&
len(assign.Lhs) == 1 &&
len(assign.Rhs) == 1 {
// Have: body={ lhs = rhs; break }
if prevAssign, ok := prevStmt.(*ast.AssignStmt); ok &&
len(prevAssign.Lhs) == 1 &&
len(prevAssign.Rhs) == 1 &&
equalSyntax(prevAssign.Lhs[0], assign.Lhs[0]) &&
is[*ast.Ident](assign.Rhs[0]) &&
info.Uses[assign.Rhs[0].(*ast.Ident)] == builtinTrue {
// Have:
// lhs = false
// for ... { if ... { lhs = true; break } }
// =>
// lhs = slices.Contains(...)
//
// TODO(adonovan):
// - support "var lhs bool = false" and variants.
// - support negation.
// Both these variants seem quite significant.
// - allow the break to be omitted.
report([]analysis.TextEdit{
// Replace "rhs" of previous assignment by slices.Contains(...)
{
Pos: prevAssign.Rhs[0].Pos(),
End: prevAssign.Rhs[0].End(),
NewText: []byte(contains),
},
// Delete the loop and preceding space.
{
Pos: prevAssign.Rhs[0].End(),
End: rng.End(),
},
})
return
}
}
// General case:
// for ... { if ... { stmts; break } }
// => if slices.Contains(...) { stmts }
report([]analysis.TextEdit{
// Replace "for ... { if ... " with "if slices.Contains(...)".
{
Pos: rng.Pos(),
End: ifStmt.Body.Pos(),
NewText: fmt.Appendf(nil, "if %s ", contains),
},
// Delete break statement and preceding space.
{
Pos: func() token.Pos {
if len(body.List) > 1 {
beforeBreak, _ := curLastStmt.PrevSibling()
return beforeBreak.Node().End()
}
return lastStmt.Pos()
}(),
End: lastStmt.End(),
},
// Delete '}' of range statement and preceding space.
{
Pos: ifStmt.Body.End(),
End: rng.End(),
},
})
return
}
}
}
for curFile := range filesUsing(inspect, info, "go1.21") {
file := curFile.Node().(*ast.File)
for curRange := range curFile.Preorder((*ast.RangeStmt)(nil)) {
rng := curRange.Node().(*ast.RangeStmt)
if is[*ast.Ident](rng.Key) &&
rng.Tok == token.DEFINE &&
len(rng.Body.List) == 1 &&
is[*types.Slice](typeparams.CoreType(info.TypeOf(rng.X))) {
// Have:
// - for _, elem := range s { S }
// - for i := range s { S }
if ifStmt, ok := rng.Body.List[0].(*ast.IfStmt); ok &&
ifStmt.Init == nil && ifStmt.Else == nil {
// Have: for i, elem := range s { if cond { ... } }
check(file, curRange)
}
}
}
}
return nil, nil
}
// -- helpers --
// isReturnTrueOrFalse returns nonzero if stmt returns true (+1) or false (-1).
func isReturnTrueOrFalse(info *types.Info, stmt ast.Stmt) int {
if ret, ok := stmt.(*ast.ReturnStmt); ok && len(ret.Results) == 1 {
if id, ok := ret.Results[0].(*ast.Ident); ok {
switch info.Uses[id] {
case builtinTrue:
return +1
case builtinFalse:
return -1
}
}
}
return 0
}