blob: a7e8ede5bca82b52295939b2d80849806ad9aa84 [file] [log] [blame]
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssa
import (
"cmd/compile/internal/base"
"cmd/compile/internal/types"
"cmd/internal/src"
"sort"
)
// memcombine combines smaller loads and stores into larger ones.
// We ensure this generates good code for encoding/binary operations.
// It may help other cases also.
func memcombine(f *Func) {
// This optimization requires that the architecture has
// unaligned loads and unaligned stores.
if !f.Config.unalignedOK {
return
}
memcombineLoads(f)
memcombineStores(f)
}
func memcombineLoads(f *Func) {
// Find "OR trees" to start with.
mark := f.newSparseSet(f.NumValues())
defer f.retSparseSet(mark)
var order []*Value
// Mark all values that are the argument of an OR.
for _, b := range f.Blocks {
for _, v := range b.Values {
if v.Op == OpOr16 || v.Op == OpOr32 || v.Op == OpOr64 {
mark.add(v.Args[0].ID)
mark.add(v.Args[1].ID)
}
}
}
for _, b := range f.Blocks {
order = order[:0]
for _, v := range b.Values {
if v.Op != OpOr16 && v.Op != OpOr32 && v.Op != OpOr64 {
continue
}
if mark.contains(v.ID) {
// marked - means it is not the root of an OR tree
continue
}
// Add the OR tree rooted at v to the order.
// We use BFS here, but any walk that puts roots before leaves would work.
i := len(order)
order = append(order, v)
for ; i < len(order); i++ {
x := order[i]
for j := 0; j < 2; j++ {
a := x.Args[j]
if a.Op == OpOr16 || a.Op == OpOr32 || a.Op == OpOr64 {
order = append(order, a)
}
}
}
}
for _, v := range order {
max := f.Config.RegSize
switch v.Op {
case OpOr64:
case OpOr32:
max = 4
case OpOr16:
max = 2
default:
continue
}
for n := max; n > 1; n /= 2 {
if combineLoads(v, n) {
break
}
}
}
}
}
// A BaseAddress represents the address ptr+idx, where
// ptr is a pointer type and idx is an integer type.
// idx may be nil, in which case it is treated as 0.
type BaseAddress struct {
ptr *Value
idx *Value
}
// splitPtr returns the base address of ptr and any
// constant offset from that base.
// BaseAddress{ptr,nil},0 is always a valid result, but splitPtr
// tries to peel away as many constants into off as possible.
func splitPtr(ptr *Value) (BaseAddress, int64) {
var idx *Value
var off int64
for {
if ptr.Op == OpOffPtr {
off += ptr.AuxInt
ptr = ptr.Args[0]
} else if ptr.Op == OpAddPtr {
if idx != nil {
// We have two or more indexing values.
// Pick the first one we found.
return BaseAddress{ptr: ptr, idx: idx}, off
}
idx = ptr.Args[1]
if idx.Op == OpAdd32 || idx.Op == OpAdd64 {
if idx.Args[0].Op == OpConst32 || idx.Args[0].Op == OpConst64 {
off += idx.Args[0].AuxInt
idx = idx.Args[1]
} else if idx.Args[1].Op == OpConst32 || idx.Args[1].Op == OpConst64 {
off += idx.Args[1].AuxInt
idx = idx.Args[0]
}
}
ptr = ptr.Args[0]
} else {
return BaseAddress{ptr: ptr, idx: idx}, off
}
}
}
func combineLoads(root *Value, n int64) bool {
orOp := root.Op
var shiftOp Op
switch orOp {
case OpOr64:
shiftOp = OpLsh64x64
case OpOr32:
shiftOp = OpLsh32x64
case OpOr16:
shiftOp = OpLsh16x64
default:
return false
}
// Find n values that are ORed together with the above op.
a := make([]*Value, 0, 8)
a = append(a, root)
for i := 0; i < len(a) && int64(len(a)) < n; i++ {
v := a[i]
if v.Uses != 1 && v != root {
// Something in this subtree is used somewhere else.
return false
}
if v.Op == orOp {
a[i] = v.Args[0]
a = append(a, v.Args[1])
i--
}
}
if int64(len(a)) != n {
return false
}
// Check that the first entry to see what ops we're looking for.
// All the entries should be of the form shift(extend(load)), maybe with no shift.
v := a[0]
if v.Op == shiftOp {
v = v.Args[0]
}
var extOp Op
if orOp == OpOr64 && (v.Op == OpZeroExt8to64 || v.Op == OpZeroExt16to64 || v.Op == OpZeroExt32to64) ||
orOp == OpOr32 && (v.Op == OpZeroExt8to32 || v.Op == OpZeroExt16to32) ||
orOp == OpOr16 && v.Op == OpZeroExt8to16 {
extOp = v.Op
v = v.Args[0]
} else {
return false
}
if v.Op != OpLoad {
return false
}
base, _ := splitPtr(v.Args[0])
mem := v.Args[1]
size := v.Type.Size()
if root.Block.Func.Config.arch == "S390X" {
// s390x can't handle unaligned accesses to global variables.
if base.ptr.Op == OpAddr {
return false
}
}
// Check all the entries, extract useful info.
type LoadRecord struct {
load *Value
offset int64 // offset of load address from base
shift int64
}
r := make([]LoadRecord, n, 8)
for i := int64(0); i < n; i++ {
v := a[i]
if v.Uses != 1 {
return false
}
shift := int64(0)
if v.Op == shiftOp {
if v.Args[1].Op != OpConst64 {
return false
}
shift = v.Args[1].AuxInt
v = v.Args[0]
if v.Uses != 1 {
return false
}
}
if v.Op != extOp {
return false
}
load := v.Args[0]
if load.Op != OpLoad {
return false
}
if load.Uses != 1 {
return false
}
if load.Args[1] != mem {
return false
}
p, off := splitPtr(load.Args[0])
if p != base {
return false
}
r[i] = LoadRecord{load: load, offset: off, shift: shift}
}
// Sort in memory address order.
sort.Slice(r, func(i, j int) bool {
return r[i].offset < r[j].offset
})
// Check that we have contiguous offsets.
for i := int64(0); i < n; i++ {
if r[i].offset != r[0].offset+i*size {
return false
}
}
// Check for reads in little-endian or big-endian order.
shift0 := r[0].shift
isLittleEndian := true
for i := int64(0); i < n; i++ {
if r[i].shift != shift0+i*size*8 {
isLittleEndian = false
break
}
}
isBigEndian := true
for i := int64(0); i < n; i++ {
if r[i].shift != shift0-i*size*8 {
isBigEndian = false
break
}
}
if !isLittleEndian && !isBigEndian {
return false
}
// Find a place to put the new load.
// This is tricky, because it has to be at a point where
// its memory argument is live. We can't just put it in root.Block.
// We use the block of the latest load.
loads := make([]*Value, n, 8)
for i := int64(0); i < n; i++ {
loads[i] = r[i].load
}
loadBlock := mergePoint(root.Block, loads...)
if loadBlock == nil {
return false
}
// Find a source position to use.
pos := src.NoXPos
for _, load := range loads {
if load.Block == loadBlock {
pos = load.Pos
break
}
}
if pos == src.NoXPos {
return false
}
// Check to see if we need byte swap before storing.
needSwap := isLittleEndian && root.Block.Func.Config.BigEndian ||
isBigEndian && !root.Block.Func.Config.BigEndian
if needSwap && (size != 1 || !root.Block.Func.Config.haveByteSwap(n)) {
return false
}
// This is the commit point.
// First, issue load at lowest address.
v = loadBlock.NewValue2(pos, OpLoad, sizeType(n*size), r[0].load.Args[0], mem)
// Byte swap if needed,
if needSwap {
v = byteSwap(loadBlock, pos, v)
}
// Extend if needed.
if n*size < root.Type.Size() {
v = zeroExtend(loadBlock, pos, v, n*size, root.Type.Size())
}
// Shift if needed.
if isLittleEndian && shift0 != 0 {
v = leftShift(loadBlock, pos, v, shift0)
}
if isBigEndian && shift0-(n-1)*size*8 != 0 {
v = leftShift(loadBlock, pos, v, shift0-(n-1)*size*8)
}
// Install with (Copy v).
root.reset(OpCopy)
root.AddArg(v)
// Clobber the loads, just to prevent additional work being done on
// subtrees (which are now unreachable).
for i := int64(0); i < n; i++ {
clobber(r[i].load)
}
return true
}
func memcombineStores(f *Func) {
mark := f.newSparseSet(f.NumValues())
defer f.retSparseSet(mark)
var order []*Value
for _, b := range f.Blocks {
// Mark all stores which are not last in a store sequence.
mark.clear()
for _, v := range b.Values {
if v.Op == OpStore {
mark.add(v.MemoryArg().ID)
}
}
// pick an order for visiting stores such that
// later stores come earlier in the ordering.
order = order[:0]
for _, v := range b.Values {
if v.Op != OpStore {
continue
}
if mark.contains(v.ID) {
continue // not last in a chain of stores
}
for {
order = append(order, v)
v = v.Args[2]
if v.Block != b || v.Op != OpStore {
break
}
}
}
// Look for combining opportunities at each store in queue order.
for _, v := range order {
if v.Op != OpStore { // already rewritten
continue
}
size := v.Aux.(*types.Type).Size()
if size >= f.Config.RegSize || size == 0 {
continue
}
for n := f.Config.RegSize / size; n > 1; n /= 2 {
if combineStores(v, n) {
continue
}
}
}
}
}
// Try to combine the n stores ending in root.
// Returns true if successful.
func combineStores(root *Value, n int64) bool {
// Helper functions.
type StoreRecord struct {
store *Value
offset int64
}
getShiftBase := func(a []StoreRecord) *Value {
x := a[0].store.Args[1]
y := a[1].store.Args[1]
switch x.Op {
case OpTrunc64to8, OpTrunc64to16, OpTrunc64to32, OpTrunc32to8, OpTrunc32to16, OpTrunc16to8:
x = x.Args[0]
default:
return nil
}
switch y.Op {
case OpTrunc64to8, OpTrunc64to16, OpTrunc64to32, OpTrunc32to8, OpTrunc32to16, OpTrunc16to8:
y = y.Args[0]
default:
return nil
}
var x2 *Value
switch x.Op {
case OpRsh64Ux64, OpRsh32Ux64, OpRsh16Ux64:
x2 = x.Args[0]
default:
}
var y2 *Value
switch y.Op {
case OpRsh64Ux64, OpRsh32Ux64, OpRsh16Ux64:
y2 = y.Args[0]
default:
}
if y2 == x {
// a shift of x and x itself.
return x
}
if x2 == y {
// a shift of y and y itself.
return y
}
if x2 == y2 {
// 2 shifts both of the same argument.
return x2
}
return nil
}
isShiftBase := func(v, base *Value) bool {
val := v.Args[1]
switch val.Op {
case OpTrunc64to8, OpTrunc64to16, OpTrunc64to32, OpTrunc32to8, OpTrunc32to16, OpTrunc16to8:
val = val.Args[0]
default:
return false
}
if val == base {
return true
}
switch val.Op {
case OpRsh64Ux64, OpRsh32Ux64, OpRsh16Ux64:
val = val.Args[0]
default:
return false
}
return val == base
}
shift := func(v, base *Value) int64 {
val := v.Args[1]
switch val.Op {
case OpTrunc64to8, OpTrunc64to16, OpTrunc64to32, OpTrunc32to8, OpTrunc32to16, OpTrunc16to8:
val = val.Args[0]
default:
return -1
}
if val == base {
return 0
}
switch val.Op {
case OpRsh64Ux64, OpRsh32Ux64, OpRsh16Ux64:
val = val.Args[1]
default:
return -1
}
if val.Op != OpConst64 {
return -1
}
return val.AuxInt
}
// Element size of the individual stores.
size := root.Aux.(*types.Type).Size()
if size*n > root.Block.Func.Config.RegSize {
return false
}
// Gather n stores to look at. Check easy conditions we require.
a := make([]StoreRecord, 0, 8)
rbase, roff := splitPtr(root.Args[0])
if root.Block.Func.Config.arch == "S390X" {
// s390x can't handle unaligned accesses to global variables.
if rbase.ptr.Op == OpAddr {
return false
}
}
a = append(a, StoreRecord{root, roff})
for i, x := int64(1), root.Args[2]; i < n; i, x = i+1, x.Args[2] {
if x.Op != OpStore {
return false
}
if x.Block != root.Block {
return false
}
if x.Uses != 1 { // Note: root can have more than one use.
return false
}
if x.Aux.(*types.Type).Size() != size {
// TODO: the constant source and consecutive load source cases
// do not need all the stores to be the same size.
return false
}
base, off := splitPtr(x.Args[0])
if base != rbase {
return false
}
a = append(a, StoreRecord{x, off})
}
// Before we sort, grab the memory arg the result should have.
mem := a[n-1].store.Args[2]
// Also grab position of first store (last in array = first in memory order).
pos := a[n-1].store.Pos
// Sort stores in increasing address order.
sort.Slice(a, func(i, j int) bool {
return a[i].offset < a[j].offset
})
// Check that everything is written to sequential locations.
for i := int64(0); i < n; i++ {
if a[i].offset != a[0].offset+i*size {
return false
}
}
// Memory location we're going to write at (the lowest one).
ptr := a[0].store.Args[0]
// Check for constant stores
isConst := true
for i := int64(0); i < n; i++ {
switch a[i].store.Args[1].Op {
case OpConst32, OpConst16, OpConst8, OpConstBool:
default:
isConst = false
break
}
}
if isConst {
// Modify root to do all the stores.
var c int64
mask := int64(1)<<(8*size) - 1
for i := int64(0); i < n; i++ {
s := 8 * size * int64(i)
if root.Block.Func.Config.BigEndian {
s = 8*size*(n-1) - s
}
c |= (a[i].store.Args[1].AuxInt & mask) << s
}
var cv *Value
switch size * n {
case 2:
cv = root.Block.Func.ConstInt16(types.Types[types.TUINT16], int16(c))
case 4:
cv = root.Block.Func.ConstInt32(types.Types[types.TUINT32], int32(c))
case 8:
cv = root.Block.Func.ConstInt64(types.Types[types.TUINT64], c)
}
// Move all the stores to the root.
for i := int64(0); i < n; i++ {
v := a[i].store
if v == root {
v.Aux = cv.Type // widen store type
v.Pos = pos
v.SetArg(0, ptr)
v.SetArg(1, cv)
v.SetArg(2, mem)
} else {
clobber(v)
v.Type = types.Types[types.TBOOL] // erase memory type
}
}
return true
}
// Check for consecutive loads as the source of the stores.
var loadMem *Value
var loadBase BaseAddress
var loadIdx int64
for i := int64(0); i < n; i++ {
load := a[i].store.Args[1]
if load.Op != OpLoad {
loadMem = nil
break
}
if load.Uses != 1 {
loadMem = nil
break
}
if load.Type.IsPtr() {
// Don't combine stores containing a pointer, as we need
// a write barrier for those. This can't currently happen,
// but might in the future if we ever have another
// 8-byte-reg/4-byte-ptr architecture like amd64p32.
loadMem = nil
break
}
mem := load.Args[1]
base, idx := splitPtr(load.Args[0])
if loadMem == nil {
// First one we found
loadMem = mem
loadBase = base
loadIdx = idx
continue
}
if base != loadBase || mem != loadMem {
loadMem = nil
break
}
if idx != loadIdx+(a[i].offset-a[0].offset) {
loadMem = nil
break
}
}
if loadMem != nil {
// Modify the first load to do a larger load instead.
load := a[0].store.Args[1]
switch size * n {
case 2:
load.Type = types.Types[types.TUINT16]
case 4:
load.Type = types.Types[types.TUINT32]
case 8:
load.Type = types.Types[types.TUINT64]
}
// Modify root to do the store.
for i := int64(0); i < n; i++ {
v := a[i].store
if v == root {
v.Aux = load.Type // widen store type
v.Pos = pos
v.SetArg(0, ptr)
v.SetArg(1, load)
v.SetArg(2, mem)
} else {
clobber(v)
v.Type = types.Types[types.TBOOL] // erase memory type
}
}
return true
}
// Check that all the shift/trunc are of the same base value.
shiftBase := getShiftBase(a)
if shiftBase == nil {
return false
}
for i := int64(0); i < n; i++ {
if !isShiftBase(a[i].store, shiftBase) {
return false
}
}
// Check for writes in little-endian or big-endian order.
isLittleEndian := true
shift0 := shift(a[0].store, shiftBase)
for i := int64(1); i < n; i++ {
if shift(a[i].store, shiftBase) != shift0+i*size*8 {
isLittleEndian = false
break
}
}
isBigEndian := true
for i := int64(1); i < n; i++ {
if shift(a[i].store, shiftBase) != shift0-i*size*8 {
isBigEndian = false
break
}
}
if !isLittleEndian && !isBigEndian {
return false
}
// Check to see if we need byte swap before storing.
needSwap := isLittleEndian && root.Block.Func.Config.BigEndian ||
isBigEndian && !root.Block.Func.Config.BigEndian
if needSwap && (size != 1 || !root.Block.Func.Config.haveByteSwap(n)) {
return false
}
// This is the commit point.
// Modify root to do all the stores.
sv := shiftBase
if isLittleEndian && shift0 != 0 {
sv = rightShift(root.Block, root.Pos, sv, shift0)
}
if isBigEndian && shift0-(n-1)*size*8 != 0 {
sv = rightShift(root.Block, root.Pos, sv, shift0-(n-1)*size*8)
}
if sv.Type.Size() > size*n {
sv = truncate(root.Block, root.Pos, sv, sv.Type.Size(), size*n)
}
if needSwap {
sv = byteSwap(root.Block, root.Pos, sv)
}
// Move all the stores to the root.
for i := int64(0); i < n; i++ {
v := a[i].store
if v == root {
v.Aux = sv.Type // widen store type
v.Pos = pos
v.SetArg(0, ptr)
v.SetArg(1, sv)
v.SetArg(2, mem)
} else {
clobber(v)
v.Type = types.Types[types.TBOOL] // erase memory type
}
}
return true
}
func sizeType(size int64) *types.Type {
switch size {
case 8:
return types.Types[types.TUINT64]
case 4:
return types.Types[types.TUINT32]
case 2:
return types.Types[types.TUINT16]
default:
base.Fatalf("bad size %d\n", size)
return nil
}
}
func truncate(b *Block, pos src.XPos, v *Value, from, to int64) *Value {
switch from*10 + to {
case 82:
return b.NewValue1(pos, OpTrunc64to16, types.Types[types.TUINT16], v)
case 84:
return b.NewValue1(pos, OpTrunc64to32, types.Types[types.TUINT32], v)
case 42:
return b.NewValue1(pos, OpTrunc32to16, types.Types[types.TUINT16], v)
default:
base.Fatalf("bad sizes %d %d\n", from, to)
return nil
}
}
func zeroExtend(b *Block, pos src.XPos, v *Value, from, to int64) *Value {
switch from*10 + to {
case 24:
return b.NewValue1(pos, OpZeroExt16to32, types.Types[types.TUINT32], v)
case 28:
return b.NewValue1(pos, OpZeroExt16to64, types.Types[types.TUINT64], v)
case 48:
return b.NewValue1(pos, OpZeroExt32to64, types.Types[types.TUINT64], v)
default:
base.Fatalf("bad sizes %d %d\n", from, to)
return nil
}
}
func leftShift(b *Block, pos src.XPos, v *Value, shift int64) *Value {
s := b.Func.ConstInt64(types.Types[types.TUINT64], shift)
size := v.Type.Size()
switch size {
case 8:
return b.NewValue2(pos, OpLsh64x64, v.Type, v, s)
case 4:
return b.NewValue2(pos, OpLsh32x64, v.Type, v, s)
case 2:
return b.NewValue2(pos, OpLsh16x64, v.Type, v, s)
default:
base.Fatalf("bad size %d\n", size)
return nil
}
}
func rightShift(b *Block, pos src.XPos, v *Value, shift int64) *Value {
s := b.Func.ConstInt64(types.Types[types.TUINT64], shift)
size := v.Type.Size()
switch size {
case 8:
return b.NewValue2(pos, OpRsh64Ux64, v.Type, v, s)
case 4:
return b.NewValue2(pos, OpRsh32Ux64, v.Type, v, s)
case 2:
return b.NewValue2(pos, OpRsh16Ux64, v.Type, v, s)
default:
base.Fatalf("bad size %d\n", size)
return nil
}
}
func byteSwap(b *Block, pos src.XPos, v *Value) *Value {
switch v.Type.Size() {
case 8:
return b.NewValue1(pos, OpBswap64, v.Type, v)
case 4:
return b.NewValue1(pos, OpBswap32, v.Type, v)
case 2:
return b.NewValue1(pos, OpBswap16, v.Type, v)
default:
v.Fatalf("bad size %d\n", v.Type.Size())
return nil
}
}