blob: e36defb2e18056770d16be63a225cf39f974d233 [file] [log] [blame] [edit]
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build ignore
package main
import (
"bytes"
"fmt"
"io"
"log"
"os"
"slices"
"strconv"
"internal/runtime/gc"
"internal/runtime/gc/internal/gen"
)
const header = "// Code generated by mkasm.go. DO NOT EDIT.\n\n"
func main() {
generate("expand_amd64.s", genExpanders)
}
func generate(fileName string, genFunc func(*gen.File)) {
var buf bytes.Buffer
tee := io.MultiWriter(&buf, os.Stdout)
file := gen.NewFile(tee)
genFunc(file)
fmt.Fprintf(tee, header)
file.Compile()
f, err := os.Create(fileName)
if err != nil {
log.Fatal(err)
}
defer f.Close()
_, err = f.Write(buf.Bytes())
if err != nil {
log.Fatal(err)
}
}
func genExpanders(file *gen.File) {
gcExpandersAVX512 := make([]*gen.Func, len(gc.SizeClassToSize))
for sc, ob := range gc.SizeClassToSize {
if gc.SizeClassToNPages[sc] != 1 {
// These functions all produce a bitmap that covers exactly one
// page.
continue
}
if ob > gc.MinSizeForMallocHeader {
// This size class is too big to have a packed pointer/scalar bitmap.
break
}
xf := int(ob) / 8
log.Printf("size class %d bytes, expansion %dx", ob, xf)
fn := gen.NewFunc(fmt.Sprintf("expandAVX512_%d<>", xf))
ptrObjBits := gen.Arg[gen.Ptr[gen.Uint8x64]](fn)
if xf == 1 {
expandIdentity(ptrObjBits)
} else {
ok := gfExpander(xf, ptrObjBits)
if !ok {
log.Printf("failed to generate expander for size class %d", sc)
}
}
file.AddFunc(fn)
gcExpandersAVX512[sc] = fn
}
// Generate table mapping size class to expander PC
file.AddConst("·gcExpandersAVX512", gcExpandersAVX512)
}
// mat8x8 is an 8x8 bit matrix.
type mat8x8 struct {
mat [8]uint8
}
func matGroupToVec(mats *[8]mat8x8) [8]uint64 {
var out [8]uint64
for i, mat := range mats {
for j, row := range mat.mat {
// For some reason, Intel flips the rows.
out[i] |= uint64(row) << ((7 - j) * 8)
}
}
return out
}
// expandIdentity implements 1x expansion (that is, no expansion).
func expandIdentity(ptrObjBits gen.Ptr[gen.Uint8x64]) {
objBitsLo := gen.Deref(ptrObjBits)
objBitsHi := gen.Deref(ptrObjBits.AddConst(64))
gen.Return(objBitsLo, objBitsHi)
}
// gfExpander produces a function that expands each bit in an input bitmap into
// f consecutive bits in an output bitmap.
//
// The input is
//
// AX *[8]uint64 = A pointer to floor(1024/f) bits (f >= 2, so at most 512 bits)
//
// The output is
//
// Z1 [64]uint8 = The bottom 512 bits of the expanded bitmap
// Z2 [64]uint8 = The top 512 bits of the expanded bitmap
//
// TODO(austin): This should Z0/Z1.
func gfExpander(f int, ptrObjBits gen.Ptr[gen.Uint8x64]) bool {
// TODO(austin): For powers of 2 >= 8, we can use mask expansion ops to make this much simpler.
// TODO(austin): For f >= 8, I suspect there are better ways to do this.
//
// For example, we could use a mask expansion to get a full byte for each
// input bit, and separately create the bytes that blend adjacent bits, then
// shuffle those bytes together. Certainly for f >= 16 this makes sense
// because each of those bytes will be used, possibly more than once.
objBits := gen.Deref(ptrObjBits)
type term struct {
iByte, oByte int
mat mat8x8
}
var terms []term
// Iterate over all output bytes and construct the 8x8 GF2 matrix to compute
// the output byte from the appropriate input byte. Gather all of these into
// "terms".
for oByte := 0; oByte < 1024/8; oByte++ {
var byteMat mat8x8
iByte := -1
for oBit := oByte * 8; oBit < oByte*8+8; oBit++ {
iBit := oBit / f
if iByte == -1 {
iByte = iBit / 8
} else if iByte != iBit/8 {
log.Printf("output byte %d straddles input bytes %d and %d", oByte, iByte, iBit/8)
return false
}
// One way to view this is that the i'th row of the matrix will be
// ANDed with the input byte, and the parity of the result will set
// the i'th bit in the output. We use a simple 1 bit mask, so the
// parity is irrelevant beyond selecting out that one bit.
byteMat.mat[oBit%8] = 1 << (iBit % 8)
}
terms = append(terms, term{iByte, oByte, byteMat})
}
if false {
// Print input byte -> output byte as a matrix
maxIByte, maxOByte := 0, 0
for _, term := range terms {
maxIByte = max(maxIByte, term.iByte)
maxOByte = max(maxOByte, term.oByte)
}
iToO := make([][]rune, maxIByte+1)
for i := range iToO {
iToO[i] = make([]rune, maxOByte+1)
}
matMap := make(map[mat8x8]int)
for _, term := range terms {
i, ok := matMap[term.mat]
if !ok {
i = len(matMap)
matMap[term.mat] = i
}
iToO[term.iByte][term.oByte] = 'A' + rune(i)
}
for o := range maxOByte + 1 {
fmt.Printf("%d", o)
for i := range maxIByte + 1 {
fmt.Printf(",")
if mat := iToO[i][o]; mat != 0 {
fmt.Printf("%c", mat)
}
}
fmt.Println()
}
}
// In hardware, each (8 byte) matrix applies to 8 bytes of data in parallel,
// and we get to operate on up to 8 matrixes in parallel (or 64 values). That is:
//
// abcdefgh ijklmnop qrstuvwx yzABCDEF GHIJKLMN OPQRSTUV WXYZ0123 456789_+
// mat0 mat1 mat2 mat3 mat4 mat5 mat6 mat7
// Group the terms by matrix, but limit each group to 8 terms.
const termsPerGroup = 8 // Number of terms we can multiply by the same matrix.
const groupsPerSuperGroup = 8 // Number of matrixes we can fit in a vector.
matMap := make(map[mat8x8]int)
allMats := make(map[mat8x8]bool)
var termGroups [][]term
for _, term := range terms {
allMats[term.mat] = true
i, ok := matMap[term.mat]
if ok && f > groupsPerSuperGroup {
// The output is ultimately produced in two [64]uint8 registers.
// Getting every byte in the right place of each of these requires a
// final permutation that often requires more than one source.
//
// Up to 8x expansion, we can get a really nice grouping so we can use
// the same 8 matrix vector several times, without producing
// permutations that require more than two sources.
//
// Above 8x, however, we can't get nice matrixes anyway, so we
// instead prefer reducing the complexity of the permutations we
// need to produce the final outputs. To do this, avoid grouping
// together terms that are split across the two registers.
outRegister := termGroups[i][0].oByte / 64
if term.oByte/64 != outRegister {
ok = false
}
}
if !ok {
// Start a new term group.
i = len(termGroups)
matMap[term.mat] = i
termGroups = append(termGroups, nil)
}
termGroups[i] = append(termGroups[i], term)
if len(termGroups[i]) == termsPerGroup {
// This term group is full.
delete(matMap, term.mat)
}
}
for i, termGroup := range termGroups {
log.Printf("term group %d:", i)
for _, term := range termGroup {
log.Printf(" %+v", term)
}
}
// We can do 8 matrix multiplies in parallel, which is 8 term groups. Pack
// as many term groups as we can into each super-group to minimize the
// number of matrix multiplies.
//
// Ideally, we use the same matrix in each super-group, which might mean
// doing fewer than 8 multiplies at a time. That's fine because it never
// increases the total number of matrix multiplies.
//
// TODO: Packing the matrixes less densely may let us use more broadcast
// loads instead of general permutations, though. That replaces a load of
// the permutation with a load of the matrix, but is probably still slightly
// better.
var sgSize, nSuperGroups int
oneMatVec := f <= groupsPerSuperGroup
if oneMatVec {
// We can use the same matrix in each multiply by doing sgSize
// multiplies at a time.
sgSize = groupsPerSuperGroup / len(allMats) * len(allMats)
nSuperGroups = (len(termGroups) + sgSize - 1) / sgSize
} else {
// We can't use the same matrix for each multiply. Just do as many at a
// time as we can.
//
// TODO: This is going to produce several distinct matrixes, when we
// probably only need two. Be smarter about how we create super-groups
// in this case. Maybe we build up an array of super-groups and then the
// loop below just turns them into ops?
sgSize = 8
nSuperGroups = (len(termGroups) + groupsPerSuperGroup - 1) / groupsPerSuperGroup
}
// Construct each super-group.
var matGroup [8]mat8x8
var matMuls []gen.Uint8x64
var perm [128]int
for sgi := range nSuperGroups {
var iperm [64]uint8
for i := range iperm {
iperm[i] = 0xff // "Don't care"
}
// Pick off sgSize term groups.
superGroup := termGroups[:min(len(termGroups), sgSize)]
termGroups = termGroups[len(superGroup):]
// Build the matrix and permutations for this super-group.
var thisMatGroup [8]mat8x8
for i, termGroup := range superGroup {
// All terms in this group have the same matrix. Pick one.
thisMatGroup[i] = termGroup[0].mat
for j, term := range termGroup {
// Build the input permutation.
iperm[i*termsPerGroup+j] = uint8(term.iByte)
// Build the output permutation.
perm[term.oByte] = sgi*groupsPerSuperGroup*termsPerGroup + i*termsPerGroup + j
}
}
log.Printf("input permutation %d: %v", sgi, iperm)
// Check that we're not making more distinct matrixes than expected.
if oneMatVec {
if sgi == 0 {
matGroup = thisMatGroup
} else if matGroup != thisMatGroup {
log.Printf("super-groups have different matrixes:\n%+v\n%+v", matGroup, thisMatGroup)
return false
}
}
// Emit matrix op.
matConst := gen.ConstUint64x8(matGroupToVec(&thisMatGroup), fmt.Sprintf("*_mat%d<>", sgi))
inOp := objBits.Shuffle(gen.ConstUint8x64(iperm, fmt.Sprintf("*_inShuf%d<>", sgi)))
matMul := matConst.GF2P8Affine(inOp)
matMuls = append(matMuls, matMul)
}
log.Printf("output permutation: %v", perm)
outLo, ok := genShuffle("*_outShufLo", (*[64]int)(perm[:64]), matMuls...)
if !ok {
log.Printf("bad number of inputs to final shuffle: %d != 1, 2, or 4", len(matMuls))
return false
}
outHi, ok := genShuffle("*_outShufHi", (*[64]int)(perm[64:]), matMuls...)
if !ok {
log.Printf("bad number of inputs to final shuffle: %d != 1, 2, or 4", len(matMuls))
return false
}
gen.Return(outLo, outHi)
return true
}
func genShuffle(name string, perm *[64]int, args ...gen.Uint8x64) (gen.Uint8x64, bool) {
// Construct flattened permutation.
var vperm [64]byte
// Get the inputs used by this permutation.
var inputs []int
for i, src := range perm {
inputIdx := slices.Index(inputs, src/64)
if inputIdx == -1 {
inputIdx = len(inputs)
inputs = append(inputs, src/64)
}
vperm[i] = byte(src%64 | (inputIdx << 6))
}
// Emit instructions for easy cases.
switch len(inputs) {
case 1:
constOp := gen.ConstUint8x64(vperm, name)
return args[inputs[0]].Shuffle(constOp), true
case 2:
constOp := gen.ConstUint8x64(vperm, name)
return args[inputs[0]].Shuffle2(args[inputs[1]], constOp), true
}
// Harder case, we need to shuffle in from up to 2 more tables.
//
// Perform two shuffles. One shuffle will get its data from the first
// two inputs, the other shuffle will get its data from the other one
// or two inputs. All values they don't care each don't care about will
// be zeroed.
var vperms [2][64]byte
var masks [2]uint64
for j, idx := range vperm {
for i := range vperms {
vperms[i][j] = 0xff // "Don't care"
}
if idx == 0xff {
continue
}
vperms[idx/128][j] = idx % 128
masks[idx/128] |= uint64(1) << j
}
// Validate that the masks are fully disjoint.
if masks[0]^masks[1] != ^uint64(0) {
panic("bad shuffle!")
}
// Generate constants.
constOps := make([]gen.Uint8x64, len(vperms))
for i, v := range vperms {
constOps[i] = gen.ConstUint8x64(v, name+strconv.Itoa(i))
}
// Generate shuffles.
switch len(inputs) {
case 3:
r0 := args[inputs[0]].Shuffle2Zeroed(args[inputs[1]], constOps[0], gen.ConstMask64(masks[0]))
r1 := args[inputs[2]].ShuffleZeroed(constOps[1], gen.ConstMask64(masks[1]))
return r0.ToUint64x8().Or(r1.ToUint64x8()).ToUint8x64(), true
case 4:
r0 := args[inputs[0]].Shuffle2Zeroed(args[inputs[1]], constOps[0], gen.ConstMask64(masks[0]))
r1 := args[inputs[2]].Shuffle2Zeroed(args[inputs[3]], constOps[1], gen.ConstMask64(masks[1]))
return r0.ToUint64x8().Or(r1.ToUint64x8()).ToUint8x64(), true
}
// Too many inputs. To support more, we'd need to separate tables much earlier.
// Right now all the indices fit in a byte, but with >4 inputs they might not (>256 bytes).
return args[0], false
}