gopls/internal/analysis/modernize: replace loop with slices.Contains

This CL adds a modernizer pass for slices.Contains{,Func}.

Example:

func assignTrueBreak(slice []int, needle int) {
	found := false
	for _, elem := range slice { // want "Loop can be simplified using strings.Contains"
		if elem == needle {
			found = true
			break
		}
	}
	print(found)
}

=>

func assignTrueBreak(slice []int, needle int) {
	found := slices.Contains(slice, needle)
	print(found)
}

Updates golang/go#70815

Change-Id: I72ad1c099481b6c9ae6f732e2d81674a98b79a9f
Reviewed-on: https://go-review.googlesource.com/c/tools/+/640576
Auto-Submit: Alan Donovan <adonovan@google.com>
Reviewed-by: Robert Findley <rfindley@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Commit-Queue: Alan Donovan <adonovan@google.com>
diff --git a/gopls/internal/analysis/modernize/modernize.go b/gopls/internal/analysis/modernize/modernize.go
index 3734618..6cedc5e 100644
--- a/gopls/internal/analysis/modernize/modernize.go
+++ b/gopls/internal/analysis/modernize/modernize.go
@@ -49,6 +49,9 @@
 		}
 		report := pass.Report
 		pass.Report = func(diag analysis.Diagnostic) {
+			if diag.Category == "" {
+				panic("Diagnostic.Category is unset")
+			}
 			if _, ok := generated[pass.Fset.File(diag.Pos)]; ok {
 				return // skip checking if it's generated code
 			}
@@ -62,6 +65,7 @@
 	fmtappendf(pass)
 	mapsloop(pass)
 	minmax(pass)
+	slicescontains(pass)
 	sortslice(pass)
 	testingContext(pass)
 
@@ -120,7 +124,9 @@
 	builtinAny    = types.Universe.Lookup("any")
 	builtinAppend = types.Universe.Lookup("append")
 	builtinBool   = types.Universe.Lookup("bool")
+	builtinFalse  = types.Universe.Lookup("false")
 	builtinMake   = types.Universe.Lookup("make")
 	builtinNil    = types.Universe.Lookup("nil")
+	builtinTrue   = types.Universe.Lookup("true")
 	byteSliceType = types.NewSlice(types.Typ[types.Byte])
 )
diff --git a/gopls/internal/analysis/modernize/modernize_test.go b/gopls/internal/analysis/modernize/modernize_test.go
index bf3114e..d8d2d9a 100644
--- a/gopls/internal/analysis/modernize/modernize_test.go
+++ b/gopls/internal/analysis/modernize/modernize_test.go
@@ -19,6 +19,7 @@
 		"fmtappendf",
 		"mapsloop",
 		"minmax",
+		"slicescontains",
 		"sortslice",
 		"testingcontext",
 	)
diff --git a/gopls/internal/analysis/modernize/slices.go b/gopls/internal/analysis/modernize/slices.go
index 1389298..cb73f7e 100644
--- a/gopls/internal/analysis/modernize/slices.go
+++ b/gopls/internal/analysis/modernize/slices.go
@@ -5,6 +5,7 @@
 package modernize
 
 // This file defines modernizers that use the "slices" package.
