blob: 007bfc9060f805026e5bc2d6b762c964e80b75bd [file] [log] [blame]
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package asmgen
// mulAddVWW generates mulAddVWW, which does z, c = x*m + a.
func mulAddVWW(a *Asm) {
f := a.Func("func mulAddVWW(z, x []Word, m, a Word) (c Word)")
if a.AltCarry().Valid() {
addMulVirtualCarry(f, 0)
return
}
addMul(f, "", "x", 0)
}
// addMulVVWW generates addMulVVWW which does z, c = x + y*m + a.
// (A more pedantic name would be addMulAddVVWW.)
func addMulVVWW(a *Asm) {
f := a.Func("func addMulVVWW(z, x, y []Word, m, a Word) (c Word)")
// If the architecture has virtual carries, emit that version unconditionally.
if a.AltCarry().Valid() {
addMulVirtualCarry(f, 1)
return
}
// If the architecture optionally has two carries, test and emit both versions.
if a.JmpEnable(OptionAltCarry, "altcarry") {
regs := a.RegsUsed()
addMul(f, "x", "y", 1)
a.Label("altcarry")
a.SetOption(OptionAltCarry, true)
a.SetRegsUsed(regs)
addMulAlt(f)
a.SetOption(OptionAltCarry, false)
return
}
// Otherwise emit the one-carry form.
addMul(f, "x", "y", 1)
}
// Computing z = addsrc + m*mulsrc + a, we need:
//
// for i := range z {
// lo, hi := m * mulsrc[i]
// lo, carry = bits.Add(lo, a, 0)
// lo, carryAlt = bits.Add(lo, addsrc[i], 0)
// z[i] = lo
// a = hi + carry + carryAlt // cannot overflow
// }
//
// The final addition cannot overflow because after processing N words,
// the maximum possible value is (for a 64-bit system):
//
// (2**64N - 1) + (2**64 - 1)*(2**64N - 1) + (2**64 - 1)
// = (2**64)*(2**64N - 1) + (2**64 - 1)
// = 2**64(N+1) - 1,
//
// which fits in N+1 words (the high order one being the new value of a).
//
// (For example, with 3 decimal words, 999 + 9*999 + 9 = 999*10 + 9 = 9999.)
//
// If we unroll the loop a bit, then we can chain the carries in two passes.
// Consider:
//
// lo0, hi0 := m * mulsrc[i]
// lo0, carry = bits.Add(lo0, a, 0)
// lo0, carryAlt = bits.Add(lo0, addsrc[i], 0)
// z[i] = lo0
// a = hi + carry + carryAlt // cannot overflow
//
// lo1, hi1 := m * mulsrc[i]
// lo1, carry = bits.Add(lo1, a, 0)
// lo1, carryAlt = bits.Add(lo1, addsrc[i], 0)
// z[i] = lo1
// a = hi + carry + carryAlt // cannot overflow
//
// lo2, hi2 := m * mulsrc[i]
// lo2, carry = bits.Add(lo2, a, 0)
// lo2, carryAlt = bits.Add(lo2, addsrc[i], 0)
// z[i] = lo2
// a = hi + carry + carryAlt // cannot overflow
//
// lo3, hi3 := m * mulsrc[i]
// lo3, carry = bits.Add(lo3, a, 0)
// lo3, carryAlt = bits.Add(lo3, addsrc[i], 0)
// z[i] = lo3
// a = hi + carry + carryAlt // cannot overflow
//
// There are three ways we can optimize this sequence.
//
// (1) Reordering, we can chain carries so that we can use one hardware carry flag
// but amortize the cost of saving and restoring it across multiple instructions:
//
// // multiply
// lo0, hi0 := m * mulsrc[i]
// lo1, hi1 := m * mulsrc[i+1]
// lo2, hi2 := m * mulsrc[i+2]
// lo3, hi3 := m * mulsrc[i+3]
//
// lo0, carry = bits.Add(lo0, a, 0)
// lo1, carry = bits.Add(lo1, hi0, carry)
// lo2, carry = bits.Add(lo2, hi1, carry)
// lo3, carry = bits.Add(lo3, hi2, carry)
// a = hi3 + carry // cannot overflow
//
// // add
// lo0, carryAlt = bits.Add(lo0, addsrc[i], 0)
// lo1, carryAlt = bits.Add(lo1, addsrc[i+1], carryAlt)
// lo2, carryAlt = bits.Add(lo2, addsrc[i+2], carryAlt)
// lo3, carryAlt = bits.Add(lo3, addrsc[i+3], carryAlt)
// a = a + carryAlt // cannot overflow
//
// z[i] = lo0
// z[i+1] = lo1
// z[i+2] = lo2
// z[i+3] = lo3
//
// addMul takes this approach, using the hardware carry flag
// first for carry and then for carryAlt.
//
// (2) addMulAlt assumes there are two hardware carry flags available.
// It dedicates one each to carry and carryAlt, so that a multi-block
// unrolling can keep the flags in hardware across all the blocks.
// So even if the block size is 1, the code can do:
//
// // multiply and add
// lo0, hi0 := m * mulsrc[i]
// lo0, carry = bits.Add(lo0, a, 0)
// lo0, carryAlt = bits.Add(lo0, addsrc[i], 0)
// z[i] = lo0
//
// lo1, hi1 := m * mulsrc[i+1]
// lo1, carry = bits.Add(lo1, hi0, carry)
// lo1, carryAlt = bits.Add(lo1, addsrc[i+1], carryAlt)
// z[i+1] = lo1
//
// lo2, hi2 := m * mulsrc[i+2]
// lo2, carry = bits.Add(lo2, hi1, carry)
// lo2, carryAlt = bits.Add(lo2, addsrc[i+2], carryAlt)
// z[i+2] = lo2
//
// lo3, hi3 := m * mulsrc[i+3]
// lo3, carry = bits.Add(lo3, hi2, carry)
// lo3, carryAlt = bits.Add(lo3, addrsc[i+3], carryAlt)
// z[i+3] = lo2
//
// a = hi3 + carry + carryAlt // cannot overflow
//
// (3) addMulVirtualCarry optimizes for systems with explicitly computed carry bits
// (loong64, mips, riscv64), cutting the number of actual instructions almost by half.
// Look again at the original word-at-a-time version:
//
// lo1, hi1 := m * mulsrc[i]
// lo1, carry = bits.Add(lo1, a, 0)
// lo1, carryAlt = bits.Add(lo1, addsrc[i], 0)
// z[i] = lo1
// a = hi + carry + carryAlt // cannot overflow
//
// Although it uses four adds per word, those are cheap adds: the two bits.Add adds
// use two instructions each (ADD+SLTU) and the final + adds only use one ADD each,
// for a total of 6 instructions per word. In contrast, the middle stanzas in (2) use
// only two “adds” per word, but these are SetCarry|UseCarry adds, which compile to
// five instruction each, for a total of 10 instructions per word. So the word-at-a-time
// loop is actually better. And we can reorder things slightly to use only a single carry bit:
//
// lo1, hi1 := m * mulsrc[i]
// lo1, carry = bits.Add(lo1, a, 0)
// a = hi + carry
// lo1, carry = bits.Add(lo1, addsrc[i], 0)
// a = a + carry
// z[i] = lo1
func addMul(f *Func, addsrc, mulsrc string, mulIndex int) {
a := f.Asm
mh := HintNone
if a.Arch == Arch386 && addsrc != "" {
mh = HintMemOK // too few registers otherwise
}
m := f.ArgHint("m", mh)
c := f.Arg("a")
n := f.Arg("z_len")
p := f.Pipe()
if addsrc != "" {
p.SetHint(addsrc, HintMemOK)
}
p.SetHint(mulsrc, HintMulSrc)
unroll := []int{1, 4}
switch a.Arch {
case Arch386:
unroll = []int{1} // too few registers
case ArchARM:
p.SetMaxColumns(2) // too few registers (but more than 386)
case ArchARM64:
unroll = []int{1, 8} // 5% speedup on c4as16
}
// See the large comment above for an explanation of the code being generated.
// This is optimization strategy 1.
p.Start(n, unroll...)
p.Loop(func(in, out [][]Reg) {
a.Comment("multiply")
prev := c
flag := SetCarry
for i, x := range in[mulIndex] {
hi := a.RegHint(HintMulHi)
a.MulWide(m, x, x, hi)
a.Add(prev, x, x, flag)
flag = UseCarry | SetCarry
if prev != c {
a.Free(prev)
}
out[0][i] = x
prev = hi
}
a.Add(a.Imm(0), prev, c, UseCarry|SmashCarry)
if addsrc != "" {
a.Comment("add")
flag := SetCarry
for i, x := range in[0] {
a.Add(x, out[0][i], out[0][i], flag)
flag = UseCarry | SetCarry
}
a.Add(a.Imm(0), c, c, UseCarry|SmashCarry)
}
p.StoreN(out)
})
f.StoreArg(c, "c")
a.Ret()
}
func addMulAlt(f *Func) {
a := f.Asm
m := f.ArgHint("m", HintMulSrc)
c := f.Arg("a")
n := f.Arg("z_len")
// On amd64, we need a non-immediate for the AtUnrollEnd adds.
r0 := a.ZR()
if !r0.Valid() {
r0 = a.Reg()
a.Mov(a.Imm(0), r0)
}
p := f.Pipe()
p.SetLabel("alt")
p.SetHint("x", HintMemOK)
p.SetHint("y", HintMemOK)
if a.Arch == ArchAMD64 {
p.SetMaxColumns(2)
}
// See the large comment above for an explanation of the code being generated.
// This is optimization strategy (2).
var hi Reg
prev := c
p.Start(n, 1, 8)
p.AtUnrollStart(func() {
a.Comment("multiply and add")
a.ClearCarry(AddCarry | AltCarry)
a.ClearCarry(AddCarry)
hi = a.Reg()
})
p.AtUnrollEnd(func() {
a.Add(r0, prev, c, UseCarry|SmashCarry)
a.Add(r0, c, c, UseCarry|SmashCarry|AltCarry)
prev = c
})
p.Loop(func(in, out [][]Reg) {
for i, y := range in[1] {
x := in[0][i]
lo := y
if lo.IsMem() {
lo = a.Reg()
}
a.MulWide(m, y, lo, hi)
a.Add(prev, lo, lo, UseCarry|SetCarry)
a.Add(x, lo, lo, UseCarry|SetCarry|AltCarry)
out[0][i] = lo
prev, hi = hi, prev
}
p.StoreN(out)
})
f.StoreArg(c, "c")
a.Ret()
}
func addMulVirtualCarry(f *Func, mulIndex int) {
a := f.Asm
m := f.Arg("m")
c := f.Arg("a")
n := f.Arg("z_len")
// See the large comment above for an explanation of the code being generated.
// This is optimization strategy (3).
p := f.Pipe()
p.Start(n, 1, 4)
p.Loop(func(in, out [][]Reg) {
a.Comment("synthetic carry, one column at a time")
lo, hi := a.Reg(), a.Reg()
for i, x := range in[mulIndex] {
a.MulWide(m, x, lo, hi)
if mulIndex == 1 {
a.Add(in[0][i], lo, lo, SetCarry)
a.Add(a.Imm(0), hi, hi, UseCarry|SmashCarry)
}
a.Add(c, lo, x, SetCarry)
a.Add(a.Imm(0), hi, c, UseCarry|SmashCarry)
out[0][i] = x
}
p.StoreN(out)
})
f.StoreArg(c, "c")
a.Ret()
}