| // Copyright 2025 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| //go:build ignore |
| |
| package main |
| |
| import ( |
| "bytes" |
| "fmt" |
| "io" |
| "log" |
| "os" |
| "slices" |
| "strconv" |
| |
| "internal/runtime/gc" |
| "internal/runtime/gc/internal/gen" |
| ) |
| |
| const header = "// Code generated by mkasm.go. DO NOT EDIT.\n\n" |
| |
| func main() { |
| generate("expand_amd64.s", genExpanders) |
| } |
| |
| func generate(fileName string, genFunc func(*gen.File)) { |
| var buf bytes.Buffer |
| tee := io.MultiWriter(&buf, os.Stdout) |
| |
| file := gen.NewFile(tee) |
| |
| genFunc(file) |
| |
| fmt.Fprintf(tee, header) |
| file.Compile() |
| |
| f, err := os.Create(fileName) |
| if err != nil { |
| log.Fatal(err) |
| } |
| defer f.Close() |
| _, err = f.Write(buf.Bytes()) |
| if err != nil { |
| log.Fatal(err) |
| } |
| } |
| |
| func genExpanders(file *gen.File) { |
| gcExpandersAVX512 := make([]*gen.Func, len(gc.SizeClassToSize)) |
| for sc, ob := range gc.SizeClassToSize { |
| if gc.SizeClassToNPages[sc] != 1 { |
| // These functions all produce a bitmap that covers exactly one |
| // page. |
| continue |
| } |
| if ob > gc.MinSizeForMallocHeader { |
| // This size class is too big to have a packed pointer/scalar bitmap. |
| break |
| } |
| |
| xf := int(ob) / 8 |
| log.Printf("size class %d bytes, expansion %dx", ob, xf) |
| |
| fn := gen.NewFunc(fmt.Sprintf("expandAVX512_%d<>", xf)) |
| ptrObjBits := gen.Arg[gen.Ptr[gen.Uint8x64]](fn) |
| |
| if xf == 1 { |
| expandIdentity(ptrObjBits) |
| } else { |
| ok := gfExpander(xf, ptrObjBits) |
| if !ok { |
| log.Printf("failed to generate expander for size class %d", sc) |
| } |
| } |
| file.AddFunc(fn) |
| gcExpandersAVX512[sc] = fn |
| } |
| |
| // Generate table mapping size class to expander PC |
| file.AddConst("·gcExpandersAVX512", gcExpandersAVX512) |
| } |
| |
| // mat8x8 is an 8x8 bit matrix. |
| type mat8x8 struct { |
| mat [8]uint8 |
| } |
| |
| func matGroupToVec(mats *[8]mat8x8) [8]uint64 { |
| var out [8]uint64 |
| for i, mat := range mats { |
| for j, row := range mat.mat { |
| // For some reason, Intel flips the rows. |
| out[i] |= uint64(row) << ((7 - j) * 8) |
| } |
| } |
| return out |
| } |
| |
| // expandIdentity implements 1x expansion (that is, no expansion). |
| func expandIdentity(ptrObjBits gen.Ptr[gen.Uint8x64]) { |
| objBitsLo := gen.Deref(ptrObjBits) |
| objBitsHi := gen.Deref(ptrObjBits.AddConst(64)) |
| gen.Return(objBitsLo, objBitsHi) |
| } |
| |
| // gfExpander produces a function that expands each bit in an input bitmap into |
| // f consecutive bits in an output bitmap. |
| // |
| // The input is |
| // |
| // AX *[8]uint64 = A pointer to floor(1024/f) bits (f >= 2, so at most 512 bits) |
| // |
| // The output is |
| // |
| // Z1 [64]uint8 = The bottom 512 bits of the expanded bitmap |
| // Z2 [64]uint8 = The top 512 bits of the expanded bitmap |
| // |
| // TODO(austin): This should Z0/Z1. |
| func gfExpander(f int, ptrObjBits gen.Ptr[gen.Uint8x64]) bool { |
| // TODO(austin): For powers of 2 >= 8, we can use mask expansion ops to make this much simpler. |
| |
| // TODO(austin): For f >= 8, I suspect there are better ways to do this. |
| // |
| // For example, we could use a mask expansion to get a full byte for each |
| // input bit, and separately create the bytes that blend adjacent bits, then |
| // shuffle those bytes together. Certainly for f >= 16 this makes sense |
| // because each of those bytes will be used, possibly more than once. |
| |
| objBits := gen.Deref(ptrObjBits) |
| |
| type term struct { |
| iByte, oByte int |
| mat mat8x8 |
| } |
| var terms []term |
| |
| // Iterate over all output bytes and construct the 8x8 GF2 matrix to compute |
| // the output byte from the appropriate input byte. Gather all of these into |
| // "terms". |
| for oByte := 0; oByte < 1024/8; oByte++ { |
| var byteMat mat8x8 |
| iByte := -1 |
| for oBit := oByte * 8; oBit < oByte*8+8; oBit++ { |
| iBit := oBit / f |
| if iByte == -1 { |
| iByte = iBit / 8 |
| } else if iByte != iBit/8 { |
| log.Printf("output byte %d straddles input bytes %d and %d", oByte, iByte, iBit/8) |
| return false |
| } |
| // One way to view this is that the i'th row of the matrix will be |
| // ANDed with the input byte, and the parity of the result will set |
| // the i'th bit in the output. We use a simple 1 bit mask, so the |
| // parity is irrelevant beyond selecting out that one bit. |
| byteMat.mat[oBit%8] = 1 << (iBit % 8) |
| } |
| terms = append(terms, term{iByte, oByte, byteMat}) |
| } |
| |
| if false { |
| // Print input byte -> output byte as a matrix |
| maxIByte, maxOByte := 0, 0 |
| for _, term := range terms { |
| maxIByte = max(maxIByte, term.iByte) |
| maxOByte = max(maxOByte, term.oByte) |
| } |
| iToO := make([][]rune, maxIByte+1) |
| for i := range iToO { |
| iToO[i] = make([]rune, maxOByte+1) |
| } |
| matMap := make(map[mat8x8]int) |
| for _, term := range terms { |
| i, ok := matMap[term.mat] |
| if !ok { |
| i = len(matMap) |
| matMap[term.mat] = i |
| } |
| iToO[term.iByte][term.oByte] = 'A' + rune(i) |
| } |
| for o := range maxOByte + 1 { |
| fmt.Printf("%d", o) |
| for i := range maxIByte + 1 { |
| fmt.Printf(",") |
| if mat := iToO[i][o]; mat != 0 { |
| fmt.Printf("%c", mat) |
| } |
| } |
| fmt.Println() |
| } |
| } |
| |
| // In hardware, each (8 byte) matrix applies to 8 bytes of data in parallel, |
| // and we get to operate on up to 8 matrixes in parallel (or 64 values). That is: |
| // |
| // abcdefgh ijklmnop qrstuvwx yzABCDEF GHIJKLMN OPQRSTUV WXYZ0123 456789_+ |
| // mat0 mat1 mat2 mat3 mat4 mat5 mat6 mat7 |
| |
| // Group the terms by matrix, but limit each group to 8 terms. |
| const termsPerGroup = 8 // Number of terms we can multiply by the same matrix. |
| const groupsPerSuperGroup = 8 // Number of matrixes we can fit in a vector. |
| |
| matMap := make(map[mat8x8]int) |
| allMats := make(map[mat8x8]bool) |
| var termGroups [][]term |
| for _, term := range terms { |
| allMats[term.mat] = true |
| |
| i, ok := matMap[term.mat] |
| if ok && f > groupsPerSuperGroup { |
| // The output is ultimately produced in two [64]uint8 registers. |
| // Getting every byte in the right place of each of these requires a |
| // final permutation that often requires more than one source. |
| // |
| // Up to 8x expansion, we can get a really nice grouping so we can use |
| // the same 8 matrix vector several times, without producing |
| // permutations that require more than two sources. |
| // |
| // Above 8x, however, we can't get nice matrixes anyway, so we |
| // instead prefer reducing the complexity of the permutations we |
| // need to produce the final outputs. To do this, avoid grouping |
| // together terms that are split across the two registers. |
| outRegister := termGroups[i][0].oByte / 64 |
| if term.oByte/64 != outRegister { |
| ok = false |
| } |
| } |
| if !ok { |
| // Start a new term group. |
| i = len(termGroups) |
| matMap[term.mat] = i |
| termGroups = append(termGroups, nil) |
| } |
| |
| termGroups[i] = append(termGroups[i], term) |
| |
| if len(termGroups[i]) == termsPerGroup { |
| // This term group is full. |
| delete(matMap, term.mat) |
| } |
| } |
| |
| for i, termGroup := range termGroups { |
| log.Printf("term group %d:", i) |
| for _, term := range termGroup { |
| log.Printf(" %+v", term) |
| } |
| } |
| |
| // We can do 8 matrix multiplies in parallel, which is 8 term groups. Pack |
| // as many term groups as we can into each super-group to minimize the |
| // number of matrix multiplies. |
| // |
| // Ideally, we use the same matrix in each super-group, which might mean |
| // doing fewer than 8 multiplies at a time. That's fine because it never |
| // increases the total number of matrix multiplies. |
| // |
| // TODO: Packing the matrixes less densely may let us use more broadcast |
| // loads instead of general permutations, though. That replaces a load of |
| // the permutation with a load of the matrix, but is probably still slightly |
| // better. |
| var sgSize, nSuperGroups int |
| oneMatVec := f <= groupsPerSuperGroup |
| if oneMatVec { |
| // We can use the same matrix in each multiply by doing sgSize |
| // multiplies at a time. |
| sgSize = groupsPerSuperGroup / len(allMats) * len(allMats) |
| nSuperGroups = (len(termGroups) + sgSize - 1) / sgSize |
| } else { |
| // We can't use the same matrix for each multiply. Just do as many at a |
| // time as we can. |
| // |
| // TODO: This is going to produce several distinct matrixes, when we |
| // probably only need two. Be smarter about how we create super-groups |
| // in this case. Maybe we build up an array of super-groups and then the |
| // loop below just turns them into ops? |
| sgSize = 8 |
| nSuperGroups = (len(termGroups) + groupsPerSuperGroup - 1) / groupsPerSuperGroup |
| } |
| |
| // Construct each super-group. |
| var matGroup [8]mat8x8 |
| var matMuls []gen.Uint8x64 |
| var perm [128]int |
| for sgi := range nSuperGroups { |
| var iperm [64]uint8 |
| for i := range iperm { |
| iperm[i] = 0xff // "Don't care" |
| } |
| // Pick off sgSize term groups. |
| superGroup := termGroups[:min(len(termGroups), sgSize)] |
| termGroups = termGroups[len(superGroup):] |
| // Build the matrix and permutations for this super-group. |
| var thisMatGroup [8]mat8x8 |
| for i, termGroup := range superGroup { |
| // All terms in this group have the same matrix. Pick one. |
| thisMatGroup[i] = termGroup[0].mat |
| for j, term := range termGroup { |
| // Build the input permutation. |
| iperm[i*termsPerGroup+j] = uint8(term.iByte) |
| // Build the output permutation. |
| perm[term.oByte] = sgi*groupsPerSuperGroup*termsPerGroup + i*termsPerGroup + j |
| } |
| } |
| log.Printf("input permutation %d: %v", sgi, iperm) |
| |
| // Check that we're not making more distinct matrixes than expected. |
| if oneMatVec { |
| if sgi == 0 { |
| matGroup = thisMatGroup |
| } else if matGroup != thisMatGroup { |
| log.Printf("super-groups have different matrixes:\n%+v\n%+v", matGroup, thisMatGroup) |
| return false |
| } |
| } |
| |
| // Emit matrix op. |
| matConst := gen.ConstUint64x8(matGroupToVec(&thisMatGroup), fmt.Sprintf("*_mat%d<>", sgi)) |
| inOp := objBits.Shuffle(gen.ConstUint8x64(iperm, fmt.Sprintf("*_inShuf%d<>", sgi))) |
| matMul := matConst.GF2P8Affine(inOp) |
| matMuls = append(matMuls, matMul) |
| } |
| |
| log.Printf("output permutation: %v", perm) |
| |
| outLo, ok := genShuffle("*_outShufLo", (*[64]int)(perm[:64]), matMuls...) |
| if !ok { |
| log.Printf("bad number of inputs to final shuffle: %d != 1, 2, or 4", len(matMuls)) |
| return false |
| } |
| outHi, ok := genShuffle("*_outShufHi", (*[64]int)(perm[64:]), matMuls...) |
| if !ok { |
| log.Printf("bad number of inputs to final shuffle: %d != 1, 2, or 4", len(matMuls)) |
| return false |
| } |
| gen.Return(outLo, outHi) |
| |
| return true |
| } |
| |
| func genShuffle(name string, perm *[64]int, args ...gen.Uint8x64) (gen.Uint8x64, bool) { |
| // Construct flattened permutation. |
| var vperm [64]byte |
| |
| // Get the inputs used by this permutation. |
| var inputs []int |
| for i, src := range perm { |
| inputIdx := slices.Index(inputs, src/64) |
| if inputIdx == -1 { |
| inputIdx = len(inputs) |
| inputs = append(inputs, src/64) |
| } |
| vperm[i] = byte(src%64 | (inputIdx << 6)) |
| } |
| |
| // Emit instructions for easy cases. |
| switch len(inputs) { |
| case 1: |
| constOp := gen.ConstUint8x64(vperm, name) |
| return args[inputs[0]].Shuffle(constOp), true |
| case 2: |
| constOp := gen.ConstUint8x64(vperm, name) |
| return args[inputs[0]].Shuffle2(args[inputs[1]], constOp), true |
| } |
| |
| // Harder case, we need to shuffle in from up to 2 more tables. |
| // |
| // Perform two shuffles. One shuffle will get its data from the first |
| // two inputs, the other shuffle will get its data from the other one |
| // or two inputs. All values they don't care each don't care about will |
| // be zeroed. |
| var vperms [2][64]byte |
| var masks [2]uint64 |
| for j, idx := range vperm { |
| for i := range vperms { |
| vperms[i][j] = 0xff // "Don't care" |
| } |
| if idx == 0xff { |
| continue |
| } |
| vperms[idx/128][j] = idx % 128 |
| masks[idx/128] |= uint64(1) << j |
| } |
| |
| // Validate that the masks are fully disjoint. |
| if masks[0]^masks[1] != ^uint64(0) { |
| panic("bad shuffle!") |
| } |
| |
| // Generate constants. |
| constOps := make([]gen.Uint8x64, len(vperms)) |
| for i, v := range vperms { |
| constOps[i] = gen.ConstUint8x64(v, name+strconv.Itoa(i)) |
| } |
| |
| // Generate shuffles. |
| switch len(inputs) { |
| case 3: |
| r0 := args[inputs[0]].Shuffle2Zeroed(args[inputs[1]], constOps[0], gen.ConstMask64(masks[0])) |
| r1 := args[inputs[2]].ShuffleZeroed(constOps[1], gen.ConstMask64(masks[1])) |
| return r0.ToUint64x8().Or(r1.ToUint64x8()).ToUint8x64(), true |
| case 4: |
| r0 := args[inputs[0]].Shuffle2Zeroed(args[inputs[1]], constOps[0], gen.ConstMask64(masks[0])) |
| r1 := args[inputs[2]].Shuffle2Zeroed(args[inputs[3]], constOps[1], gen.ConstMask64(masks[1])) |
| return r0.ToUint64x8().Or(r1.ToUint64x8()).ToUint8x64(), true |
| } |
| |
| // Too many inputs. To support more, we'd need to separate tables much earlier. |
| // Right now all the indices fit in a byte, but with >4 inputs they might not (>256 bytes). |
| return args[0], false |
| } |