+// TODO(adonovan): actually let's split them up and rename this file.
 
 import (
 	"fmt"
diff --git a/gopls/internal/analysis/modernize/slicescontains.go b/gopls/internal/analysis/modernize/slicescontains.go
new file mode 100644
index 0000000..062083c
--- /dev/null
+++ b/gopls/internal/analysis/modernize/slicescontains.go
@@ -0,0 +1,365 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package modernize
+
+import (
+	"fmt"
+	"go/ast"
+	"go/token"
+	"go/types"
+
+	"golang.org/x/tools/go/analysis"
+	"golang.org/x/tools/go/analysis/passes/inspect"
+	"golang.org/x/tools/go/ast/inspector"
+	"golang.org/x/tools/go/types/typeutil"
+	"golang.org/x/tools/internal/analysisinternal"
+	"golang.org/x/tools/internal/astutil/cursor"
+	"golang.org/x/tools/internal/typeparams"
+)
+
+// The slicescontains pass identifies loops that can be replaced by a
+// call to slices.Contains{,Func}. For example:
+//
+//	for i, elem := range s {
+//		if elem == needle {
+//			...
+//			break
+//		}
+//	}
+//
+// =>
+//
+//	if slices.Contains(s, needle) { ... }
+//
+// Variants:
+//   - if the if-condition is f(elem), the replacement
+//     uses slices.ContainsFunc(s, f).
+//   - if the if-body is "return true" and the fallthrough
+//     statement is "return false" (or vice versa), the
+//     loop becomes "return [!]slices.Contains(...)".
+//   - if the if-body is "found = true" and the previous
+//     statement is "found = false" (or vice versa), the
+//     loop becomes "found = [!]slices.Contains(...)".
+//
+// It may change cardinality of effects of the "needle" expression.
+// (Mostly this appears to be a desirable optimization, avoiding
+// redundantly repeated evaluation.)
+func slicescontains(pass *analysis.Pass) {
+	// Don't modify the slices package itself.
+	if pass.Pkg.Path() == "slices" {
+		return
+	}
+
+	info := pass.TypesInfo
+
+	// check is called for each RangeStmt of this form:
+	//   for i, elem := range s { if cond { ... } }
+	check := func(file *ast.File, curRange cursor.Cursor) {
+		rng := curRange.Node().(*ast.RangeStmt)
+		ifStmt := rng.Body.List[0].(*ast.IfStmt)
+
+		// isSliceElem reports whether e denotes the
+		// current slice element (elem or s[i]).
+		isSliceElem := func(e ast.Expr) bool {
+			if rng.Value != nil && equalSyntax(e, rng.Value) {
+				return true // "elem"
+			}
+			if x, ok := e.(*ast.IndexExpr); ok &&
+				equalSyntax(x.X, rng.X) &&
+				equalSyntax(x.Index, rng.Key) {
+				return true // "s[i]"
+			}
+			return false
+		}
+
+		// Examine the condition for one of these forms:
+		//
+		// - if elem or s[i] == needle  { ... } => Contains
+		// - if predicate(s[i] or elem) { ... } => ContainsFunc
+		var (
+			funcName string   // "Contains" or "ContainsFunc"
+			arg2     ast.Expr // second argument to func (needle or predicate)
+		)
+		switch cond := ifStmt.Cond.(type) {
+		case *ast.BinaryExpr:
+			if cond.Op == token.EQL {
+				if isSliceElem(cond.X) {
+					funcName = "Contains"
+					arg2 = cond.Y // "if elem == needle"
+				} else if isSliceElem(cond.Y) {
+					funcName = "Contains"
+					arg2 = cond.X // "if needle == elem"
+				}
+			}
+
+		case *ast.CallExpr:
+			if len(cond.Args) == 1 &&
+				isSliceElem(cond.Args[0]) &&
+				typeutil.Callee(info, cond) != nil { // not a conversion
+
+				funcName = "ContainsFunc"
+				arg2 = cond.Fun // "if predicate(elem)"
+			}
+		}
+		if funcName == "" {
+			return // not a candidate for Contains{,Func}
+		}
+
+		// body is the "true" body.
+		body := ifStmt.Body
+		if len(body.List) == 0 {
+			// (We could perhaps delete the loop entirely.)
+			return
+		}
+
+		// Reject if the body, needle or predicate references either range variable.
+		usesRangeVar := func(n ast.Node) bool {
+			cur, ok := curRange.FindNode(n)
+			if !ok {
+				panic(fmt.Sprintf("FindNode(%T) failed", n))
+			}
+			return uses(info, cur, info.Defs[rng.Key.(*ast.Ident)]) ||
+				rng.Value != nil && uses(info, cur, info.Defs[rng.Value.(*ast.Ident)])
+		}
+		if usesRangeVar(body) {
+			// Body uses range var "i" or "elem".
+			//
+			// (The check for "i" could be relaxed when we
+			// generalize this to support slices.Index;
+			// and the check for "elem" could be relaxed
+			// if "elem" can safely be replaced in the
+			// body by "needle".)
+			return
+		}
+		if usesRangeVar(arg2) {
+			return
+		}
+
+		// Prepare slices.Contains{,Func} call.
+		slicesName, importEdits := analysisinternal.AddImport(info, file, rng.Pos(), "slices", "slices")
+		contains := fmt.Sprintf("%s.%s(%s, %s)",
+			slicesName,
+			funcName,
+			analysisinternal.Format(pass.Fset, rng.X),
+			analysisinternal.Format(pass.Fset, arg2))
+
+		report := func(edits []analysis.TextEdit) {
+			pass.Report(analysis.Diagnostic{
+				Pos:      rng.Pos(),
+				End:      rng.End(),
+				Category: "slicescontains",
+				Message:  fmt.Sprintf("Loop can be simplified using slices.%s", funcName),
+				SuggestedFixes: []analysis.SuggestedFix{{
+					Message:   "Replace loop by call to slices." + funcName,
+					TextEdits: append(edits, importEdits...),
+				}},
+			})
+		}
+
+		// Last statement of body must return/break out of the loop.
+		curBody, _ := curRange.FindNode(body)
+		curLastStmt, _ := curBody.LastChild()
+
+		// Reject if any statement in the body except the
+		// last has a free continuation (continue or break)
+		// that might affected by melting down the loop.
+		//
+		// TODO(adonovan): relax check by analyzing branch target.
+		for curBodyStmt := range curBody.Children() {
+			if curBodyStmt != curLastStmt {
+				for range curBodyStmt.Preorder((*ast.BranchStmt)(nil), (*ast.ReturnStmt)(nil)) {
+					return
+				}
+			}
+		}
+
+		switch lastStmt := curLastStmt.Node().(type) {
+		case *ast.ReturnStmt:
+			// Have: for ... range seq { if ... { stmts; return x } }
+
+			// Special case:
+			// body={ return true } next="return false"   (or negation)
+			// => return [!]slices.Contains(...)
+			if curNext, ok := curRange.NextSibling(); ok {
+				nextStmt := curNext.Node().(ast.Stmt)
+				tval := isReturnTrueOrFalse(info, lastStmt)
+				fval := isReturnTrueOrFalse(info, nextStmt)
+				if len(body.List) == 1 && tval*fval < 0 {
+					//    for ... { if ... { return true/false } }
+					// => return [!]slices.Contains(...)
+					report([]analysis.TextEdit{
+						// Delete the range statement and following space.
+						{
+							Pos: rng.Pos(),
+							End: nextStmt.Pos(),
+						},
+						// Change return to [!]slices.Contains(...).
+						{
+							Pos: nextStmt.Pos(),
+							End: nextStmt.End(),
+							NewText: fmt.Appendf(nil, "return %s%s",
+								cond(tval > 0, "", "!"),
+								contains),
+						},
+					})
+					return
+				}
+			}
+
+			// General case:
+			// => if slices.Contains(...) { stmts; return x }
+			report([]analysis.TextEdit{
+				// Replace "for ... { if ... " with "if slices.Contains(...)".
+				{
+					Pos:     rng.Pos(),
+					End:     ifStmt.Body.Pos(),
+					NewText: fmt.Appendf(nil, "if %s ", contains),
+				},
+				// Delete '}' of range statement and preceding space.
+				{
+					Pos: ifStmt.Body.End(),
+					End: rng.End(),
+				},
+			})
+			return
+
+		case *ast.BranchStmt:
+			if lastStmt.Tok == token.BREAK && lastStmt.Label == nil { // unlabeled break
+				// Have: for ... { if ... { stmts; break } }
+
+				var prevStmt ast.Stmt // previous statement to range (if any)
+				if curPrev, ok := curRange.PrevSibling(); ok {
+					// If the RangeStmt's previous sibling is a Stmt,
+					// the RangeStmt must be among the Body list of
+					// a BlockStmt, CauseClause, or CommClause.
+					// In all cases, the prevStmt is the immediate
+					// predecessor of the RangeStmt during execution.
+					//
+					// (This is not true for Stmts in general;
+					// see [Cursor.Children] and #71074.)
+					prevStmt, _ = curPrev.Node().(ast.Stmt)
+				}
+
+				// Special case:
+				// prev="lhs = false" body={ lhs = true; break }
+				// => lhs = slices.Contains(...) (or negation)
+				if assign, ok := body.List[0].(*ast.AssignStmt); ok &&
+					len(body.List) == 2 &&
+					assign.Tok == token.ASSIGN &&
+					len(assign.Lhs) == 1 &&
+					len(assign.Rhs) == 1 {
+
+					// Have: body={ lhs = rhs; break }
+
+					if prevAssign, ok := prevStmt.(*ast.AssignStmt); ok &&
+						len(prevAssign.Lhs) == 1 &&
+						len(prevAssign.Rhs) == 1 &&
+						equalSyntax(prevAssign.Lhs[0], assign.Lhs[0]) &&
+						is[*ast.Ident](assign.Rhs[0]) &&
+						info.Uses[assign.Rhs[0].(*ast.Ident)] == builtinTrue {
+
+						// Have:
+						//    lhs = false
+						//    for ... { if ... { lhs = true; break } }
+						//  =>
+						//    lhs = slices.Contains(...)
+						//
+						// TODO(adonovan):
+						// - support "var lhs bool = false" and variants.
+						// - support negation.
+						// Both these variants seem quite significant.
+						// - allow the break to be omitted.
+						report([]analysis.TextEdit{
+							// Replace "rhs" of previous assignment by slices.Contains(...)
+							{
+								Pos:     prevAssign.Rhs[0].Pos(),
+								End:     prevAssign.Rhs[0].End(),
+								NewText: []byte(contains),
+							},
+							// Delete the loop and preceding space.
+							{
+								Pos: prevAssign.Rhs[0].End(),
+								End: rng.End(),
+							},
+						})
+						return
+					}
+				}
+
+				// General case:
+				//    for ... { if ...        { stmts; break } }
+				// => if slices.Contains(...) { stmts        }
+				report([]analysis.TextEdit{
+					// Replace "for ... { if ... " with "if slices.Contains(...)".
+					{
+						Pos:     rng.Pos(),
+						End:     ifStmt.Body.Pos(),
+						NewText: fmt.Appendf(nil, "if %s ", contains),
+					},
+					// Delete break statement and preceding space.
+					{
+						Pos: func() token.Pos {
+							if len(body.List) > 1 {
+								beforeBreak, _ := curLastStmt.PrevSibling()
+								return beforeBreak.Node().End()
+							}
+							return lastStmt.Pos()
+						}(),
+						End: lastStmt.End(),
+					},
+					// Delete '}' of range statement and preceding space.
+					{
+						Pos: ifStmt.Body.End(),
+						End: rng.End(),
+					},
+				})
+				return
+			}
+		}
+	}
+
+	inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
+	for curFile := range filesUsing(inspect, info, "go1.21") {
+		file := curFile.Node().(*ast.File)
+
+		for curRange := range curFile.Preorder((*ast.RangeStmt)(nil)) {
+			rng := curRange.Node().(*ast.RangeStmt)
+
+			if is[*ast.Ident](rng.Key) &&
+				rng.Tok == token.DEFINE &&
+				len(rng.Body.List) == 1 &&
+				is[*types.Slice](typeparams.CoreType(info.TypeOf(rng.X))) {
+
+				// Have:
+				// - for _, elem := range s { S }
+				// - for i       := range s { S }
+
+				if ifStmt, ok := rng.Body.List[0].(*ast.IfStmt); ok &&
+					ifStmt.Init == nil && ifStmt.Else == nil {
+
+					// Have: for i, elem := range s { if cond { ... } }
+					check(file, curRange)
+				}
+			}
+		}
+	}
+}
+
+// -- helpers --
+
+// isReturnTrueOrFalse returns nonzero if stmt returns true (+1) or false (-1).
+func isReturnTrueOrFalse(info *types.Info, stmt ast.Stmt) int {
+	if ret, ok := stmt.(*ast.ReturnStmt); ok && len(ret.Results) == 1 {
+		if id, ok := ret.Results[0].(*ast.Ident); ok {
+			switch info.Uses[id] {
+			case builtinTrue:
+				return +1
+			case builtinFalse:
+				return -1
+			}
+		}
+	}
+	return 0
+}
diff --git a/gopls/internal/analysis/modernize/testdata/src/slicescontains/slicescontains.go b/gopls/internal/analysis/modernize/testdata/src/slicescontains/slicescontains.go
new file mode 100644
index 0000000..ecb7371
--- /dev/null
+++ b/gopls/internal/analysis/modernize/testdata/src/slicescontains/slicescontains.go
@@ -0,0 +1,129 @@
+package slicescontains
+
+import "slices"
+
+var _ = slices.Contains[[]int] // force import of "slices" to avoid duplicate import edits
+
+func nopeNoBreak(slice []int, needle int) {
+	for i := range slice {
+		if slice[i] == needle {
+			println("found")
+		}
+	}
+}
+
+func rangeIndex(slice []int, needle int) {
+	for i := range slice { // want "Loop can be simplified using slices.Contains"
+		if slice[i] == needle {
+			println("found")
+			break
+		}
+	}
+}
+
+func rangeValue(slice []int, needle int) {
+	for _, elem := range slice { // want "Loop can be simplified using slices.Contains"
+		if elem == needle {
+			println("found")
+			break
+		}
+	}
+}
+
+func returns(slice []int, needle int) {
+	for i := range slice { // want "Loop can be simplified using slices.Contains"
+		if slice[i] == needle {
+			println("found")
+			return
+		}
+	}
+}
+
+func assignTrueBreak(slice []int, needle int) {
+	found := false
+	for _, elem := range slice { // want "Loop can be simplified using slices.Contains"
+		if elem == needle {
+			found = true
+			break
+		}
+	}
+	print(found)
+}
+
+func assignFalseBreak(slice []int, needle int) { // TODO: treat this specially like booleanTrue
+	found := true
+	for _, elem := range slice { // want "Loop can be simplified using slices.Contains"
+		if elem == needle {
+			found = false
+			break
+		}
+	}
+	print(found)
+}
+
+func assignFalseBreakInSelectSwitch(slice []int, needle int) {
+	// Exercise RangeStmt in CommClause, CaseClause.
+	select {
+	default:
+		found := false
+		for _, elem := range slice { // want "Loop can be simplified using slices.Contains"
+			if elem == needle {
+				found = true
+				break
+			}
+		}
+		print(found)
+	}
+	switch {
+	default:
+		found := false
+		for _, elem := range slice { // want "Loop can be simplified using slices.Contains"
+			if elem == needle {
+				found = true
+				break
+			}
+		}
+		print(found)
+	}
+}
+
+func returnTrue(slice []int, needle int) bool {
+	for _, elem := range slice { // want "Loop can be simplified using slices.Contains"
+		if elem == needle {
+			return true
+		}
+	}
+	return false
+}
+
+func returnFalse(slice []int, needle int) bool {
+	for _, elem := range slice { // want "Loop can be simplified using slices.Contains"
+		if elem == needle {
+			return false
+		}
+	}
+	return true
+}
+
+func containsFunc(slice []int, needle int) bool {
+	for _, elem := range slice { // want "Loop can be simplified using slices.ContainsFunc"
+		if predicate(elem) {
+			return true
+		}
+	}
+	return false
+}
+
+func nopeLoopBodyHasFreeContinuation(slice []int, needle int) bool {
+	for _, elem := range slice {
+		if predicate(elem) {
+			if needle == 7 {
+				continue // this statement defeats loop elimination
+			}
+			return true
+		}
+	}
+	return false
+}
+
+func predicate(int) bool
diff --git a/gopls/internal/analysis/modernize/testdata/src/slicescontains/slicescontains.go.golden b/gopls/internal/analysis/modernize/testdata/src/slicescontains/slicescontains.go.golden
new file mode 100644
index 0000000..561e42f
--- /dev/null
+++ b/gopls/internal/analysis/modernize/testdata/src/slicescontains/slicescontains.go.golden
@@ -0,0 +1,85 @@
+package slicescontains
+
+import "slices"
+
+var _ = slices.Contains[[]int] // force import of "slices" to avoid duplicate import edits
+
+func nopeNoBreak(slice []int, needle int) {
+	for i := range slice {
+		if slice[i] == needle {
+			println("found")
+		}
+	}
+}
+
+func rangeIndex(slice []int, needle int) {
+	if slices.Contains(slice, needle) {
+		println("found")
+	}
+}
+
+func rangeValue(slice []int, needle int) {
+	if slices.Contains(slice, needle) {
+		println("found")
+	}
+}
+
+func returns(slice []int, needle int) {
+	if slices.Contains(slice, needle) {
+		println("found")
+		return
+	}
+}
+
+func assignTrueBreak(slice []int, needle int) {
+	found := slices.Contains(slice, needle)
+	print(found)
+}
+
+func assignFalseBreak(slice []int, needle int) { // TODO: treat this specially like booleanTrue
+	found := true
+	if slices.Contains(slice, needle) {
+		found = false
+	}
+	print(found)
+}
+
+func assignFalseBreakInSelectSwitch(slice []int, needle int) {
+	// Exercise RangeStmt in CommClause, CaseClause.
+	select {
+	default:
+		found := slices.Contains(slice, needle)
+		print(found)
+	}
+	switch {
+	default:
+		found := slices.Contains(slice, needle)
+		print(found)
+	}
+}
+
+func returnTrue(slice []int, needle int) bool {
+	return slices.Contains(slice, needle)
+}
+
+func returnFalse(slice []int, needle int) bool {
+	return !slices.Contains(slice, needle)
+}
+
+func containsFunc(slice []int, needle int) bool {
+	return slices.ContainsFunc(slice, predicate)
+}
+
+func nopeLoopBodyHasFreeContinuation(slice []int, needle int) bool {
+	for _, elem := range slice {
+		if predicate(elem) {
+			if needle == 7 {
+				continue // this statement defeats loop elimination
+			}
+			return true
+		}
+	}
+	return false
+}
+
+func predicate(int) bool
diff --git a/internal/astutil/cursor/cursor.go b/internal/astutil/cursor/cursor.go
index 89dd641..24fec99 100644
--- a/internal/astutil/cursor/cursor.go
+++ b/internal/astutil/cursor/cursor.go
@@ -304,8 +304,10 @@
 // - [ast.AssignStmt] (Lhs, Rhs)
 //
 // So, do not assume that the previous sibling of an ast.Stmt is also
-// an ast.Stmt unless you have established that, say, its parent is a
-// BlockStmt.
+// an ast.Stmt, or if it is, that they are executed sequentially,
+// unless you have established that, say, its parent is a BlockStmt.
+// For example, given "for S1; ; S2 {}", the predecessor of S2 is S1,
+// even though they are not executed in sequence.
 func (c Cursor) Children() iter.Seq[Cursor] {
 	return func(yield func(Cursor) bool) {
 		c, ok := c.FirstChild()