| // Copyright 2025 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package main |
| |
| import ( |
| "bytes" |
| "fmt" |
| "slices" |
| "text/template" |
| ) |
| |
| type tplRuleData struct { |
| tplName string // e.g. "sftimm" |
| GoOp string // e.g. "ShiftAllLeft" |
| GoType string // e.g. "Uint32x8" |
| Args string // e.g. "x y" |
| Asm string // e.g. "VPSLLD256" |
| ArgsOut string // e.g. "x y" |
| MaskInConvert string // e.g. "VPMOVVec32x8ToM" |
| MaskOutConvert string // e.g. "VPMOVMToVec32x8" |
| } |
| |
| var ( |
| ruleTemplates = template.Must(template.New("simdRules").Parse(` |
| {{define "pureVreg"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.Asm}} {{.ArgsOut}}) |
| {{end}} |
| {{define "maskIn"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.Asm}} {{.ArgsOut}} ({{.MaskInConvert}} <types.TypeMask> mask)) |
| {{end}} |
| {{define "maskOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.MaskOutConvert}} ({{.Asm}} {{.ArgsOut}})) |
| {{end}} |
| {{define "maskInMaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.MaskOutConvert}} ({{.Asm}} {{.ArgsOut}} ({{.MaskInConvert}} <types.TypeMask> mask))) |
| {{end}} |
| {{define "sftimm"}}({{.Asm}} x (MOVQconst [c])) => ({{.Asm}}const [uint8(c)] x) |
| {{end}} |
| {{define "masksftimm"}}({{.Asm}} x (MOVQconst [c]) mask) => ({{.Asm}}const [uint8(c)] x mask) |
| {{end}} |
| `)) |
| ) |
| |
| // SSA rewrite rules need to appear in a most-to-least-specific order. This works for that. |
| var tmplOrder = map[string]int{ |
| "masksftimm": 0, |
| "sftimm": 1, |
| "maskInMaskOut": 2, |
| "maskOut": 3, |
| "maskIn": 4, |
| "pureVreg": 5, |
| } |
| |
| func compareTplRuleData(x, y tplRuleData) int { |
| if c := compareNatural(x.GoOp, y.GoOp); c != 0 { |
| return c |
| } |
| if c := compareNatural(x.GoType, y.GoType); c != 0 { |
| return c |
| } |
| if c := compareNatural(x.Args, y.Args); c != 0 { |
| return c |
| } |
| if x.tplName == y.tplName { |
| return 0 |
| } |
| xo, xok := tmplOrder[x.tplName] |
| yo, yok := tmplOrder[y.tplName] |
| if !xok { |
| panic(fmt.Errorf("Unexpected template name %s, please add to tmplOrder", x.tplName)) |
| } |
| if !yok { |
| panic(fmt.Errorf("Unexpected template name %s, please add to tmplOrder", y.tplName)) |
| } |
| return xo - yo |
| } |
| |
| // writeSIMDRules generates the lowering and rewrite rules for ssa and writes it to simdAMD64.rules |
| // within the specified directory. |
| func writeSIMDRules(ops []Operation) *bytes.Buffer { |
| buffer := new(bytes.Buffer) |
| buffer.WriteString(generatedHeader + "\n") |
| |
| var allData []tplRuleData |
| |
| for _, opr := range ops { |
| if opr.NoGenericOps != nil && *opr.NoGenericOps == "true" { |
| continue |
| } |
| opInShape, opOutShape, maskType, immType, gOp := opr.shape() |
| asm := machineOpName(maskType, gOp) |
| vregInCnt := len(gOp.In) |
| if maskType == OneMask { |
| vregInCnt-- |
| } |
| |
| data := tplRuleData{ |
| GoOp: gOp.Go, |
| Asm: asm, |
| } |
| |
| if vregInCnt == 1 { |
| data.Args = "x" |
| data.ArgsOut = data.Args |
| } else if vregInCnt == 2 { |
| data.Args = "x y" |
| data.ArgsOut = data.Args |
| } else if vregInCnt == 3 { |
| data.Args = "x y z" |
| data.ArgsOut = data.Args |
| } else { |
| panic(fmt.Errorf("simdgen does not support more than 3 vreg in inputs")) |
| } |
| if immType == ConstImm { |
| data.ArgsOut = fmt.Sprintf("[%s] %s", *opr.In[0].Const, data.ArgsOut) |
| } else if immType == VarImm { |
| data.Args = fmt.Sprintf("[a] %s", data.Args) |
| data.ArgsOut = fmt.Sprintf("[a] %s", data.ArgsOut) |
| } else if immType == ConstVarImm { |
| data.Args = fmt.Sprintf("[a] %s", data.Args) |
| data.ArgsOut = fmt.Sprintf("[a+%s] %s", *opr.In[0].Const, data.ArgsOut) |
| } |
| |
| goType := func(op Operation) string { |
| if op.OperandOrder != nil { |
| switch *op.OperandOrder { |
| case "21Type1", "231Type1": |
| // Permute uses operand[1] for method receiver. |
| return *op.In[1].Go |
| } |
| } |
| return *op.In[0].Go |
| } |
| var tplName string |
| // If class overwrite is happening, that's not really a mask but a vreg. |
| if opOutShape == OneVregOut || opOutShape == OneVregOutAtIn || gOp.Out[0].OverwriteClass != nil { |
| switch opInShape { |
| case OneImmIn: |
| tplName = "pureVreg" |
| data.GoType = goType(gOp) |
| case PureVregIn: |
| tplName = "pureVreg" |
| data.GoType = goType(gOp) |
| case OneKmaskImmIn: |
| fallthrough |
| case OneKmaskIn: |
| tplName = "maskIn" |
| data.GoType = goType(gOp) |
| rearIdx := len(gOp.In) - 1 |
| // Mask is at the end. |
| data.MaskInConvert = fmt.Sprintf("VPMOVVec%dx%dToM", *gOp.In[rearIdx].ElemBits, *gOp.In[rearIdx].Lanes) |
| case PureKmaskIn: |
| panic(fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations")) |
| } |
| } else if opOutShape == OneGregOut { |
| tplName = "pureVreg" // TODO this will be wrong |
| data.GoType = goType(gOp) |
| } else { |
| // OneKmaskOut case |
| data.MaskOutConvert = fmt.Sprintf("VPMOVMToVec%dx%d", *gOp.Out[0].ElemBits, *gOp.In[0].Lanes) |
| switch opInShape { |
| case OneImmIn: |
| fallthrough |
| case PureVregIn: |
| tplName = "maskOut" |
| data.GoType = goType(gOp) |
| case OneKmaskImmIn: |
| fallthrough |
| case OneKmaskIn: |
| tplName = "maskInMaskOut" |
| data.GoType = goType(gOp) |
| rearIdx := len(gOp.In) - 1 |
| data.MaskInConvert = fmt.Sprintf("VPMOVVec%dx%dToM", *gOp.In[rearIdx].ElemBits, *gOp.In[rearIdx].Lanes) |
| case PureKmaskIn: |
| panic(fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations")) |
| } |
| } |
| |
| if gOp.SpecialLower != nil { |
| if *gOp.SpecialLower == "sftimm" { |
| if data.GoType[0] == 'I' { |
| // only do these for signed types, it is a duplicate rewrite for unsigned |
| sftImmData := data |
| if tplName == "maskIn" { |
| sftImmData.tplName = "masksftimm" |
| } else { |
| sftImmData.tplName = "sftimm" |
| } |
| allData = append(allData, sftImmData) |
| } |
| } else { |
| panic("simdgen sees unknwon special lower " + *gOp.SpecialLower + ", maybe implement it?") |
| } |
| } |
| |
| if tplName == "pureVreg" && data.Args == data.ArgsOut { |
| data.Args = "..." |
| data.ArgsOut = "..." |
| } |
| data.tplName = tplName |
| allData = append(allData, data) |
| } |
| |
| slices.SortFunc(allData, compareTplRuleData) |
| |
| for _, data := range allData { |
| if err := ruleTemplates.ExecuteTemplate(buffer, data.tplName, data); err != nil { |
| panic(fmt.Errorf("failed to execute template %s for %s: %w", data.tplName, data.GoOp+data.GoType, err)) |
| } |
| } |
| |
| return buffer |
| } |