blob: b0fc7e62cde1f78bd9ddef43591cf0d2a542bb39 [file]
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"bytes"
"fmt"
"slices"
"text/template"
)
type tplRuleData struct {
tplName string // e.g. "sftimm"
GoOp string // e.g. "ShiftAllLeft"
GoType string // e.g. "Uint32x8"
Args string // e.g. "x y"
Asm string // e.g. "VPSLLD256"
ArgsOut string // e.g. "x y"
MaskInConvert string // e.g. "VPMOVVec32x8ToM"
MaskOutConvert string // e.g. "VPMOVMToVec32x8"
}
var (
ruleTemplates = template.Must(template.New("simdRules").Parse(`
{{define "pureVreg"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.Asm}} {{.ArgsOut}})
{{end}}
{{define "maskIn"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.Asm}} {{.ArgsOut}} ({{.MaskInConvert}} <types.TypeMask> mask))
{{end}}
{{define "maskOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.MaskOutConvert}} ({{.Asm}} {{.ArgsOut}}))
{{end}}
{{define "maskInMaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.MaskOutConvert}} ({{.Asm}} {{.ArgsOut}} ({{.MaskInConvert}} <types.TypeMask> mask)))
{{end}}
{{define "sftimm"}}({{.Asm}} x (MOVQconst [c])) => ({{.Asm}}const [uint8(c)] x)
{{end}}
{{define "masksftimm"}}({{.Asm}} x (MOVQconst [c]) mask) => ({{.Asm}}const [uint8(c)] x mask)
{{end}}
`))
)
// SSA rewrite rules need to appear in a most-to-least-specific order. This works for that.
var tmplOrder = map[string]int{
"masksftimm": 0,
"sftimm": 1,
"maskInMaskOut": 2,
"maskOut": 3,
"maskIn": 4,
"pureVreg": 5,
}
func compareTplRuleData(x, y tplRuleData) int {
if c := compareNatural(x.GoOp, y.GoOp); c != 0 {
return c
}
if c := compareNatural(x.GoType, y.GoType); c != 0 {
return c
}
if c := compareNatural(x.Args, y.Args); c != 0 {
return c
}
if x.tplName == y.tplName {
return 0
}
xo, xok := tmplOrder[x.tplName]
yo, yok := tmplOrder[y.tplName]
if !xok {
panic(fmt.Errorf("Unexpected template name %s, please add to tmplOrder", x.tplName))
}
if !yok {
panic(fmt.Errorf("Unexpected template name %s, please add to tmplOrder", y.tplName))
}
return xo - yo
}
// writeSIMDRules generates the lowering and rewrite rules for ssa and writes it to simdAMD64.rules
// within the specified directory.
func writeSIMDRules(ops []Operation) *bytes.Buffer {
buffer := new(bytes.Buffer)
buffer.WriteString(generatedHeader + "\n")
var allData []tplRuleData
for _, opr := range ops {
if opr.NoGenericOps != nil && *opr.NoGenericOps == "true" {
continue
}
opInShape, opOutShape, maskType, immType, gOp := opr.shape()
asm := machineOpName(maskType, gOp)
vregInCnt := len(gOp.In)
if maskType == OneMask {
vregInCnt--
}
data := tplRuleData{
GoOp: gOp.Go,
Asm: asm,
}
if vregInCnt == 1 {
data.Args = "x"
data.ArgsOut = data.Args
} else if vregInCnt == 2 {
data.Args = "x y"
data.ArgsOut = data.Args
} else if vregInCnt == 3 {
data.Args = "x y z"
data.ArgsOut = data.Args
} else {
panic(fmt.Errorf("simdgen does not support more than 3 vreg in inputs"))
}
if immType == ConstImm {
data.ArgsOut = fmt.Sprintf("[%s] %s", *opr.In[0].Const, data.ArgsOut)
} else if immType == VarImm {
data.Args = fmt.Sprintf("[a] %s", data.Args)
data.ArgsOut = fmt.Sprintf("[a] %s", data.ArgsOut)
} else if immType == ConstVarImm {
data.Args = fmt.Sprintf("[a] %s", data.Args)
data.ArgsOut = fmt.Sprintf("[a+%s] %s", *opr.In[0].Const, data.ArgsOut)
}
goType := func(op Operation) string {
if op.OperandOrder != nil {
switch *op.OperandOrder {
case "21Type1", "231Type1":
// Permute uses operand[1] for method receiver.
return *op.In[1].Go
}
}
return *op.In[0].Go
}
var tplName string
// If class overwrite is happening, that's not really a mask but a vreg.
if opOutShape == OneVregOut || opOutShape == OneVregOutAtIn || gOp.Out[0].OverwriteClass != nil {
switch opInShape {
case OneImmIn:
tplName = "pureVreg"
data.GoType = goType(gOp)
case PureVregIn:
tplName = "pureVreg"
data.GoType = goType(gOp)
case OneKmaskImmIn:
fallthrough
case OneKmaskIn:
tplName = "maskIn"
data.GoType = goType(gOp)
rearIdx := len(gOp.In) - 1
// Mask is at the end.
data.MaskInConvert = fmt.Sprintf("VPMOVVec%dx%dToM", *gOp.In[rearIdx].ElemBits, *gOp.In[rearIdx].Lanes)
case PureKmaskIn:
panic(fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations"))
}
} else if opOutShape == OneGregOut {
tplName = "pureVreg" // TODO this will be wrong
data.GoType = goType(gOp)
} else {
// OneKmaskOut case
data.MaskOutConvert = fmt.Sprintf("VPMOVMToVec%dx%d", *gOp.Out[0].ElemBits, *gOp.In[0].Lanes)
switch opInShape {
case OneImmIn:
fallthrough
case PureVregIn:
tplName = "maskOut"
data.GoType = goType(gOp)
case OneKmaskImmIn:
fallthrough
case OneKmaskIn:
tplName = "maskInMaskOut"
data.GoType = goType(gOp)
rearIdx := len(gOp.In) - 1
data.MaskInConvert = fmt.Sprintf("VPMOVVec%dx%dToM", *gOp.In[rearIdx].ElemBits, *gOp.In[rearIdx].Lanes)
case PureKmaskIn:
panic(fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations"))
}
}
if gOp.SpecialLower != nil {
if *gOp.SpecialLower == "sftimm" {
if data.GoType[0] == 'I' {
// only do these for signed types, it is a duplicate rewrite for unsigned
sftImmData := data
if tplName == "maskIn" {
sftImmData.tplName = "masksftimm"
} else {
sftImmData.tplName = "sftimm"
}
allData = append(allData, sftImmData)
}
} else {
panic("simdgen sees unknwon special lower " + *gOp.SpecialLower + ", maybe implement it?")
}
}
if tplName == "pureVreg" && data.Args == data.ArgsOut {
data.Args = "..."
data.ArgsOut = "..."
}
data.tplName = tplName
allData = append(allData, data)
}
slices.SortFunc(allData, compareTplRuleData)
for _, data := range allData {
if err := ruleTemplates.ExecuteTemplate(buffer, data.tplName, data); err != nil {
panic(fmt.Errorf("failed to execute template %s for %s: %w", data.tplName, data.GoOp+data.GoType, err))
}
}
return buffer
}