blob: ecc0f94e5b96b674550b2672fa7ac2d10a2602bc [file] [log] [blame] [edit]
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssa
import (
"fmt"
)
// ----------------------------------------------------------------------------
// Sparse Conditional Constant Propagation
//
// Described in
// Mark N. Wegman, F. Kenneth Zadeck: Constant Propagation with Conditional Branches.
// TOPLAS 1991.
//
// This algorithm uses three level lattice for SSA value
//
// Top undefined
// / | \
// .. 1 2 3 .. constant
// \ | /
// Bottom not constant
//
// It starts with optimistically assuming that all SSA values are initially Top
// and then propagates constant facts only along reachable control flow paths.
// Since some basic blocks are not visited yet, corresponding inputs of phi become
// Top, we use the meet(phi) to compute its lattice.
//
// Top ∩ any = any
// Bottom ∩ any = Bottom
// ConstantA ∩ ConstantA = ConstantA
// ConstantA ∩ ConstantB = Bottom
//
// Each lattice value is lowered most twice(Top to Constant, Constant to Bottom)
// due to lattice depth, resulting in a fast convergence speed of the algorithm.
// In this way, sccp can discover optimization opportunities that cannot be found
// by just combining constant folding and constant propagation and dead code
// elimination separately.
// Three level lattice holds compile time knowledge about SSA value
const (
top int8 = iota // undefined
constant // constant
bottom // not a constant
)
type lattice struct {
tag int8 // lattice type
val *Value // constant value
}
type worklist struct {
f *Func // the target function to be optimized out
edges []Edge // propagate constant facts through edges
uses []*Value // re-visiting set
visited map[Edge]bool // visited edges
latticeCells map[*Value]lattice // constant lattices
defUse map[*Value][]*Value // def-use chains for some values
defBlock map[*Value][]*Block // use blocks of def
visitedBlock []bool // visited block
}
// sccp stands for sparse conditional constant propagation, it propagates constants
// through CFG conditionally and applies constant folding, constant replacement and
// dead code elimination all together.
func sccp(f *Func) {
var t worklist
t.f = f
t.edges = make([]Edge, 0)
t.visited = make(map[Edge]bool)
t.edges = append(t.edges, Edge{f.Entry, 0})
t.defUse = make(map[*Value][]*Value)
t.defBlock = make(map[*Value][]*Block)
t.latticeCells = make(map[*Value]lattice)
t.visitedBlock = f.Cache.allocBoolSlice(f.NumBlocks())
defer f.Cache.freeBoolSlice(t.visitedBlock)
// build it early since we rely heavily on the def-use chain later
t.buildDefUses()
// pick up either an edge or SSA value from worklist, process it
for {
if len(t.edges) > 0 {
edge := t.edges[0]
t.edges = t.edges[1:]
if _, exist := t.visited[edge]; !exist {
dest := edge.b
destVisited := t.visitedBlock[dest.ID]
// mark edge as visited
t.visited[edge] = true
t.visitedBlock[dest.ID] = true
for _, val := range dest.Values {
if val.Op == OpPhi || !destVisited {
t.visitValue(val)
}
}
// propagates constants facts through CFG, taking condition test
// into account
if !destVisited {
t.propagate(dest)
}
}
continue
}
if len(t.uses) > 0 {
use := t.uses[0]
t.uses = t.uses[1:]
t.visitValue(use)
continue
}
break
}
// apply optimizations based on discovered constants
constCnt, rewireCnt := t.replaceConst()
if f.pass.debug > 0 {
if constCnt > 0 || rewireCnt > 0 {
fmt.Printf("Phase SCCP for %v : %v constants, %v dce\n", f.Name, constCnt, rewireCnt)
}
}
}
func equals(a, b lattice) bool {
if a == b {
// fast path
return true
}
if a.tag != b.tag {
return false
}
if a.tag == constant {
// The same content of const value may be different, we should
// compare with auxInt instead
v1 := a.val
v2 := b.val
if v1.Op == v2.Op && v1.AuxInt == v2.AuxInt {
return true
} else {
return false
}
}
return true
}
// possibleConst checks if Value can be folded to const. For those Values that can
// never become constants(e.g. StaticCall), we don't make futile efforts.
func possibleConst(val *Value) bool {
if isConst(val) {
return true
}
switch val.Op {
case OpCopy:
return true
case OpPhi:
return true
case
// negate
OpNeg8, OpNeg16, OpNeg32, OpNeg64, OpNeg32F, OpNeg64F,
OpCom8, OpCom16, OpCom32, OpCom64,
// math
OpFloor, OpCeil, OpTrunc, OpRoundToEven, OpSqrt,
// conversion
OpTrunc16to8, OpTrunc32to8, OpTrunc32to16, OpTrunc64to8,
OpTrunc64to16, OpTrunc64to32, OpCvt32to32F, OpCvt32to64F,
OpCvt64to32F, OpCvt64to64F, OpCvt32Fto32, OpCvt32Fto64,
OpCvt64Fto32, OpCvt64Fto64, OpCvt32Fto64F, OpCvt64Fto32F,
OpCvtBoolToUint8,
OpZeroExt8to16, OpZeroExt8to32, OpZeroExt8to64, OpZeroExt16to32,
OpZeroExt16to64, OpZeroExt32to64, OpSignExt8to16, OpSignExt8to32,
OpSignExt8to64, OpSignExt16to32, OpSignExt16to64, OpSignExt32to64,
// bit
OpCtz8, OpCtz16, OpCtz32, OpCtz64,
// mask
OpSlicemask,
// safety check
OpIsNonNil,
// not
OpNot:
return true
case
// add
OpAdd64, OpAdd32, OpAdd16, OpAdd8,
OpAdd32F, OpAdd64F,
// sub
OpSub64, OpSub32, OpSub16, OpSub8,
OpSub32F, OpSub64F,
// mul
OpMul64, OpMul32, OpMul16, OpMul8,
OpMul32F, OpMul64F,
// div
OpDiv32F, OpDiv64F,
OpDiv8, OpDiv16, OpDiv32, OpDiv64,
OpDiv8u, OpDiv16u, OpDiv32u, OpDiv64u,
OpMod8, OpMod16, OpMod32, OpMod64,
OpMod8u, OpMod16u, OpMod32u, OpMod64u,
// compare
OpEq64, OpEq32, OpEq16, OpEq8,
OpEq32F, OpEq64F,
OpLess64, OpLess32, OpLess16, OpLess8,
OpLess64U, OpLess32U, OpLess16U, OpLess8U,
OpLess32F, OpLess64F,
OpLeq64, OpLeq32, OpLeq16, OpLeq8,
OpLeq64U, OpLeq32U, OpLeq16U, OpLeq8U,
OpLeq32F, OpLeq64F,
OpEqB, OpNeqB,
// shift
OpLsh64x64, OpRsh64x64, OpRsh64Ux64, OpLsh32x64,
OpRsh32x64, OpRsh32Ux64, OpLsh16x64, OpRsh16x64,
OpRsh16Ux64, OpLsh8x64, OpRsh8x64, OpRsh8Ux64,
// safety check
OpIsInBounds, OpIsSliceInBounds,
// bit
OpAnd8, OpAnd16, OpAnd32, OpAnd64,
OpOr8, OpOr16, OpOr32, OpOr64,
OpXor8, OpXor16, OpXor32, OpXor64:
return true
default:
return false
}
}
func (t *worklist) getLatticeCell(val *Value) lattice {
if !possibleConst(val) {
// they are always worst
return lattice{bottom, nil}
}
lt, exist := t.latticeCells[val]
if !exist {
return lattice{top, nil} // optimistically for un-visited value
}
return lt
}
func isConst(val *Value) bool {
switch val.Op {
case OpConst64, OpConst32, OpConst16, OpConst8,
OpConstBool, OpConst32F, OpConst64F:
return true
default:
return false
}
}
// buildDefUses builds def-use chain for some values early, because once the
// lattice of a value is changed, we need to update lattices of use. But we don't
// need all uses of it, only uses that can become constants would be added into
// re-visit worklist since no matter how many times they are revisited, uses which
// can't become constants lattice remains unchanged, i.e. Bottom.
func (t *worklist) buildDefUses() {
for _, block := range t.f.Blocks {
for _, val := range block.Values {
for _, arg := range val.Args {
// find its uses, only uses that can become constants take into account
if possibleConst(arg) && possibleConst(val) {
if _, exist := t.defUse[arg]; !exist {
t.defUse[arg] = make([]*Value, 0, arg.Uses)
}
t.defUse[arg] = append(t.defUse[arg], val)
}
}
}
for _, ctl := range block.ControlValues() {
// for control values that can become constants, find their use blocks
if possibleConst(ctl) {
t.defBlock[ctl] = append(t.defBlock[ctl], block)
}
}
}
}
// addUses finds all uses of value and appends them into work list for further process
func (t *worklist) addUses(val *Value) {
for _, use := range t.defUse[val] {
if val == use {
// Phi may refer to itself as uses, ignore them to avoid re-visiting phi
// for performance reason
continue
}
t.uses = append(t.uses, use)
}
for _, block := range t.defBlock[val] {
if t.visitedBlock[block.ID] {
t.propagate(block)
}
}
}
// meet meets all of phi arguments and computes result lattice
func (t *worklist) meet(val *Value) lattice {
optimisticLt := lattice{top, nil}
for i := 0; i < len(val.Args); i++ {
edge := Edge{val.Block, i}
// If incoming edge for phi is not visited, assume top optimistically.
// According to rules of meet:
// Top ∩ any = any
// Top participates in meet() but does not affect the result, so here
// we will ignore Top and only take other lattices into consideration.
if _, exist := t.visited[edge]; exist {
lt := t.getLatticeCell(val.Args[i])
if lt.tag == constant {
if optimisticLt.tag == top {
optimisticLt = lt
} else {
if !equals(optimisticLt, lt) {
// ConstantA ∩ ConstantB = Bottom
return lattice{bottom, nil}
}
}
} else if lt.tag == bottom {
// Bottom ∩ any = Bottom
return lattice{bottom, nil}
} else {
// Top ∩ any = any
}
} else {
// Top ∩ any = any
}
}
// ConstantA ∩ ConstantA = ConstantA or Top ∩ any = any
return optimisticLt
}
func computeLattice(f *Func, val *Value, args ...*Value) lattice {
// In general, we need to perform constant evaluation based on constant args:
//
// res := lattice{constant, nil}
// switch op {
// case OpAdd16:
// res.val = newConst(argLt1.val.AuxInt16() + argLt2.val.AuxInt16())
// case OpAdd32:
// res.val = newConst(argLt1.val.AuxInt32() + argLt2.val.AuxInt32())
// case OpDiv8:
// if !isDivideByZero(argLt2.val.AuxInt8()) {
// res.val = newConst(argLt1.val.AuxInt8() / argLt2.val.AuxInt8())
// }
// ...
// }
//
// However, this would create a huge switch for all opcodes that can be
// evaluated during compile time. Moreover, some operations can be evaluated
// only if its arguments satisfy additional conditions(e.g. divide by zero).
// It's fragile and error-prone. We did a trick by reusing the existing rules
// in generic rules for compile-time evaluation. But generic rules rewrite
// original value, this behavior is undesired, because the lattice of values
// may change multiple times, once it was rewritten, we lose the opportunity
// to change it permanently, which can lead to errors. For example, We cannot
// change its value immediately after visiting Phi, because some of its input
// edges may still not be visited at this moment.
constValue := f.newValue(val.Op, val.Type, f.Entry, val.Pos)
constValue.AddArgs(args...)
matched := rewriteValuegeneric(constValue)
if matched {
if isConst(constValue) {
return lattice{constant, constValue}
}
}
// Either we can not match generic rules for given value or it does not
// satisfy additional constraints(e.g. divide by zero), in these cases, clean
// up temporary value immediately in case they are not dominated by their args.
constValue.reset(OpInvalid)
return lattice{bottom, nil}
}
func (t *worklist) visitValue(val *Value) {
if !possibleConst(val) {
// fast fail for always worst Values, i.e. there is no lowering happen
// on them, their lattices must be initially worse Bottom.
return
}
oldLt := t.getLatticeCell(val)
defer func() {
// re-visit all uses of value if its lattice is changed
newLt := t.getLatticeCell(val)
if !equals(newLt, oldLt) {
if int8(oldLt.tag) > int8(newLt.tag) {
t.f.Fatalf("Must lower lattice\n")
}
t.addUses(val)
}
}()
switch val.Op {
// they are constant values, aren't they?
case OpConst64, OpConst32, OpConst16, OpConst8,
OpConstBool, OpConst32F, OpConst64F: //TODO: support ConstNil ConstString etc
t.latticeCells[val] = lattice{constant, val}
// lattice value of copy(x) actually means lattice value of (x)
case OpCopy:
t.latticeCells[val] = t.getLatticeCell(val.Args[0])
// phi should be processed specially
case OpPhi:
t.latticeCells[val] = t.meet(val)
// fold 1-input operations:
case
// negate
OpNeg8, OpNeg16, OpNeg32, OpNeg64, OpNeg32F, OpNeg64F,
OpCom8, OpCom16, OpCom32, OpCom64,
// math
OpFloor, OpCeil, OpTrunc, OpRoundToEven, OpSqrt,
// conversion
OpTrunc16to8, OpTrunc32to8, OpTrunc32to16, OpTrunc64to8,
OpTrunc64to16, OpTrunc64to32, OpCvt32to32F, OpCvt32to64F,
OpCvt64to32F, OpCvt64to64F, OpCvt32Fto32, OpCvt32Fto64,
OpCvt64Fto32, OpCvt64Fto64, OpCvt32Fto64F, OpCvt64Fto32F,
OpCvtBoolToUint8,
OpZeroExt8to16, OpZeroExt8to32, OpZeroExt8to64, OpZeroExt16to32,
OpZeroExt16to64, OpZeroExt32to64, OpSignExt8to16, OpSignExt8to32,
OpSignExt8to64, OpSignExt16to32, OpSignExt16to64, OpSignExt32to64,
// bit
OpCtz8, OpCtz16, OpCtz32, OpCtz64,
// mask
OpSlicemask,
// safety check
OpIsNonNil,
// not
OpNot:
lt1 := t.getLatticeCell(val.Args[0])
if lt1.tag == constant {
// here we take a shortcut by reusing generic rules to fold constants
t.latticeCells[val] = computeLattice(t.f, val, lt1.val)
} else {
t.latticeCells[val] = lattice{lt1.tag, nil}
}
// fold 2-input operations
case
// add
OpAdd64, OpAdd32, OpAdd16, OpAdd8,
OpAdd32F, OpAdd64F,
// sub
OpSub64, OpSub32, OpSub16, OpSub8,
OpSub32F, OpSub64F,
// mul
OpMul64, OpMul32, OpMul16, OpMul8,
OpMul32F, OpMul64F,
// div
OpDiv32F, OpDiv64F,
OpDiv8, OpDiv16, OpDiv32, OpDiv64,
OpDiv8u, OpDiv16u, OpDiv32u, OpDiv64u, //TODO: support div128u
// mod
OpMod8, OpMod16, OpMod32, OpMod64,
OpMod8u, OpMod16u, OpMod32u, OpMod64u,
// compare
OpEq64, OpEq32, OpEq16, OpEq8,
OpEq32F, OpEq64F,
OpLess64, OpLess32, OpLess16, OpLess8,
OpLess64U, OpLess32U, OpLess16U, OpLess8U,
OpLess32F, OpLess64F,
OpLeq64, OpLeq32, OpLeq16, OpLeq8,
OpLeq64U, OpLeq32U, OpLeq16U, OpLeq8U,
OpLeq32F, OpLeq64F,
OpEqB, OpNeqB,
// shift
OpLsh64x64, OpRsh64x64, OpRsh64Ux64, OpLsh32x64,
OpRsh32x64, OpRsh32Ux64, OpLsh16x64, OpRsh16x64,
OpRsh16Ux64, OpLsh8x64, OpRsh8x64, OpRsh8Ux64,
// safety check
OpIsInBounds, OpIsSliceInBounds,
// bit
OpAnd8, OpAnd16, OpAnd32, OpAnd64,
OpOr8, OpOr16, OpOr32, OpOr64,
OpXor8, OpXor16, OpXor32, OpXor64:
lt1 := t.getLatticeCell(val.Args[0])
lt2 := t.getLatticeCell(val.Args[1])
if lt1.tag == constant && lt2.tag == constant {
// here we take a shortcut by reusing generic rules to fold constants
t.latticeCells[val] = computeLattice(t.f, val, lt1.val, lt2.val)
} else {
if lt1.tag == bottom || lt2.tag == bottom {
t.latticeCells[val] = lattice{bottom, nil}
} else {
t.latticeCells[val] = lattice{top, nil}
}
}
default:
// Any other type of value cannot be a constant, they are always worst(Bottom)
}
}
// propagate propagates constants facts through CFG. If the block has single successor,
// add the successor anyway. If the block has multiple successors, only add the
// branch destination corresponding to lattice value of condition value.
func (t *worklist) propagate(block *Block) {
switch block.Kind {
case BlockExit, BlockRet, BlockRetJmp, BlockInvalid:
// control flow ends, do nothing then
break
case BlockDefer:
// we know nothing about control flow, add all branch destinations
t.edges = append(t.edges, block.Succs...)
case BlockFirst:
fallthrough // always takes the first branch
case BlockPlain:
t.edges = append(t.edges, block.Succs[0])
case BlockIf, BlockJumpTable:
cond := block.ControlValues()[0]
condLattice := t.getLatticeCell(cond)
if condLattice.tag == bottom {
// we know nothing about control flow, add all branch destinations
t.edges = append(t.edges, block.Succs...)
} else if condLattice.tag == constant {
// add branchIdx destinations depends on its condition
var branchIdx int64
if block.Kind == BlockIf {
branchIdx = 1 - condLattice.val.AuxInt
} else {
branchIdx = condLattice.val.AuxInt
}
t.edges = append(t.edges, block.Succs[branchIdx])
} else {
// condition value is not visited yet, don't propagate it now
}
default:
t.f.Fatalf("All kind of block should be processed above.")
}
}
// rewireSuccessor rewires corresponding successors according to constant value
// discovered by previous analysis. As the result, some successors become unreachable
// and thus can be removed in further deadcode phase
func rewireSuccessor(block *Block, constVal *Value) bool {
switch block.Kind {
case BlockIf:
block.removeEdge(int(constVal.AuxInt))
block.Kind = BlockPlain
block.Likely = BranchUnknown
block.ResetControls()
return true
case BlockJumpTable:
// Remove everything but the known taken branch.
idx := int(constVal.AuxInt)
if idx < 0 || idx >= len(block.Succs) {
// This can only happen in unreachable code,
// as an invariant of jump tables is that their
// input index is in range.
// See issue 64826.
return false
}
block.swapSuccessorsByIdx(0, idx)
for len(block.Succs) > 1 {
block.removeEdge(1)
}
block.Kind = BlockPlain
block.Likely = BranchUnknown
block.ResetControls()
return true
default:
return false
}
}
// replaceConst will replace non-constant values that have been proven by sccp
// to be constants.
func (t *worklist) replaceConst() (int, int) {
constCnt, rewireCnt := 0, 0
for val, lt := range t.latticeCells {
if lt.tag == constant {
if !isConst(val) {
if t.f.pass.debug > 0 {
fmt.Printf("Replace %v with %v\n", val.LongString(), lt.val.LongString())
}
val.reset(lt.val.Op)
val.AuxInt = lt.val.AuxInt
constCnt++
}
// If const value controls this block, rewires successors according to its value
ctrlBlock := t.defBlock[val]
for _, block := range ctrlBlock {
if rewireSuccessor(block, lt.val) {
rewireCnt++
if t.f.pass.debug > 0 {
fmt.Printf("Rewire %v %v successors\n", block.Kind, block)
}
}
}
}
}
return constCnt, rewireCnt
}