| // Copyright 2016 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package ssa |
| |
| import ( |
| "fmt" |
| "math" |
| ) |
| |
| type branch int |
| |
| const ( |
| unknown = iota |
| positive |
| negative |
| ) |
| |
| // relation represents the set of possible relations between |
| // pairs of variables (v, w). Without a priori knowledge the |
| // mask is lt | eq | gt meaning v can be less than, equal to or |
| // greater than w. When the execution path branches on the condition |
| // `v op w` the set of relations is updated to exclude any |
| // relation not possible due to `v op w` being true (or false). |
| // |
| // E.g. |
| // |
| // r := relation(...) |
| // |
| // if v < w { |
| // newR := r & lt |
| // } |
| // if v >= w { |
| // newR := r & (eq|gt) |
| // } |
| // if v != w { |
| // newR := r & (lt|gt) |
| // } |
| type relation uint |
| |
| const ( |
| lt relation = 1 << iota |
| eq |
| gt |
| ) |
| |
| // domain represents the domain of a variable pair in which a set |
| // of relations is known. For example, relations learned for unsigned |
| // pairs cannot be transferred to signed pairs because the same bit |
| // representation can mean something else. |
| type domain uint |
| |
| const ( |
| signed domain = 1 << iota |
| unsigned |
| pointer |
| boolean |
| ) |
| |
| type pair struct { |
| v, w *Value // a pair of values, ordered by ID. |
| // v can be nil, to mean the zero value. |
| // for booleans the zero value (v == nil) is false. |
| d domain |
| } |
| |
| // fact is a pair plus a relation for that pair. |
| type fact struct { |
| p pair |
| r relation |
| } |
| |
| // a limit records known upper and lower bounds for a value. |
| type limit struct { |
| min, max int64 // min <= value <= max, signed |
| umin, umax uint64 // umin <= value <= umax, unsigned |
| } |
| |
| func (l limit) String() string { |
| return fmt.Sprintf("sm,SM,um,UM=%d,%d,%d,%d", l.min, l.max, l.umin, l.umax) |
| } |
| |
| var noLimit = limit{math.MinInt64, math.MaxInt64, 0, math.MaxUint64} |
| |
| // a limitFact is a limit known for a particular value. |
| type limitFact struct { |
| vid ID |
| limit limit |
| } |
| |
| // factsTable keeps track of relations between pairs of values. |
| type factsTable struct { |
| facts map[pair]relation // current known set of relation |
| stack []fact // previous sets of relations |
| |
| // known lower and upper bounds on individual values. |
| limits map[ID]limit |
| limitStack []limitFact // previous entries |
| } |
| |
| // checkpointFact is an invalid value used for checkpointing |
| // and restoring factsTable. |
| var checkpointFact = fact{} |
| var checkpointBound = limitFact{} |
| |
| func newFactsTable() *factsTable { |
| ft := &factsTable{} |
| ft.facts = make(map[pair]relation) |
| ft.stack = make([]fact, 4) |
| ft.limits = make(map[ID]limit) |
| ft.limitStack = make([]limitFact, 4) |
| return ft |
| } |
| |
| // get returns the known possible relations between v and w. |
| // If v and w are not in the map it returns lt|eq|gt, i.e. any order. |
| func (ft *factsTable) get(v, w *Value, d domain) relation { |
| if v.isGenericIntConst() || w.isGenericIntConst() { |
| reversed := false |
| if v.isGenericIntConst() { |
| v, w = w, v |
| reversed = true |
| } |
| r := lt | eq | gt |
| lim, ok := ft.limits[v.ID] |
| if !ok { |
| return r |
| } |
| c := w.AuxInt |
| switch d { |
| case signed: |
| switch { |
| case c < lim.min: |
| r = gt |
| case c > lim.max: |
| r = lt |
| case c == lim.min && c == lim.max: |
| r = eq |
| case c == lim.min: |
| r = gt | eq |
| case c == lim.max: |
| r = lt | eq |
| } |
| case unsigned: |
| // TODO: also use signed data if lim.min >= 0? |
| var uc uint64 |
| switch w.Op { |
| case OpConst64: |
| uc = uint64(c) |
| case OpConst32: |
| uc = uint64(uint32(c)) |
| case OpConst16: |
| uc = uint64(uint16(c)) |
| case OpConst8: |
| uc = uint64(uint8(c)) |
| } |
| switch { |
| case uc < lim.umin: |
| r = gt |
| case uc > lim.umax: |
| r = lt |
| case uc == lim.umin && uc == lim.umax: |
| r = eq |
| case uc == lim.umin: |
| r = gt | eq |
| case uc == lim.umax: |
| r = lt | eq |
| } |
| } |
| if reversed { |
| return reverseBits[r] |
| } |
| return r |
| } |
| |
| reversed := false |
| if lessByID(w, v) { |
| v, w = w, v |
| reversed = !reversed |
| } |
| |
| p := pair{v, w, d} |
| r, ok := ft.facts[p] |
| if !ok { |
| if p.v == p.w { |
| r = eq |
| } else { |
| r = lt | eq | gt |
| } |
| } |
| |
| if reversed { |
| return reverseBits[r] |
| } |
| return r |
| } |
| |
| // update updates the set of relations between v and w in domain d |
| // restricting it to r. |
| func (ft *factsTable) update(parent *Block, v, w *Value, d domain, r relation) { |
| if lessByID(w, v) { |
| v, w = w, v |
| r = reverseBits[r] |
| } |
| |
| p := pair{v, w, d} |
| oldR := ft.get(v, w, d) |
| ft.stack = append(ft.stack, fact{p, oldR}) |
| ft.facts[p] = oldR & r |
| |
| // Extract bounds when comparing against constants |
| if v.isGenericIntConst() { |
| v, w = w, v |
| r = reverseBits[r] |
| } |
| if v != nil && w.isGenericIntConst() { |
| c := w.AuxInt |
| // Note: all the +1/-1 below could overflow/underflow. Either will |
| // still generate correct results, it will just lead to imprecision. |
| // In fact if there is overflow/underflow, the corresponding |
| // code is unreachable because the known range is outside the range |
| // of the value's type. |
| old, ok := ft.limits[v.ID] |
| if !ok { |
| old = noLimit |
| } |
| lim := old |
| // Update lim with the new information we know. |
| switch d { |
| case signed: |
| switch r { |
| case lt: |
| if c-1 < lim.max { |
| lim.max = c - 1 |
| } |
| case lt | eq: |
| if c < lim.max { |
| lim.max = c |
| } |
| case gt | eq: |
| if c > lim.min { |
| lim.min = c |
| } |
| case gt: |
| if c+1 > lim.min { |
| lim.min = c + 1 |
| } |
| case lt | gt: |
| if c == lim.min { |
| lim.min++ |
| } |
| if c == lim.max { |
| lim.max-- |
| } |
| case eq: |
| lim.min = c |
| lim.max = c |
| } |
| case unsigned: |
| var uc uint64 |
| switch w.Op { |
| case OpConst64: |
| uc = uint64(c) |
| case OpConst32: |
| uc = uint64(uint32(c)) |
| case OpConst16: |
| uc = uint64(uint16(c)) |
| case OpConst8: |
| uc = uint64(uint8(c)) |
| } |
| switch r { |
| case lt: |
| if uc-1 < lim.umax { |
| lim.umax = uc - 1 |
| } |
| case lt | eq: |
| if uc < lim.umax { |
| lim.umax = uc |
| } |
| case gt | eq: |
| if uc > lim.umin { |
| lim.umin = uc |
| } |
| case gt: |
| if uc+1 > lim.umin { |
| lim.umin = uc + 1 |
| } |
| case lt | gt: |
| if uc == lim.umin { |
| lim.umin++ |
| } |
| if uc == lim.umax { |
| lim.umax-- |
| } |
| case eq: |
| lim.umin = uc |
| lim.umax = uc |
| } |
| } |
| ft.limitStack = append(ft.limitStack, limitFact{v.ID, old}) |
| ft.limits[v.ID] = lim |
| if v.Block.Func.pass.debug > 2 { |
| v.Block.Func.Config.Warnl(parent.Line, "parent=%s, new limits %s %s %s", parent, v, w, lim.String()) |
| } |
| } |
| } |
| |
| // isNonNegative returns true if v is known to be non-negative. |
| func (ft *factsTable) isNonNegative(v *Value) bool { |
| if isNonNegative(v) { |
| return true |
| } |
| l, has := ft.limits[v.ID] |
| return has && (l.min >= 0 || l.umax <= math.MaxInt64) |
| } |
| |
| // checkpoint saves the current state of known relations. |
| // Called when descending on a branch. |
| func (ft *factsTable) checkpoint() { |
| ft.stack = append(ft.stack, checkpointFact) |
| ft.limitStack = append(ft.limitStack, checkpointBound) |
| } |
| |
| // restore restores known relation to the state just |
| // before the previous checkpoint. |
| // Called when backing up on a branch. |
| func (ft *factsTable) restore() { |
| for { |
| old := ft.stack[len(ft.stack)-1] |
| ft.stack = ft.stack[:len(ft.stack)-1] |
| if old == checkpointFact { |
| break |
| } |
| if old.r == lt|eq|gt { |
| delete(ft.facts, old.p) |
| } else { |
| ft.facts[old.p] = old.r |
| } |
| } |
| for { |
| old := ft.limitStack[len(ft.limitStack)-1] |
| ft.limitStack = ft.limitStack[:len(ft.limitStack)-1] |
| if old.vid == 0 { // checkpointBound |
| break |
| } |
| if old.limit == noLimit { |
| delete(ft.limits, old.vid) |
| } else { |
| ft.limits[old.vid] = old.limit |
| } |
| } |
| } |
| |
| func lessByID(v, w *Value) bool { |
| if v == nil && w == nil { |
| // Should not happen, but just in case. |
| return false |
| } |
| if v == nil { |
| return true |
| } |
| return w != nil && v.ID < w.ID |
| } |
| |
| var ( |
| reverseBits = [...]relation{0, 4, 2, 6, 1, 5, 3, 7} |
| |
| // maps what we learn when the positive branch is taken. |
| // For example: |
| // OpLess8: {signed, lt}, |
| // v1 = (OpLess8 v2 v3). |
| // If v1 branch is taken than we learn that the rangeMaks |
| // can be at most lt. |
| domainRelationTable = map[Op]struct { |
| d domain |
| r relation |
| }{ |
| OpEq8: {signed | unsigned, eq}, |
| OpEq16: {signed | unsigned, eq}, |
| OpEq32: {signed | unsigned, eq}, |
| OpEq64: {signed | unsigned, eq}, |
| OpEqPtr: {pointer, eq}, |
| |
| OpNeq8: {signed | unsigned, lt | gt}, |
| OpNeq16: {signed | unsigned, lt | gt}, |
| OpNeq32: {signed | unsigned, lt | gt}, |
| OpNeq64: {signed | unsigned, lt | gt}, |
| OpNeqPtr: {pointer, lt | gt}, |
| |
| OpLess8: {signed, lt}, |
| OpLess8U: {unsigned, lt}, |
| OpLess16: {signed, lt}, |
| OpLess16U: {unsigned, lt}, |
| OpLess32: {signed, lt}, |
| OpLess32U: {unsigned, lt}, |
| OpLess64: {signed, lt}, |
| OpLess64U: {unsigned, lt}, |
| |
| OpLeq8: {signed, lt | eq}, |
| OpLeq8U: {unsigned, lt | eq}, |
| OpLeq16: {signed, lt | eq}, |
| OpLeq16U: {unsigned, lt | eq}, |
| OpLeq32: {signed, lt | eq}, |
| OpLeq32U: {unsigned, lt | eq}, |
| OpLeq64: {signed, lt | eq}, |
| OpLeq64U: {unsigned, lt | eq}, |
| |
| OpGeq8: {signed, eq | gt}, |
| OpGeq8U: {unsigned, eq | gt}, |
| OpGeq16: {signed, eq | gt}, |
| OpGeq16U: {unsigned, eq | gt}, |
| OpGeq32: {signed, eq | gt}, |
| OpGeq32U: {unsigned, eq | gt}, |
| OpGeq64: {signed, eq | gt}, |
| OpGeq64U: {unsigned, eq | gt}, |
| |
| OpGreater8: {signed, gt}, |
| OpGreater8U: {unsigned, gt}, |
| OpGreater16: {signed, gt}, |
| OpGreater16U: {unsigned, gt}, |
| OpGreater32: {signed, gt}, |
| OpGreater32U: {unsigned, gt}, |
| OpGreater64: {signed, gt}, |
| OpGreater64U: {unsigned, gt}, |
| |
| // TODO: OpIsInBounds actually test 0 <= a < b. This means |
| // that the positive branch learns signed/LT and unsigned/LT |
| // but the negative branch only learns unsigned/GE. |
| OpIsInBounds: {unsigned, lt}, |
| OpIsSliceInBounds: {unsigned, lt | eq}, |
| } |
| ) |
| |
| // prove removes redundant BlockIf controls that can be inferred in a straight line. |
| // |
| // By far, the most common redundant pair are generated by bounds checking. |
| // For example for the code: |
| // |
| // a[i] = 4 |
| // foo(a[i]) |
| // |
| // The compiler will generate the following code: |
| // |
| // if i >= len(a) { |
| // panic("not in bounds") |
| // } |
| // a[i] = 4 |
| // if i >= len(a) { |
| // panic("not in bounds") |
| // } |
| // foo(a[i]) |
| // |
| // The second comparison i >= len(a) is clearly redundant because if the |
| // else branch of the first comparison is executed, we already know that i < len(a). |
| // The code for the second panic can be removed. |
| func prove(f *Func) { |
| // current node state |
| type walkState int |
| const ( |
| descend walkState = iota |
| simplify |
| ) |
| // work maintains the DFS stack. |
| type bp struct { |
| block *Block // current handled block |
| state walkState // what's to do |
| } |
| work := make([]bp, 0, 256) |
| work = append(work, bp{ |
| block: f.Entry, |
| state: descend, |
| }) |
| |
| ft := newFactsTable() |
| idom := f.Idom() |
| sdom := f.sdom() |
| |
| // DFS on the dominator tree. |
| for len(work) > 0 { |
| node := work[len(work)-1] |
| work = work[:len(work)-1] |
| parent := idom[node.block.ID] |
| branch := getBranch(sdom, parent, node.block) |
| |
| switch node.state { |
| case descend: |
| if branch != unknown { |
| ft.checkpoint() |
| c := parent.Control |
| updateRestrictions(parent, ft, boolean, nil, c, lt|gt, branch) |
| if tr, has := domainRelationTable[parent.Control.Op]; has { |
| // When we branched from parent we learned a new set of |
| // restrictions. Update the factsTable accordingly. |
| updateRestrictions(parent, ft, tr.d, c.Args[0], c.Args[1], tr.r, branch) |
| } |
| } |
| |
| work = append(work, bp{ |
| block: node.block, |
| state: simplify, |
| }) |
| for s := sdom.Child(node.block); s != nil; s = sdom.Sibling(s) { |
| work = append(work, bp{ |
| block: s, |
| state: descend, |
| }) |
| } |
| |
| case simplify: |
| succ := simplifyBlock(ft, node.block) |
| if succ != unknown { |
| b := node.block |
| b.Kind = BlockFirst |
| b.SetControl(nil) |
| if succ == negative { |
| b.swapSuccessors() |
| } |
| } |
| |
| if branch != unknown { |
| ft.restore() |
| } |
| } |
| } |
| } |
| |
| // getBranch returns the range restrictions added by p |
| // when reaching b. p is the immediate dominator of b. |
| func getBranch(sdom SparseTree, p *Block, b *Block) branch { |
| if p == nil || p.Kind != BlockIf { |
| return unknown |
| } |
| // If p and p.Succs[0] are dominators it means that every path |
| // from entry to b passes through p and p.Succs[0]. We care that |
| // no path from entry to b passes through p.Succs[1]. If p.Succs[0] |
| // has one predecessor then (apart from the degenerate case), |
| // there is no path from entry that can reach b through p.Succs[1]. |
| // TODO: how about p->yes->b->yes, i.e. a loop in yes. |
| if sdom.isAncestorEq(p.Succs[0].b, b) && len(p.Succs[0].b.Preds) == 1 { |
| return positive |
| } |
| if sdom.isAncestorEq(p.Succs[1].b, b) && len(p.Succs[1].b.Preds) == 1 { |
| return negative |
| } |
| return unknown |
| } |
| |
| // updateRestrictions updates restrictions from the immediate |
| // dominating block (p) using r. r is adjusted according to the branch taken. |
| func updateRestrictions(parent *Block, ft *factsTable, t domain, v, w *Value, r relation, branch branch) { |
| if t == 0 || branch == unknown { |
| // Trivial case: nothing to do, or branch unknown. |
| // Shoult not happen, but just in case. |
| return |
| } |
| if branch == negative { |
| // Negative branch taken, complement the relations. |
| r = (lt | eq | gt) ^ r |
| } |
| for i := domain(1); i <= t; i <<= 1 { |
| if t&i != 0 { |
| ft.update(parent, v, w, i, r) |
| } |
| } |
| } |
| |
| // simplifyBlock simplifies block known the restrictions in ft. |
| // Returns which branch must always be taken. |
| func simplifyBlock(ft *factsTable, b *Block) branch { |
| for _, v := range b.Values { |
| if v.Op != OpSlicemask { |
| continue |
| } |
| add := v.Args[0] |
| if add.Op != OpAdd64 && add.Op != OpAdd32 { |
| continue |
| } |
| // Note that the arg of slicemask was originally a sub, but |
| // was rewritten to an add by generic.rules (if the thing |
| // being subtracted was a constant). |
| x := add.Args[0] |
| y := add.Args[1] |
| if x.Op == OpConst64 || x.Op == OpConst32 { |
| x, y = y, x |
| } |
| if y.Op != OpConst64 && y.Op != OpConst32 { |
| continue |
| } |
| // slicemask(x + y) |
| // if x is larger than -y (y is negative), then slicemask is -1. |
| lim, ok := ft.limits[x.ID] |
| if !ok { |
| continue |
| } |
| if lim.umin > uint64(-y.AuxInt) { |
| if v.Args[0].Op == OpAdd64 { |
| v.reset(OpConst64) |
| } else { |
| v.reset(OpConst32) |
| } |
| if b.Func.pass.debug > 0 { |
| b.Func.Config.Warnl(v.Line, "Proved slicemask not needed") |
| } |
| v.AuxInt = -1 |
| } |
| } |
| |
| if b.Kind != BlockIf { |
| return unknown |
| } |
| |
| // First, checks if the condition itself is redundant. |
| m := ft.get(nil, b.Control, boolean) |
| if m == lt|gt { |
| if b.Func.pass.debug > 0 { |
| if b.Func.pass.debug > 1 { |
| b.Func.Config.Warnl(b.Line, "Proved boolean %s (%s)", b.Control.Op, b.Control) |
| } else { |
| b.Func.Config.Warnl(b.Line, "Proved boolean %s", b.Control.Op) |
| } |
| } |
| return positive |
| } |
| if m == eq { |
| if b.Func.pass.debug > 0 { |
| if b.Func.pass.debug > 1 { |
| b.Func.Config.Warnl(b.Line, "Disproved boolean %s (%s)", b.Control.Op, b.Control) |
| } else { |
| b.Func.Config.Warnl(b.Line, "Disproved boolean %s", b.Control.Op) |
| } |
| } |
| return negative |
| } |
| |
| // Next look check equalities. |
| c := b.Control |
| tr, has := domainRelationTable[c.Op] |
| if !has { |
| return unknown |
| } |
| |
| a0, a1 := c.Args[0], c.Args[1] |
| for d := domain(1); d <= tr.d; d <<= 1 { |
| if d&tr.d == 0 { |
| continue |
| } |
| |
| // tr.r represents in which case the positive branch is taken. |
| // m represents which cases are possible because of previous relations. |
| // If the set of possible relations m is included in the set of relations |
| // need to take the positive branch (or negative) then that branch will |
| // always be taken. |
| // For shortcut, if m == 0 then this block is dead code. |
| m := ft.get(a0, a1, d) |
| if m != 0 && tr.r&m == m { |
| if b.Func.pass.debug > 0 { |
| if b.Func.pass.debug > 1 { |
| b.Func.Config.Warnl(b.Line, "Proved %s (%s)", c.Op, c) |
| } else { |
| b.Func.Config.Warnl(b.Line, "Proved %s", c.Op) |
| } |
| } |
| return positive |
| } |
| if m != 0 && ((lt|eq|gt)^tr.r)&m == m { |
| if b.Func.pass.debug > 0 { |
| if b.Func.pass.debug > 1 { |
| b.Func.Config.Warnl(b.Line, "Disproved %s (%s)", c.Op, c) |
| } else { |
| b.Func.Config.Warnl(b.Line, "Disproved %s", c.Op) |
| } |
| } |
| return negative |
| } |
| } |
| |
| // HACK: If the first argument of IsInBounds or IsSliceInBounds |
| // is a constant and we already know that constant is smaller (or equal) |
| // to the upper bound than this is proven. Most useful in cases such as: |
| // if len(a) <= 1 { return } |
| // do something with a[1] |
| if (c.Op == OpIsInBounds || c.Op == OpIsSliceInBounds) && ft.isNonNegative(c.Args[0]) { |
| m := ft.get(a0, a1, signed) |
| if m != 0 && tr.r&m == m { |
| if b.Func.pass.debug > 0 { |
| if b.Func.pass.debug > 1 { |
| b.Func.Config.Warnl(b.Line, "Proved non-negative bounds %s (%s)", c.Op, c) |
| } else { |
| b.Func.Config.Warnl(b.Line, "Proved non-negative bounds %s", c.Op) |
| } |
| } |
| return positive |
| } |
| } |
| |
| return unknown |
| } |
| |
| // isNonNegative returns true is v is known to be greater or equal to zero. |
| func isNonNegative(v *Value) bool { |
| switch v.Op { |
| case OpConst64: |
| return v.AuxInt >= 0 |
| |
| case OpConst32: |
| return int32(v.AuxInt) >= 0 |
| |
| case OpStringLen, OpSliceLen, OpSliceCap, |
| OpZeroExt8to64, OpZeroExt16to64, OpZeroExt32to64: |
| return true |
| |
| case OpRsh64x64: |
| return isNonNegative(v.Args[0]) |
| } |
| return false |
| } |