gopls/internal/analysis/modernize: replace loop with slices.Contains
This CL adds a modernizer pass for slices.Contains{,Func}.
Example:
func assignTrueBreak(slice []int, needle int) {
found := false
for _, elem := range slice { // want "Loop can be simplified using strings.Contains"
if elem == needle {
found = true
break
}
}
print(found)
}
=>
func assignTrueBreak(slice []int, needle int) {
found := slices.Contains(slice, needle)
print(found)
}
Updates golang/go#70815
Change-Id: I72ad1c099481b6c9ae6f732e2d81674a98b79a9f
Reviewed-on: https://go-review.googlesource.com/c/tools/+/640576
Auto-Submit: Alan Donovan <adonovan@google.com>
Reviewed-by: Robert Findley <rfindley@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Commit-Queue: Alan Donovan <adonovan@google.com>
diff --git a/gopls/internal/analysis/modernize/modernize.go b/gopls/internal/analysis/modernize/modernize.go
index 3734618..6cedc5e 100644
--- a/gopls/internal/analysis/modernize/modernize.go
+++ b/gopls/internal/analysis/modernize/modernize.go
@@ -49,6 +49,9 @@
}
report := pass.Report
pass.Report = func(diag analysis.Diagnostic) {
+ if diag.Category == "" {
+ panic("Diagnostic.Category is unset")
+ }
if _, ok := generated[pass.Fset.File(diag.Pos)]; ok {
return // skip checking if it's generated code
}
@@ -62,6 +65,7 @@
fmtappendf(pass)
mapsloop(pass)
minmax(pass)
+ slicescontains(pass)
sortslice(pass)
testingContext(pass)
@@ -120,7 +124,9 @@
builtinAny = types.Universe.Lookup("any")
builtinAppend = types.Universe.Lookup("append")
builtinBool = types.Universe.Lookup("bool")
+ builtinFalse = types.Universe.Lookup("false")
builtinMake = types.Universe.Lookup("make")
builtinNil = types.Universe.Lookup("nil")
+ builtinTrue = types.Universe.Lookup("true")
byteSliceType = types.NewSlice(types.Typ[types.Byte])
)
diff --git a/gopls/internal/analysis/modernize/modernize_test.go b/gopls/internal/analysis/modernize/modernize_test.go
index bf3114e..d8d2d9a 100644
--- a/gopls/internal/analysis/modernize/modernize_test.go
+++ b/gopls/internal/analysis/modernize/modernize_test.go
@@ -19,6 +19,7 @@
"fmtappendf",
"mapsloop",
"minmax",
+ "slicescontains",
"sortslice",
"testingcontext",
)
diff --git a/gopls/internal/analysis/modernize/slices.go b/gopls/internal/analysis/modernize/slices.go
index 1389298..cb73f7e 100644
--- a/gopls/internal/analysis/modernize/slices.go
+++ b/gopls/internal/analysis/modernize/slices.go
@@ -5,6 +5,7 @@
package modernize
// This file defines modernizers that use the "slices" package.
+// TODO(adonovan): actually let's split them up and rename this file.
import (
"fmt"
diff --git a/gopls/internal/analysis/modernize/slicescontains.go b/gopls/internal/analysis/modernize/slicescontains.go
new file mode 100644
index 0000000..062083c
--- /dev/null
+++ b/gopls/internal/analysis/modernize/slicescontains.go
@@ -0,0 +1,365 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package modernize
+
+import (
+ "fmt"
+ "go/ast"
+ "go/token"
+ "go/types"
+
+ "golang.org/x/tools/go/analysis"
+ "golang.org/x/tools/go/analysis/passes/inspect"
+ "golang.org/x/tools/go/ast/inspector"
+ "golang.org/x/tools/go/types/typeutil"
+ "golang.org/x/tools/internal/analysisinternal"
+ "golang.org/x/tools/internal/astutil/cursor"
+ "golang.org/x/tools/internal/typeparams"
+)
+
+// The slicescontains pass identifies loops that can be replaced by a
+// call to slices.Contains{,Func}. For example:
+//
+// for i, elem := range s {
+// if elem == needle {
+// ...
+// break
+// }
+// }
+//
+// =>
+//
+// if slices.Contains(s, needle) { ... }
+//
+// Variants:
+// - if the if-condition is f(elem), the replacement
+// uses slices.ContainsFunc(s, f).
+// - if the if-body is "return true" and the fallthrough
+// statement is "return false" (or vice versa), the
+// loop becomes "return [!]slices.Contains(...)".
+// - if the if-body is "found = true" and the previous
+// statement is "found = false" (or vice versa), the
+// loop becomes "found = [!]slices.Contains(...)".
+//
+// It may change cardinality of effects of the "needle" expression.
+// (Mostly this appears to be a desirable optimization, avoiding
+// redundantly repeated evaluation.)
+func slicescontains(pass *analysis.Pass) {
+ // Don't modify the slices package itself.
+ if pass.Pkg.Path() == "slices" {
+ return
+ }
+
+ info := pass.TypesInfo
+
+ // check is called for each RangeStmt of this form:
+ // for i, elem := range s { if cond { ... } }
+ check := func(file *ast.File, curRange cursor.Cursor) {
+ rng := curRange.Node().(*ast.RangeStmt)
+ ifStmt := rng.Body.List[0].(*ast.IfStmt)
+
+ // isSliceElem reports whether e denotes the
+ // current slice element (elem or s[i]).
+ isSliceElem := func(e ast.Expr) bool {
+ if rng.Value != nil && equalSyntax(e, rng.Value) {
+ return true // "elem"
+ }
+ if x, ok := e.(*ast.IndexExpr); ok &&
+ equalSyntax(x.X, rng.X) &&
+ equalSyntax(x.Index, rng.Key) {
+ return true // "s[i]"
+ }
+ return false
+ }
+
+ // Examine the condition for one of these forms:
+ //
+ // - if elem or s[i] == needle { ... } => Contains
+ // - if predicate(s[i] or elem) { ... } => ContainsFunc
+ var (
+ funcName string // "Contains" or "ContainsFunc"
+ arg2 ast.Expr // second argument to func (needle or predicate)
+ )
+ switch cond := ifStmt.Cond.(type) {
+ case *ast.BinaryExpr:
+ if cond.Op == token.EQL {
+ if isSliceElem(cond.X) {
+ funcName = "Contains"
+ arg2 = cond.Y // "if elem == needle"
+ } else if isSliceElem(cond.Y) {
+ funcName = "Contains"
+ arg2 = cond.X // "if needle == elem"
+ }
+ }
+
+ case *ast.CallExpr:
+ if len(cond.Args) == 1 &&
+ isSliceElem(cond.Args[0]) &&
+ typeutil.Callee(info, cond) != nil { // not a conversion
+
+ funcName = "ContainsFunc"
+ arg2 = cond.Fun // "if predicate(elem)"
+ }
+ }
+ if funcName == "" {
+ return // not a candidate for Contains{,Func}
+ }
+
+ // body is the "true" body.
+ body := ifStmt.Body
+ if len(body.List) == 0 {
+ // (We could perhaps delete the loop entirely.)
+ return
+ }
+
+ // Reject if the body, needle or predicate references either range variable.
+ usesRangeVar := func(n ast.Node) bool {
+ cur, ok := curRange.FindNode(n)
+ if !ok {
+ panic(fmt.Sprintf("FindNode(%T) failed", n))
+ }
+ return uses(info, cur, info.Defs[rng.Key.(*ast.Ident)]) ||
+ rng.Value != nil && uses(info, cur, info.Defs[rng.Value.(*ast.Ident)])
+ }
+ if usesRangeVar(body) {
+ // Body uses range var "i" or "elem".
+ //
+ // (The check for "i" could be relaxed when we
+ // generalize this to support slices.Index;
+ // and the check for "elem" could be relaxed
+ // if "elem" can safely be replaced in the
+ // body by "needle".)
+ return
+ }
+ if usesRangeVar(arg2) {
+ return
+ }
+
+ // Prepare slices.Contains{,Func} call.
+ slicesName, importEdits := analysisinternal.AddImport(info, file, rng.Pos(), "slices", "slices")
+ contains := fmt.Sprintf("%s.%s(%s, %s)",
+ slicesName,
+ funcName,
+ analysisinternal.Format(pass.Fset, rng.X),
+ analysisinternal.Format(pass.Fset, arg2))
+
+ report := func(edits []analysis.TextEdit) {
+ pass.Report(analysis.Diagnostic{
+ Pos: rng.Pos(),
+ End: rng.End(),
+ Category: "slicescontains",
+ Message: fmt.Sprintf("Loop can be simplified using slices.%s", funcName),
+ SuggestedFixes: []analysis.SuggestedFix{{
+ Message: "Replace loop by call to slices." + funcName,
+ TextEdits: append(edits, importEdits...),
+ }},
+ })
+ }
+
+ // Last statement of body must return/break out of the loop.
+ curBody, _ := curRange.FindNode(body)
+ curLastStmt, _ := curBody.LastChild()
+
+ // Reject if any statement in the body except the
+ // last has a free continuation (continue or break)
+ // that might affected by melting down the loop.
+ //
+ // TODO(adonovan): relax check by analyzing branch target.
+ for curBodyStmt := range curBody.Children() {
+ if curBodyStmt != curLastStmt {
+ for range curBodyStmt.Preorder((*ast.BranchStmt)(nil), (*ast.ReturnStmt)(nil)) {
+ return
+ }
+ }
+ }
+
+ switch lastStmt := curLastStmt.Node().(type) {
+ case *ast.ReturnStmt:
+ // Have: for ... range seq { if ... { stmts; return x } }
+
+ // Special case:
+ // body={ return true } next="return false" (or negation)
+ // => return [!]slices.Contains(...)
+ if curNext, ok := curRange.NextSibling(); ok {
+ nextStmt := curNext.Node().(ast.Stmt)
+ tval := isReturnTrueOrFalse(info, lastStmt)
+ fval := isReturnTrueOrFalse(info, nextStmt)
+ if len(body.List) == 1 && tval*fval < 0 {
+ // for ... { if ... { return true/false } }
+ // => return [!]slices.Contains(...)
+ report([]analysis.TextEdit{
+ // Delete the range statement and following space.
+ {
+ Pos: rng.Pos(),
+ End: nextStmt.Pos(),
+ },
+ // Change return to [!]slices.Contains(...).
+ {
+ Pos: nextStmt.Pos(),
+ End: nextStmt.End(),
+ NewText: fmt.Appendf(nil, "return %s%s",
+ cond(tval > 0, "", "!"),
+ contains),
+ },
+ })
+ return
+ }
+ }
+
+ // General case:
+ // => if slices.Contains(...) { stmts; return x }
+ report([]analysis.TextEdit{
+ // Replace "for ... { if ... " with "if slices.Contains(...)".
+ {
+ Pos: rng.Pos(),
+ End: ifStmt.Body.Pos(),
+ NewText: fmt.Appendf(nil, "if %s ", contains),
+ },
+ // Delete '}' of range statement and preceding space.
+ {
+ Pos: ifStmt.Body.End(),
+ End: rng.End(),
+ },
+ })
+ return
+
+ case *ast.BranchStmt:
+ if lastStmt.Tok == token.BREAK && lastStmt.Label == nil { // unlabeled break
+ // Have: for ... { if ... { stmts; break } }
+
+ var prevStmt ast.Stmt // previous statement to range (if any)
+ if curPrev, ok := curRange.PrevSibling(); ok {
+ // If the RangeStmt's previous sibling is a Stmt,
+ // the RangeStmt must be among the Body list of
+ // a BlockStmt, CauseClause, or CommClause.
+ // In all cases, the prevStmt is the immediate
+ // predecessor of the RangeStmt during execution.
+ //
+ // (This is not true for Stmts in general;
+ // see [Cursor.Children] and #71074.)
+ prevStmt, _ = curPrev.Node().(ast.Stmt)
+ }
+
+ // Special case:
+ // prev="lhs = false" body={ lhs = true; break }
+ // => lhs = slices.Contains(...) (or negation)
+ if assign, ok := body.List[0].(*ast.AssignStmt); ok &&
+ len(body.List) == 2 &&
+ assign.Tok == token.ASSIGN &&
+ len(assign.Lhs) == 1 &&
+ len(assign.Rhs) == 1 {
+
+ // Have: body={ lhs = rhs; break }
+
+ if prevAssign, ok := prevStmt.(*ast.AssignStmt); ok &&
+ len(prevAssign.Lhs) == 1 &&
+ len(prevAssign.Rhs) == 1 &&
+ equalSyntax(prevAssign.Lhs[0], assign.Lhs[0]) &&
+ is[*ast.Ident](assign.Rhs[0]) &&
+ info.Uses[assign.Rhs[0].(*ast.Ident)] == builtinTrue {
+
+ // Have:
+ // lhs = false
+ // for ... { if ... { lhs = true; break } }
+ // =>
+ // lhs = slices.Contains(...)
+ //
+ // TODO(adonovan):
+ // - support "var lhs bool = false" and variants.
+ // - support negation.
+ // Both these variants seem quite significant.
+ // - allow the break to be omitted.
+ report([]analysis.TextEdit{
+ // Replace "rhs" of previous assignment by slices.Contains(...)
+ {
+ Pos: prevAssign.Rhs[0].Pos(),
+ End: prevAssign.Rhs[0].End(),
+ NewText: []byte(contains),
+ },
+ // Delete the loop and preceding space.
+ {
+ Pos: prevAssign.Rhs[0].End(),
+ End: rng.End(),
+ },
+ })
+ return
+ }
+ }
+
+ // General case:
+ // for ... { if ... { stmts; break } }
+ // => if slices.Contains(...) { stmts }
+ report([]analysis.TextEdit{
+ // Replace "for ... { if ... " with "if slices.Contains(...)".
+ {
+ Pos: rng.Pos(),
+ End: ifStmt.Body.Pos(),
+ NewText: fmt.Appendf(nil, "if %s ", contains),
+ },
+ // Delete break statement and preceding space.
+ {
+ Pos: func() token.Pos {
+ if len(body.List) > 1 {
+ beforeBreak, _ := curLastStmt.PrevSibling()
+ return beforeBreak.Node().End()
+ }
+ return lastStmt.Pos()
+ }(),
+ End: lastStmt.End(),
+ },
+ // Delete '}' of range statement and preceding space.
+ {
+ Pos: ifStmt.Body.End(),
+ End: rng.End(),
+ },
+ })
+ return
+ }
+ }
+ }
+
+ inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
+ for curFile := range filesUsing(inspect, info, "go1.21") {
+ file := curFile.Node().(*ast.File)
+
+ for curRange := range curFile.Preorder((*ast.RangeStmt)(nil)) {
+ rng := curRange.Node().(*ast.RangeStmt)
+
+ if is[*ast.Ident](rng.Key) &&
+ rng.Tok == token.DEFINE &&
+ len(rng.Body.List) == 1 &&
+ is[*types.Slice](typeparams.CoreType(info.TypeOf(rng.X))) {
+
+ // Have:
+ // - for _, elem := range s { S }
+ // - for i := range s { S }
+
+ if ifStmt, ok := rng.Body.List[0].(*ast.IfStmt); ok &&
+ ifStmt.Init == nil && ifStmt.Else == nil {
+
+ // Have: for i, elem := range s { if cond { ... } }
+ check(file, curRange)
+ }
+ }
+ }
+ }
+}
+
+// -- helpers --
+
+// isReturnTrueOrFalse returns nonzero if stmt returns true (+1) or false (-1).
+func isReturnTrueOrFalse(info *types.Info, stmt ast.Stmt) int {
+ if ret, ok := stmt.(*ast.ReturnStmt); ok && len(ret.Results) == 1 {
+ if id, ok := ret.Results[0].(*ast.Ident); ok {
+ switch info.Uses[id] {
+ case builtinTrue:
+ return +1
+ case builtinFalse:
+ return -1
+ }
+ }
+ }
+ return 0
+}
diff --git a/gopls/internal/analysis/modernize/testdata/src/slicescontains/slicescontains.go b/gopls/internal/analysis/modernize/testdata/src/slicescontains/slicescontains.go
new file mode 100644
index 0000000..ecb7371
--- /dev/null
+++ b/gopls/internal/analysis/modernize/testdata/src/slicescontains/slicescontains.go
@@ -0,0 +1,129 @@
+package slicescontains
+
+import "slices"
+
+var _ = slices.Contains[[]int] // force import of "slices" to avoid duplicate import edits
+
+func nopeNoBreak(slice []int, needle int) {
+ for i := range slice {
+ if slice[i] == needle {
+ println("found")
+ }
+ }
+}
+
+func rangeIndex(slice []int, needle int) {
+ for i := range slice { // want "Loop can be simplified using slices.Contains"
+ if slice[i] == needle {
+ println("found")
+ break
+ }
+ }
+}
+
+func rangeValue(slice []int, needle int) {
+ for _, elem := range slice { // want "Loop can be simplified using slices.Contains"
+ if elem == needle {
+ println("found")
+ break
+ }
+ }
+}
+
+func returns(slice []int, needle int) {
+ for i := range slice { // want "Loop can be simplified using slices.Contains"
+ if slice[i] == needle {
+ println("found")
+ return
+ }
+ }
+}
+
+func assignTrueBreak(slice []int, needle int) {
+ found := false
+ for _, elem := range slice { // want "Loop can be simplified using slices.Contains"
+ if elem == needle {
+ found = true
+ break
+ }
+ }
+ print(found)
+}
+
+func assignFalseBreak(slice []int, needle int) { // TODO: treat this specially like booleanTrue
+ found := true
+ for _, elem := range slice { // want "Loop can be simplified using slices.Contains"
+ if elem == needle {
+ found = false
+ break
+ }
+ }
+ print(found)
+}
+
+func assignFalseBreakInSelectSwitch(slice []int, needle int) {
+ // Exercise RangeStmt in CommClause, CaseClause.
+ select {
+ default:
+ found := false
+ for _, elem := range slice { // want "Loop can be simplified using slices.Contains"
+ if elem == needle {
+ found = true
+ break
+ }
+ }
+ print(found)
+ }
+ switch {
+ default:
+ found := false
+ for _, elem := range slice { // want "Loop can be simplified using slices.Contains"
+ if elem == needle {
+ found = true
+ break
+ }
+ }
+ print(found)
+ }
+}
+
+func returnTrue(slice []int, needle int) bool {
+ for _, elem := range slice { // want "Loop can be simplified using slices.Contains"
+ if elem == needle {
+ return true
+ }
+ }
+ return false
+}
+
+func returnFalse(slice []int, needle int) bool {
+ for _, elem := range slice { // want "Loop can be simplified using slices.Contains"
+ if elem == needle {
+ return false
+ }
+ }
+ return true
+}
+
+func containsFunc(slice []int, needle int) bool {
+ for _, elem := range slice { // want "Loop can be simplified using slices.ContainsFunc"
+ if predicate(elem) {
+ return true
+ }
+ }
+ return false
+}
+
+func nopeLoopBodyHasFreeContinuation(slice []int, needle int) bool {
+ for _, elem := range slice {
+ if predicate(elem) {
+ if needle == 7 {
+ continue // this statement defeats loop elimination
+ }
+ return true
+ }
+ }
+ return false
+}
+
+func predicate(int) bool
diff --git a/gopls/internal/analysis/modernize/testdata/src/slicescontains/slicescontains.go.golden b/gopls/internal/analysis/modernize/testdata/src/slicescontains/slicescontains.go.golden
new file mode 100644
index 0000000..561e42f
--- /dev/null
+++ b/gopls/internal/analysis/modernize/testdata/src/slicescontains/slicescontains.go.golden
@@ -0,0 +1,85 @@
+package slicescontains
+
+import "slices"
+
+var _ = slices.Contains[[]int] // force import of "slices" to avoid duplicate import edits
+
+func nopeNoBreak(slice []int, needle int) {
+ for i := range slice {
+ if slice[i] == needle {
+ println("found")
+ }
+ }
+}
+
+func rangeIndex(slice []int, needle int) {
+ if slices.Contains(slice, needle) {
+ println("found")
+ }
+}
+
+func rangeValue(slice []int, needle int) {
+ if slices.Contains(slice, needle) {
+ println("found")
+ }
+}
+
+func returns(slice []int, needle int) {
+ if slices.Contains(slice, needle) {
+ println("found")
+ return
+ }
+}
+
+func assignTrueBreak(slice []int, needle int) {
+ found := slices.Contains(slice, needle)
+ print(found)
+}
+
+func assignFalseBreak(slice []int, needle int) { // TODO: treat this specially like booleanTrue
+ found := true
+ if slices.Contains(slice, needle) {
+ found = false
+ }
+ print(found)
+}
+
+func assignFalseBreakInSelectSwitch(slice []int, needle int) {
+ // Exercise RangeStmt in CommClause, CaseClause.
+ select {
+ default:
+ found := slices.Contains(slice, needle)
+ print(found)
+ }
+ switch {
+ default:
+ found := slices.Contains(slice, needle)
+ print(found)
+ }
+}
+
+func returnTrue(slice []int, needle int) bool {
+ return slices.Contains(slice, needle)
+}
+
+func returnFalse(slice []int, needle int) bool {
+ return !slices.Contains(slice, needle)
+}
+
+func containsFunc(slice []int, needle int) bool {
+ return slices.ContainsFunc(slice, predicate)
+}
+
+func nopeLoopBodyHasFreeContinuation(slice []int, needle int) bool {
+ for _, elem := range slice {
+ if predicate(elem) {
+ if needle == 7 {
+ continue // this statement defeats loop elimination
+ }
+ return true
+ }
+ }
+ return false
+}
+
+func predicate(int) bool
diff --git a/internal/astutil/cursor/cursor.go b/internal/astutil/cursor/cursor.go
index 89dd641..24fec99 100644
--- a/internal/astutil/cursor/cursor.go
+++ b/internal/astutil/cursor/cursor.go
@@ -304,8 +304,10 @@
// - [ast.AssignStmt] (Lhs, Rhs)
//
// So, do not assume that the previous sibling of an ast.Stmt is also
-// an ast.Stmt unless you have established that, say, its parent is a
-// BlockStmt.
+// an ast.Stmt, or if it is, that they are executed sequentially,
+// unless you have established that, say, its parent is a BlockStmt.
+// For example, given "for S1; ; S2 {}", the predecessor of S2 is S1,
+// even though they are not executed in sequence.
func (c Cursor) Children() iter.Seq[Cursor] {
return func(yield func(Cursor) bool) {
c, ok := c.FirstChild()