blob: 8d1391e2b0087177be8ad90746e5d8225e47c633 [file]
// Copyright 2026 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package yield
// TODO(adonovan): also check for inefficient code using this pattern:
//
// for x := range seq {
// if !yield(x) {
// break
// }
// }
//
// which should be entirely rewritten as
//
// seq(yield)
//
// to avoid unnecessary range desugaring and chains of dynamic calls.
import (
"cmp"
_ "embed"
"fmt"
"go/ast"
"go/constant"
"go/token"
"go/types"
"iter"
"math/bits"
"slices"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/buildssa"
"golang.org/x/tools/go/analysis/passes/inspect"
"golang.org/x/tools/go/ast/inspector"
"golang.org/x/tools/go/ssa"
"golang.org/x/tools/gopls/internal/util/moremaps"
"golang.org/x/tools/gopls/internal/util/safetoken"
"golang.org/x/tools/internal/analysis/analyzerutil"
"golang.org/x/tools/internal/flow"
"golang.org/x/tools/internal/typesinternal"
)
// This analyzer uses a classical dataflow analysis to track the set
// of program points that may be reached after a specific yield() call
// must have returned false. It uses the [flow] framework to compute a
// fixed point over the SSA control-flow graph. The lattice value,
// [stateSet], represents a set of facts about the conditions under
// which the current program point _may_ be reached after yield
// returns false. The conditions are the known truth or falsehood of
// selected local Boolean SSA values, specifically constants, yield
// calls, negations, and phis. Values are merged when their conditions
// are equal, or when a stronger condition makes a weaker one
// redundant.
//
// (An earlier implementation used only sparse dataflow analysis but
// had a number of false positives due to loss of precision when
// control flow joins were materialized as boolean values.)
//
// Note that this is a "may" dataflow analysis: it reports when a
// yield function _may_ be called again without a positive intervening
// check, but it is possible that the check is beyond the ability of
// the representation to detect, perhaps involving sophisticated use
// of booleans, indirect state (not in SSA registers), or multiple
// flow paths some of which are infeasible.
//
// A "must" analysis (which would report when a second yield call can
// only be reached after failing the boolean check) would be too
// conservative. In particular, the most common mistake is to forget
// to check the boolean at all.
//
// The analysis ignores 'go' and 'defer' statements.
//go:embed doc.go
var doc string
var Analyzer = &analysis.Analyzer{
Name: "yield",
Doc: analyzerutil.MustExtractDoc(doc, "yield"),
Requires: []*analysis.Analyzer{inspect.Analyzer, buildssa.Analyzer},
Run: run,
URL: "https://pkg.go.dev/golang.org/x/tools/gopls/internal/analysis/yield",
}
func run(pass *analysis.Pass) (any, error) {
// It is not strictly necessary that an iterator reference
// iter.Seq{,2}, but it is overwhelmingly the usual case.
// Skip any package that does not.
if !typesinternal.Imports(pass.Pkg, "iter") {
return nil, nil
}
// Find position of each syntactic yield call.
// We assume each yield function is named "yield".
var (
yieldCalls = make(map[token.Pos]*ast.CallExpr) // keyed by CallExpr.Lparen.
inspector = pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
)
for curCall := range inspector.Root().Preorder((*ast.CallExpr)(nil)) {
call := curCall.Node().(*ast.CallExpr)
if id, ok := call.Fun.(*ast.Ident); ok && id.Name == "yield" {
if sig, ok := pass.TypesInfo.TypeOf(id).(*types.Signature); ok &&
sig.Params().Len() < 3 &&
sig.Results().Len() == 1 &&
types.Identical(sig.Results().At(0).Type(), types.Typ[types.Bool]) {
yieldCalls[call.Lparen] = call
}
}
}
// Common case: nothing to do.
if len(yieldCalls) == 0 {
return nil, nil
}
callSyntax := func(call *ssa.Call) *ast.CallExpr {
return yieldCalls[call.Pos()]
}
// Study the control flow using SSA.
buildssa := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA)
for _, fn := range buildssa.SrcFuncs {
run1(pass, callSyntax, fn)
}
return nil, nil
}
func run1(pass *analysis.Pass, callSyntax func(call *ssa.Call) *ast.CallExpr, fn *ssa.Function) {
// Find SSA instruction of each yield call.
ssaYieldCalls := make(map[*ssa.Call]bool)
for _, b := range fn.Blocks {
for _, instr := range b.Instrs {
if call, ok := instr.(*ssa.Call); ok && callSyntax(call) != nil {
ssaYieldCalls[call] = true
}
}
}
if len(ssaYieldCalls) == 0 {
return
}
isYieldCall := func(v ssa.Value) bool {
call, ok := v.(*ssa.Call)
return ok && ssaYieldCalls[call]
}
numb := make(numbering)
// Compute the dataflow solution.
initial := map[int]stateSet{0: nil} // on entry, no states of interest
result := flow.Forward[lattice](fnGraph{fn}, initial, func(fromID, toID int, in stateSet) stateSet {
// The transfer function computes the effect on the abstract
// state of flow along the CFG edge from --> to,
// including the effects from the 'from' block itself.
// and the prefix of phis in the 'to' block.
//
// In effect, the block's state in the framework
// corresponds to the point after its prefix of phis.
var (
from = fn.Blocks[fromID]
to = fn.Blocks[toID]
out = in // (do not mutate in)
)
for _, instr := range from.Instrs {
switch instr := instr.(type) {
case *ssa.Call:
if isYieldCall(instr) {
out = out.yieldCall(numb, instr)
}
case *ssa.UnOp:
if instr.Op == token.NOT {
out = out.not(numb, instr)
}
case *ssa.If:
out = out.if_(numb, instr.Cond, to == from.Succs[0])
}
}
// Process phis in 'to' block.
if i := slices.Index(to.Preds, from); i >= 0 {
for _, instr := range to.Instrs {
if phi, ok := instr.(*ssa.Phi); ok {
out = out.phi(numb, phi, phi.Edges[i])
}
}
}
// Opt: avoid renormalizing 'in' if unchanged.
if len(out) > 0 && !sameSlice(out, in) {
out = normalize(out)
}
return out
})
// Gather the problematic calls.
type problem struct{ first, later *ssa.Call }
var problems []problem
for i, b := range fn.Blocks {
in := result.In(i)
out := slices.Clone(in)
for _, instr := range b.Instrs {
if call, ok := instr.(*ssa.Call); ok && isYieldCall(call) {
for _, s := range out {
problems = append(problems, problem{first: s.yield, later: call})
}
// Apply intra-block transfer function,
// in case there are yield calls in sequence.
out = out.yieldCall(numb, call)
}
}
}
// Sort, since source order differs from block order,
// and for each 'first', we want the earliest 'later' in the source.
slices.SortFunc(problems, func(x, y problem) int {
if d := cmp.Compare(x.first.Pos(), y.first.Pos()); d != 0 {
return d
}
return cmp.Compare(x.later.Pos(), y.later.Pos())
})
// Report a diagnostic for each problematic 'first' call.
for _, p := range problems {
if !moremaps.Delete(ssaYieldCalls, p.first) {
continue // already reported
}
var where string
var related []analysis.RelatedInformation
if p.later != p.first {
otherLine := safetoken.StartPosition(pass.Fset, p.later.Pos()).Line
where = fmt.Sprintf("(on L%d) ", otherLine)
laterCall := callSyntax(p.later)
related = []analysis.RelatedInformation{{
Pos: laterCall.Pos(),
End: laterCall.End(),
Message: "other call here",
}}
}
firstCall := callSyntax(p.first)
pass.Report(analysis.Diagnostic{
Pos: firstCall.Pos(),
End: firstCall.End(),
Message: fmt.Sprintf("yield may be called again %safter returning false", where),
Related: related,
})
}
}
// A numbering maps SSA values to small nonnegative integers.
type numbering map[ssa.Value]int
// number returns the sequence of number for value v.
func (n numbering) number(v ssa.Value) int {
i, ok := n[v]
if !ok {
i = len(n)
n[v] = i
}
return i
}
// -- lattice --
type lattice struct{}
var _ flow.Semilattice[stateSet] = lattice{}
// Ident returns the identity element of the lattice, an empty stateSet.
func (lattice) Ident() stateSet { return nil }
// Equals reports whether two stateSets are equivalent.
func (lattice) Equals(a, b stateSet) bool {
return slices.EqualFunc(a, b, state.equal)
}
// Merge combines two stateSets into a minimal unified set, dropping subsets.
// The result is normalized even if the arguments are not.
func (lattice) Merge(a, b stateSet) stateSet {
return normalize(slices.Concat(a, b))
}
// normalize puts the stateSet in normal form, destructively.
func normalize(ss stateSet) stateSet {
// We define a total order of states in the set so
// that set equality is slice equality.
// States are ordered by four keys in order:
// - yield call (since this is cheaper than mask.len);
// - mask.len (number of conditions), smallest first,
// so that the later merging step can eliminate
// narrow conditions for a given yield call in
// favor of broader ones;
// - mask bit pattern
// - senses bit pattern
// The latter two are essentially arbitrary ways
// to ensure a total order.
slices.SortFunc(ss, func(x, y state) int {
if x.yield != y.yield {
return cmp.Compare(x.yield.Pos(), y.yield.Pos()) // must be non-zero
}
if d := cmp.Compare(x.mask.len(), y.mask.len()); d != 0 {
return d
}
if d := x.mask.cmp(&y.mask); d != 0 {
return d
}
// We rely on senses having no stray (unmasked) bits.
if d := x.senses.cmp(&y.senses); d != 0 {
return d
}
return 0
})
// Discard empty states, or states that are stricter
// than (and thus redundant wrt) ones we already have.
//
// This is quadratic in the number of analytically distinct
// control states, which is related to the number of yield
// calls, the number of control paths that differ in their
// treatement of yield results, and the number of phis.
// But n is typically tiny.
out := ss[:0]
for _, s := range ss {
if !s.mask.empty() && !slices.ContainsFunc(out, s.stricter) {
out = append(out, s)
}
}
return out
}
// -- stateSet transfer (pure) functions --
// stateSet is the dataflow fact associated with each block edge.
//
// Conceptually it is a map from a specific yield call to an
// "antichain", a set of partially ordered states such that none
// subsumes another (similar to DNF). Though specialized
// representations exist (e.g. BDDs), a slice is fine in practice.
//
// The states in the set are totally ordered (somewhat arbitrarily)
// by [normalize] so that state set equality is slice equality.
type stateSet []state
// yieldCall defines the transfer function for a yield call.
func (in stateSet) yieldCall(n numbering, call *ssa.Call) (out stateSet) {
callnum := n.number(call)
s := state{yield: call}
s.mask.set(callnum, true)
s.senses.set(callnum, false) // analysis presumes each yield call returns false
return append(slices.Clip(in), s)
}
// not defines the transfer function for a negation.
func (in stateSet) not(n numbering, not *ssa.UnOp) (out stateSet) {
xnum := n.number(not.X)
notnum := n.number(not)
for _, s := range in {
if s.mask.get(xnum) {
s = s.update(notnum, true, !s.senses.get(xnum))
} else {
s = s.update(notnum, false, false) // clear stale fact
}
out = append(out, s)
}
return out
}
// phi defines the transfer function for a phi node for the incoming edge value val.
func (in stateSet) phi(n numbering, phi *ssa.Phi, val ssa.Value) (out stateSet) {
phinum := n.number(phi)
if c, ok := val.(*ssa.Const); ok && c.Value != nil && c.Value.Kind() == constant.Bool {
sense := constant.BoolVal(c.Value)
// phi's value is a constant (sense).
for _, s := range in {
s = s.update(phinum, true, sense)
out = append(out, s)
}
} else {
// phi's value comes from predecessor, val.
valnum := n.number(val)
for _, s := range in {
if s.mask.get(valnum) {
s = s.update(phinum, true, s.senses.get(valnum))
} else {
s = s.update(phinum, false, false) // clear stale fact
}
out = append(out, s)
}
}
return out
}
// if_ defines the transfer function for a conditional branch.
func (in stateSet) if_(n numbering, cond ssa.Value, sense bool) (out stateSet) {
// Strip off any negations to get to the root fact.
for {
unop, ok := cond.(*ssa.UnOp)
if !(ok && unop.Op == token.NOT) {
break
}
sense = !sense
cond = unop.X
}
condnum := n.number(cond)
for _, s := range in {
if s.mask.get(condnum) && s.senses.get(condnum) != sense {
// Infeasible edge; discard this state.
continue
}
out = append(out, s)
}
return out
}
// -- state --
// state represents an execution path defined by a set of boolean conditions.
// It tracks the original yield call that could be violated if this state's conditions are met.
// Conceptually, the conditions are a mapping from ssa.Value to boolean sense.
// Concretely, SSA values are sequentially numbered (see [numbering]) as they are encountered,
// and these numbers identify values.
//
// mask is the set of map keys; senses is the boolean sense of each value.
// It is an invariant that senses is a subset of mask.
type state struct {
yield *ssa.Call // yield call that would be violated if the conditions are not met
mask bitset // tracks which values have conditions
senses bitset // tracks the boolean condition (sense) of each value
}
// equal reports whether x and y are the same state.
func (x state) equal(y state) bool {
return x.yield == y.yield &&
x.mask.equal(&y.mask) &&
x.senses.equalMasked(&y.senses, &x.mask)
}
// stricter reports whether x is a stricter (more specific) state than y.
func (x state) stricter(y state) bool {
return x.yield == y.yield &&
y.mask.subsetOf(&x.mask) &&
y.senses.equalMasked(&x.senses, &y.mask)
}
// update returns a copy of the state with the specified value's condition updated.
func (s state) update(num int, mask, sense bool) state {
s.mask = s.mask.clone()
s.senses = s.senses.clone()
s.mask.set(num, mask)
s.senses.set(num, sense)
return s
}
// -- SSA CFG as graph.Graph --
// fnGraph adapts an [ssa.Function] to to the [graph.Graph] interface
// required by the flow analysis framework.
// Nodes are labelled by their block indices and connected by the
// successor relation.
//
// TODO(adonovan): move to ssa package?
type fnGraph struct {
fn *ssa.Function
}
// Nodes returns an iterator over the basic block indices.
func (g fnGraph) Nodes() iter.Seq[int] {
return func(yield func(int) bool) {
for i := range g.fn.Blocks {
if !yield(i) {
return
}
}
}
}
// NumNodes returns the number of basic blocks in the graph.
func (g fnGraph) NumNodes() int { return len(g.fn.Blocks) }
// Out returns an iterator over the successor block indices of a given node.
func (g fnGraph) Out(node int) iter.Seq[int] {
return func(yield func(int) bool) {
for _, succ := range g.fn.Blocks[node].Succs {
if !yield(succ.Index) {
return
}
}
}
}
// -- bitset --
// bitset is a set of non-negative integers.
// It uses space proportional to its largest element.
//
// The zero value is a ready-to-use empty set.
// Bitsets, like slices, have hybrid value/reference semantics.
// Do not mutate copies; use [bitset.clone] before [bitset.set].
//
// bitsets are comparable (see [bitset.equal]) and totally ordered (see [bitset.cmp]).
type bitset struct {
limbs []uint64 // bit vector; last limb is nonzero
}
// empty reports whether the set is empty.
func (b *bitset) empty() bool {
return len(b.limbs) == 0
}
// set inserts or removes i from the set.
func (b *bitset) set(i int, sense bool) {
// Grow if needed.
idx := int(i / 64)
if idx >= len(b.limbs) {
if !sense {
return // clearing nonexistent bit
}
b.limbs = slices.Grow(b.limbs, idx-len(b.limbs)+1)[:idx+1]
}
bit := uint64(1) << (i % 64)
if sense {
// set
b.limbs[idx] |= bit
} else {
// clear
b.limbs[idx] &^= bit
// Remove any trailing zero limbs.
if b.limbs[idx] == 0 {
n := len(b.limbs)
for n > 0 && b.limbs[n-1] == 0 {
n--
}
b.limbs = b.limbs[:n]
}
}
}
func (b *bitset) limb(i int) uint64 {
if i < len(b.limbs) {
return b.limbs[i]
}
return 0
}
// get reports whether the set contains i.
func (b *bitset) get(i int) bool {
return b.limb(i/64)&(1<<(i%64)) != 0
}
// clone returns a copy of the bitset.
func (b *bitset) clone() bitset {
return bitset{limbs: slices.Clone(b.limbs)}
}
// equal reports whether two sets contain the same elements.
func (b *bitset) equal(other *bitset) bool {
return slices.Equal(b.limbs, other.limbs)
}
// equalMasked reports whether b&mask equals other&mask.
func (b *bitset) equalMasked(other, mask *bitset) bool {
// Above n words, both operands when masked are effectively zero.
n := min(len(mask.limbs), max(len(b.limbs), len(other.limbs)))
for i, m := range mask.limbs[:n] {
if b.limb(i)&m != other.limb(i)&m {
return false
}
}
return true
}
// cmp returns the signum of the comparison b against other.
func (b *bitset) cmp(other *bitset) int {
if d := cmp.Compare(len(b.limbs), len(other.limbs)); d != 0 {
return d
}
for i := len(b.limbs) - 1; i >= 0; i-- {
if d := cmp.Compare(b.limbs[i], other.limbs[i]); d != 0 {
return d
}
}
return 0
}
// subsetOf reports whether other contains all of b's elements.
func (b *bitset) subsetOf(other *bitset) bool {
for i, w1 := range b.limbs {
if w1&other.limb(i) != w1 {
return false
}
}
return true
}
// len returns the number of elements of the set.
func (b *bitset) len() int {
var n int
for _, v := range b.limbs {
n += bits.OnesCount64(v)
}
return n
}
// sameSlice reports whether the corresponding elements
// of two slices are identical variables.
func sameSlice[T any](x, y []T) bool {
return len(x) == len(y) && (len(x) == 0 || &x[0] == &y[0])
}