blob: 7a8823483a6d702aaa5fa52c04f9edca644ba75c [file] [log] [blame]
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"bytes"
"fmt"
"slices"
"strings"
"text/template"
)
type tplRuleData struct {
tplName string // e.g. "sftimm"
GoOp string // e.g. "ShiftAllLeft"
GoType string // e.g. "Uint32x8"
Args string // e.g. "x y"
Asm string // e.g. "VPSLLD256"
ArgsOut string // e.g. "x y"
MaskInConvert string // e.g. "VPMOVVec32x8ToM"
MaskOutConvert string // e.g. "VPMOVMToVec32x8"
ElementSize int // e.g. 32
Size int // e.g. 128
ArgsLoadAddr string // [Args] with its last vreg arg being a concrete "(VMOVDQUload* ptr mem)", and might contain mask.
ArgsAddr string // [Args] with its last vreg arg being replaced by "ptr", and might contain mask, and with a "mem" at the end.
FeatCheck string // e.g. "v.Block.CPUfeatures.hasFeature(CPUavx512)" -- for a ssa/_gen rules file.
}
var (
ruleTemplates = template.Must(template.New("simdRules").Parse(`
{{define "pureVreg"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.Asm}} {{.ArgsOut}})
{{end}}
{{define "maskIn"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.Asm}} {{.ArgsOut}} ({{.MaskInConvert}} <types.TypeMask> mask))
{{end}}
{{define "maskOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.MaskOutConvert}} ({{.Asm}} {{.ArgsOut}}))
{{end}}
{{define "maskInMaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.MaskOutConvert}} ({{.Asm}} {{.ArgsOut}} ({{.MaskInConvert}} <types.TypeMask> mask)))
{{end}}
{{define "sftimm"}}({{.Asm}} x (MOVQconst [c])) => ({{.Asm}}const [uint8(c)] x)
{{end}}
{{define "masksftimm"}}({{.Asm}} x (MOVQconst [c]) mask) => ({{.Asm}}const [uint8(c)] x mask)
{{end}}
{{define "vregMem"}}({{.Asm}} {{.ArgsLoadAddr}}) && canMergeLoad(v, l) && clobber(l) => ({{.Asm}}load {{.ArgsAddr}})
{{end}}
{{define "vregMemFeatCheck"}}({{.Asm}} {{.ArgsLoadAddr}}) && {{.FeatCheck}} && canMergeLoad(v, l) && clobber(l)=> ({{.Asm}}load {{.ArgsAddr}})
{{end}}
`))
)
func (d tplRuleData) MaskOptimization(asmCheck map[string]bool) string {
asmNoMask := d.Asm
if i := strings.Index(asmNoMask, "Masked"); i == -1 {
return ""
}
asmNoMask = strings.ReplaceAll(asmNoMask, "Masked", "")
if asmCheck[asmNoMask] == false {
return ""
}
for _, nope := range []string{"VMOVDQU", "VPCOMPRESS", "VCOMPRESS", "VPEXPAND", "VEXPAND", "VPBLENDM", "VMOVUP"} {
if strings.HasPrefix(asmNoMask, nope) {
return ""
}
}
size := asmNoMask[len(asmNoMask)-3:]
if strings.HasSuffix(asmNoMask, "const") {
sufLen := len("128const")
size = asmNoMask[len(asmNoMask)-sufLen:][:3]
}
switch size {
case "128", "256", "512":
default:
panic("Unexpected operation size on " + d.Asm)
}
switch d.ElementSize {
case 8, 16, 32, 64:
default:
panic(fmt.Errorf("Unexpected operation width %d on %v", d.ElementSize, d.Asm))
}
return fmt.Sprintf("(VMOVDQU%dMasked%s (%s %s) mask) => (%s %s mask)\n", d.ElementSize, size, asmNoMask, d.Args, d.Asm, d.Args)
}
// SSA rewrite rules need to appear in a most-to-least-specific order. This works for that.
var tmplOrder = map[string]int{
"masksftimm": 0,
"sftimm": 1,
"maskInMaskOut": 2,
"maskOut": 3,
"maskIn": 4,
"pureVreg": 5,
"vregMem": 6,
}
func compareTplRuleData(x, y tplRuleData) int {
if c := compareNatural(x.GoOp, y.GoOp); c != 0 {
return c
}
if c := compareNatural(x.GoType, y.GoType); c != 0 {
return c
}
if c := compareNatural(x.Args, y.Args); c != 0 {
return c
}
if x.tplName == y.tplName {
return 0
}
xo, xok := tmplOrder[x.tplName]
yo, yok := tmplOrder[y.tplName]
if !xok {
panic(fmt.Errorf("Unexpected template name %s, please add to tmplOrder", x.tplName))
}
if !yok {
panic(fmt.Errorf("Unexpected template name %s, please add to tmplOrder", y.tplName))
}
return xo - yo
}
// writeSIMDRules generates the lowering and rewrite rules for ssa and writes it to simdAMD64.rules
// within the specified directory.
func writeSIMDRules(ops []Operation) *bytes.Buffer {
buffer := new(bytes.Buffer)
buffer.WriteString(generatedHeader + "\n")
// asm -> masked merging rules
maskedMergeOpts := make(map[string]string)
s2n := map[int]string{8: "B", 16: "W", 32: "D", 64: "Q"}
asmCheck := map[string]bool{}
var allData []tplRuleData
var optData []tplRuleData // for mask peephole optimizations, and other misc
var memOptData []tplRuleData // for memory peephole optimizations
memOpSeen := make(map[string]bool)
for _, opr := range ops {
opInShape, opOutShape, maskType, immType, gOp := opr.shape()
asm := machineOpName(maskType, gOp)
vregInCnt := len(gOp.In)
if maskType == OneMask {
vregInCnt--
}
data := tplRuleData{
GoOp: gOp.Go,
Asm: asm,
}
if vregInCnt == 1 {
data.Args = "x"
data.ArgsOut = data.Args
} else if vregInCnt == 2 {
data.Args = "x y"
data.ArgsOut = data.Args
} else if vregInCnt == 3 {
data.Args = "x y z"
data.ArgsOut = data.Args
} else {
panic(fmt.Errorf("simdgen does not support more than 3 vreg in inputs"))
}
if immType == ConstImm {
data.ArgsOut = fmt.Sprintf("[%s] %s", *opr.In[0].Const, data.ArgsOut)
} else if immType == VarImm {
data.Args = fmt.Sprintf("[a] %s", data.Args)
data.ArgsOut = fmt.Sprintf("[a] %s", data.ArgsOut)
} else if immType == ConstVarImm {
data.Args = fmt.Sprintf("[a] %s", data.Args)
data.ArgsOut = fmt.Sprintf("[a+%s] %s", *opr.In[0].Const, data.ArgsOut)
}
goType := func(op Operation) string {
if op.OperandOrder != nil {
switch *op.OperandOrder {
case "21Type1", "231Type1":
// Permute uses operand[1] for method receiver.
return *op.In[1].Go
}
}
return *op.In[0].Go
}
var tplName string
// If class overwrite is happening, that's not really a mask but a vreg.
if opOutShape == OneVregOut || opOutShape == OneVregOutAtIn || gOp.Out[0].OverwriteClass != nil {
switch opInShape {
case OneImmIn:
tplName = "pureVreg"
data.GoType = goType(gOp)
case PureVregIn:
tplName = "pureVreg"
data.GoType = goType(gOp)
case OneKmaskImmIn:
fallthrough
case OneKmaskIn:
tplName = "maskIn"
data.GoType = goType(gOp)
rearIdx := len(gOp.In) - 1
// Mask is at the end.
width := *gOp.In[rearIdx].ElemBits
data.MaskInConvert = fmt.Sprintf("VPMOVVec%dx%dToM", width, *gOp.In[rearIdx].Lanes)
data.ElementSize = width
case PureKmaskIn:
panic(fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations"))
}
} else if opOutShape == OneGregOut {
tplName = "pureVreg" // TODO this will be wrong
data.GoType = goType(gOp)
} else {
// OneKmaskOut case
data.MaskOutConvert = fmt.Sprintf("VPMOVMToVec%dx%d", *gOp.Out[0].ElemBits, *gOp.In[0].Lanes)
switch opInShape {
case OneImmIn:
fallthrough
case PureVregIn:
tplName = "maskOut"
data.GoType = goType(gOp)
case OneKmaskImmIn:
fallthrough
case OneKmaskIn:
tplName = "maskInMaskOut"
data.GoType = goType(gOp)
rearIdx := len(gOp.In) - 1
data.MaskInConvert = fmt.Sprintf("VPMOVVec%dx%dToM", *gOp.In[rearIdx].ElemBits, *gOp.In[rearIdx].Lanes)
case PureKmaskIn:
panic(fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations"))
}
}
if gOp.SpecialLower != nil {
if *gOp.SpecialLower == "sftimm" {
if data.GoType[0] == 'I' {
// only do these for signed types, it is a duplicate rewrite for unsigned
sftImmData := data
if tplName == "maskIn" {
sftImmData.tplName = "masksftimm"
} else {
sftImmData.tplName = "sftimm"
}
allData = append(allData, sftImmData)
asmCheck[sftImmData.Asm+"const"] = true
}
} else {
panic("simdgen sees unknwon special lower " + *gOp.SpecialLower + ", maybe implement it?")
}
}
if gOp.MemFeatures != nil && *gOp.MemFeatures == "vbcst" {
// sanity check
selected := true
for _, a := range gOp.In {
if a.TreatLikeAScalarOfSize != nil {
selected = false
break
}
}
if _, ok := memOpSeen[data.Asm]; ok {
selected = false
}
if selected {
memOpSeen[data.Asm] = true
lastVreg := gOp.In[vregInCnt-1]
// sanity check
if lastVreg.Class != "vreg" {
panic(fmt.Errorf("simdgen expects vbcst replaced operand to be a vreg, but %v found", lastVreg))
}
memOpData := data
// Remove the last vreg from the arg and change it to a load.
origArgs := data.Args[:len(data.Args)-1]
// Prepare imm args.
immArg := ""
immArgCombineOff := " [off] "
if immType != NoImm && immType != InvalidImm {
_, after, found := strings.Cut(origArgs, "]")
if found {
origArgs = after
}
immArg = "[c] "
immArgCombineOff = " [makeValAndOff(int32(uint8(c)),off)] "
}
memOpData.ArgsLoadAddr = immArg + origArgs + fmt.Sprintf("l:(VMOVDQUload%d {sym} [off] ptr mem)", *lastVreg.Bits)
// Remove the last vreg from the arg and change it to "ptr".
memOpData.ArgsAddr = "{sym}" + immArgCombineOff + origArgs + "ptr"
if maskType == OneMask {
memOpData.ArgsAddr += " mask"
memOpData.ArgsLoadAddr += " mask"
}
memOpData.ArgsAddr += " mem"
if gOp.MemFeaturesData != nil {
_, feat2 := getVbcstData(*gOp.MemFeaturesData)
knownFeatChecks := map[string]string{
"AVX": "v.Block.CPUfeatures.hasFeature(CPUavx)",
"AVX2": "v.Block.CPUfeatures.hasFeature(CPUavx2)",
"AVX512": "v.Block.CPUfeatures.hasFeature(CPUavx512)",
}
memOpData.FeatCheck = knownFeatChecks[feat2]
memOpData.tplName = "vregMemFeatCheck"
} else {
memOpData.tplName = "vregMem"
}
memOptData = append(memOptData, memOpData)
asmCheck[memOpData.Asm+"load"] = true
}
}
// Generate the masked merging optimization rules
if gOp.hasMaskedMerging(maskType, opOutShape) {
// TODO: handle customized operand order and special lower.
maskElem := gOp.In[len(gOp.In)-1]
if maskElem.Bits == nil {
panic("mask has no bits")
}
if maskElem.ElemBits == nil {
panic("mask has no elemBits")
}
if maskElem.Lanes == nil {
panic("mask has no lanes")
}
switch *maskElem.Bits {
case 128, 256:
// VPBLENDVB cases.
noMaskName := machineOpName(NoMask, gOp)
ruleExisting, ok := maskedMergeOpts[noMaskName]
rule := fmt.Sprintf("(VPBLENDVB%d dst (%s %s) mask) && v.Block.CPUfeatures.hasFeature(CPUavx512) => (%sMerging dst %s (VPMOVVec%dx%dToM <types.TypeMask> mask))\n",
*maskElem.Bits, noMaskName, data.Args, data.Asm, data.Args, *maskElem.ElemBits, *maskElem.Lanes)
if ok && ruleExisting != rule {
panic(fmt.Sprintf("multiple masked merge rules for one op:\n%s\n%s\n", ruleExisting, rule))
} else {
maskedMergeOpts[noMaskName] = rule
}
case 512:
// VPBLENDM[BWDQ] cases.
noMaskName := machineOpName(NoMask, gOp)
ruleExisting, ok := maskedMergeOpts[noMaskName]
rule := fmt.Sprintf("(VPBLENDM%sMasked%d dst (%s %s) mask) => (%sMerging dst %s mask)\n",
s2n[*maskElem.ElemBits], *maskElem.Bits, noMaskName, data.Args, data.Asm, data.Args)
if ok && ruleExisting != rule {
panic(fmt.Sprintf("multiple masked merge rules for one op:\n%s\n%s\n", ruleExisting, rule))
} else {
maskedMergeOpts[noMaskName] = rule
}
}
}
if tplName == "pureVreg" && data.Args == data.ArgsOut {
data.Args = "..."
data.ArgsOut = "..."
}
data.tplName = tplName
if opr.NoGenericOps != nil && *opr.NoGenericOps == "true" ||
opr.SkipMaskedMethod() {
optData = append(optData, data)
continue
}
allData = append(allData, data)
asmCheck[data.Asm] = true
}
slices.SortFunc(allData, compareTplRuleData)
for _, data := range allData {
if err := ruleTemplates.ExecuteTemplate(buffer, data.tplName, data); err != nil {
panic(fmt.Errorf("failed to execute template %s for %s: %w", data.tplName, data.GoOp+data.GoType, err))
}
}
seen := make(map[string]bool)
for _, data := range optData {
if data.tplName == "maskIn" {
rule := data.MaskOptimization(asmCheck)
if seen[rule] {
continue
}
seen[rule] = true
buffer.WriteString(rule)
}
}
maskedMergeOptsRules := []string{}
for asm, rule := range maskedMergeOpts {
if !asmCheck[asm] {
continue
}
maskedMergeOptsRules = append(maskedMergeOptsRules, rule)
}
slices.Sort(maskedMergeOptsRules)
for _, rule := range maskedMergeOptsRules {
buffer.WriteString(rule)
}
for _, data := range memOptData {
if err := ruleTemplates.ExecuteTemplate(buffer, data.tplName, data); err != nil {
panic(fmt.Errorf("failed to execute template %s for %s: %w", data.tplName, data.Asm, err))
}
}
return buffer
}