internal/simdgen: simplify gen_simdrules.go
This gets the control flow out of the templates,
simplifies the templates, and allows better sorting
of the generated rules.
Change-Id: Ic31f2554bf3d2aaf1d3efd27a8a5060c8904767f
Reviewed-on: https://go-review.googlesource.com/c/arch/+/680275
Reviewed-by: Junyang Shao <shaojunyang@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/internal/simdgen/gen_simdrules.go b/internal/simdgen/gen_simdrules.go
index 172282e..2f2178d 100644
--- a/internal/simdgen/gen_simdrules.go
+++ b/internal/simdgen/gen_simdrules.go
@@ -6,10 +6,76 @@
import (
"fmt"
- "sort"
+ "io"
+ "os"
+ "path/filepath"
+ "slices"
+ "strings"
+ "text/template"
)
-const simdrulesTmpl = `// Code generated by x/arch/internal/simdgen using 'go run . -xedPath $XED_PATH -o godefs -goroot $GOROOT go.yaml types.yaml categories.yaml'; DO NOT EDIT.
+var (
+ ruleTemplates = template.Must(template.New("simdRules").Parse(`
+{{define "pureVregInVregOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.Asm}} {{.ReverseArgs}})
+{{end}}
+{{define "oneKmaskInVregOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.Asm}} {{.ReverseArgs}} (VPMOVVec{{.ElemBits}}x{{.Lanes}}ToM <types.TypeMask> mask))
+{{end}}
+{{define "oneConstImmInVregOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.Asm}} [{{.Const}}] {{.ReverseArgs}})
+{{end}}
+{{define "oneKmaskConstImmInVregOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.Asm}} [{{.Const}}] {{.ReverseArgs}} (VPMOVVec{{.ElemBits}}x{{.Lanes}}ToM <types.TypeMask> mask))
+{{end}}
+{{define "pureVregInKmaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => (VPMOVMToVec{{.ElemBits}}x{{.Lanes}} ({{.Asm}} {{.ReverseArgs}}))
+{{end}}
+{{define "oneKmaskInKmaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => (VPMOVMToVec{{.ElemBits}}x{{.Lanes}} ({{.Asm}} {{.ReverseArgs}} (VPMOVVec{{.ElemBits}}x{{.Lanes}}ToM <types.TypeMask> mask)))
+{{end}}
+{{define "oneConstImmInKmaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => (VPMOVMToVec{{.ElemBits}}x{{.Lanes}} ({{.Asm}} [{{.Const}}] {{.ReverseArgs}}))
+{{end}}
+{{define "oneKmaskConstImmInKmaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => (VPMOVMToVec{{.ElemBits}}x{{.Lanes}} ({{.Asm}} [{{.Const}}] {{.ReverseArgs}} (VPMOVVec{{.ElemBits}}x{{.Lanes}}ToM <types.TypeMask> mask)))
+{{end}}
+`))
+)
+
+type tplRuleData struct {
+ tplName string
+ GoOp string
+ GoType string
+ Args string
+ Asm string
+ ReverseArgs string
+ ElemBits int
+ Lanes int
+ Const string
+}
+
+func compareTplRuleData(x, y tplRuleData) int {
+ // TODO should MaskedXYZ compare just after XYZ?
+ if c := strings.Compare(x.GoOp, y.GoOp); c != 0 {
+ return c
+ }
+ if c := strings.Compare(x.GoType, y.GoType); c != 0 {
+ return c
+ }
+ if c := strings.Compare(x.Const, y.Const); c != 0 {
+ return c
+ }
+ return 0
+}
+
+// writeSIMDRules generates the lowering and rewrite rules for ssa and writes it to simdAMD64.rules
+// within the specified directory.
+func writeSIMDRules(directory string, ops []Operation) error {
+
+ outPath := filepath.Join(directory, "src/cmd/compile/internal/ssa/_gen/simdAMD64.rules")
+ if err := os.MkdirAll(filepath.Dir(outPath), 0755); err != nil {
+ return fmt.Errorf("failed to create directory for %s: %w", outPath, err)
+ }
+ file, err := os.Create(outPath)
+ if err != nil {
+ return fmt.Errorf("failed to create %s: %w", outPath, err)
+ }
+ defer file.Close()
+
+ header := `// Code generated by x/arch/internal/simdgen using 'go run . -xedPath $XED_PATH -o godefs -goroot $GOROOT go.yaml types.yaml categories.yaml'; DO NOT EDIT.
// The AVX instruction encodings orders vector register from right to left, for example:
// VSUBPS X Y Z means Z=Y-X
@@ -17,88 +83,63 @@
// left to right order.
// TODO: we should offload the logic to simdssa.go, instead of here.
//
-// Masks are always at the end, immediates always at the beginning.
-{{- range .Ops }}
-({{.Op.Go}}{{(index .Op.In 0).Go}} {{.Args}}) => ({{.Op.Asm}} {{.ReverseArgs}})
-{{- end }}
-{{- range .OpsImm }}
-({{.Op.Go}}{{(index .Op.In 1).Go}} {{.Args}}) => ({{.Op.Asm}} [{{(index .Op.In 0).Const}}] {{.ReverseArgs}})
-{{- end }}
-{{- range .OpsMask}}
-({{.Op.Go}}{{(index .Op.In 0).Go}} {{.Args}} mask) => ({{.Op.Asm}} {{.ReverseArgs}} (VPMOVVec{{(index .Op.In 0).ElemBits}}x{{(index .Op.In 0).Lanes}}ToM <types.TypeMask> mask))
-{{- end }}
-{{- range .OpsImmMask}}
-({{.Op.Go}}{{(index .Op.In 1).Go}} {{.Args}} mask) => ({{.Op.Asm}} [{{(index .Op.In 0).Const}}] {{.ReverseArgs}} (VPMOVVec{{(index .Op.In 1).ElemBits}}x{{(index .Op.In 1).Lanes}}ToM <types.TypeMask> mask))
-{{- end }}
-{{- range .OpsMaskOut}}
-({{.Op.Go}}{{(index .Op.In 0).Go}} {{.Args}}) => (VPMOVMToVec{{(index .Op.In 0).ElemBits}}x{{(index .Op.In 0).Lanes}} ({{.Op.Asm}} {{.ReverseArgs}}))
-{{- end }}
-{{- range .OpsImmInMaskOut}}
-({{.Op.Go}}{{(index .Op.In 1).Go}} {{.Args}}) => (VPMOVMToVec{{(index .Op.In 1).ElemBits}}x{{(index .Op.In 1).Lanes}} ({{.Op.Asm}} [{{(index .Op.In 0).Const}}] {{.ReverseArgs}}))
-{{- end }}
-{{- range .OpsMaskInMaskOut}}
-({{.Op.Go}}{{(index .Op.In 0).Go}} {{.Args}} mask) => (VPMOVMToVec{{(index .Op.In 0).ElemBits}}x{{(index .Op.In 0).Lanes}} ({{.Op.Asm}} {{.ReverseArgs}} (VPMOVVec{{(index .Op.In 0).ElemBits}}x{{(index .Op.In 0).Lanes}}ToM <types.TypeMask> mask)))
-{{- end }}
-{{- range .OpsImmMaskInMaskOut}}
-({{.Op.Go}}{{(index .Op.In 1).Go}} {{.Args}} mask) => (VPMOVMToVec{{(index .Op.In 1).ElemBits}}x{{(index .Op.In 1).Lanes}} ({{.Op.Asm}} [{{(index .Op.In 0).Const}}] {{.ReverseArgs}} (VPMOVVec{{(index .Op.In 1).ElemBits}}x{{(index .Op.In 1).Lanes}}ToM <types.TypeMask> mask)))
-{{- end }}
`
-
-// writeSIMDRules generates the lowering and rewrite rules for ssa and writes it to simdAMD64.rules
-// within the specified directory.
-func writeSIMDRules(directory string, ops []Operation) error {
- file, t, err := openFileAndPrepareTemplate(directory, "src/cmd/compile/internal/ssa/_gen/simdAMD64.rules", simdrulesTmpl)
- if err != nil {
- return err
+ if _, err := io.WriteString(file, header); err != nil {
+ return fmt.Errorf("failed to write header to %s: %w", outPath, err)
}
- defer file.Close()
- type OpAndArgList struct {
- Op Operation
- Args string // "x y", does not include masks
- ReverseArgs string // "y x", does not include masks
- }
- Ops := make([]OpAndArgList, 0)
- OpsImm := make([]OpAndArgList, 0)
- OpsMask := make([]OpAndArgList, 0)
- OpsImmMask := make([]OpAndArgList, 0)
- OpsMaskOut := make([]OpAndArgList, 0)
- OpsImmInMaskOut := make([]OpAndArgList, 0)
- OpsMaskInMaskOut := make([]OpAndArgList, 0)
- OpsImmMaskInMaskOut := make([]OpAndArgList, 0)
- for _, op := range ops {
- opInShape, opOutShape, maskType, _, op, gOp, err := op.shape()
+ var allData []tplRuleData
+
+ for _, opr := range ops {
+ opInShape, opOutShape, maskType, _, o, gOp, err := opr.shape()
if err != nil {
return err
}
vregInCnt := len(gOp.In)
if maskType == OneMask {
- op.Asm += "Masked"
+ o.Asm += "Masked"
vregInCnt--
}
- op.Asm = fmt.Sprintf("%s%d", op.Asm, *op.Out[0].Bits)
- opData := OpAndArgList{Op: op}
+ o.Asm = fmt.Sprintf("%s%d", o.Asm, *o.Out[0].Bits)
+
+ data := tplRuleData{
+ GoOp: o.Go,
+ Asm: o.Asm,
+ }
+
if vregInCnt == 1 {
- opData.Args = "x"
- opData.ReverseArgs = "x"
+ data.Args = "x"
+ data.ReverseArgs = "x"
} else if vregInCnt == 2 {
- opData.Args = "x y"
- opData.ReverseArgs = "y x"
+ data.Args = "x y"
+ data.ReverseArgs = "y x"
} else {
return fmt.Errorf("simdgen does not support more than 2 vreg in inputs")
}
+
+ var tplName string
// If class overwrite is happening, that's not really a mask but a vreg.
- if opOutShape == OneVregOut || op.Out[0].OverwriteClass != nil {
+ if opOutShape == OneVregOut || o.Out[0].OverwriteClass != nil {
switch opInShape {
case PureVregIn:
- Ops = append(Ops, opData)
+ tplName = "pureVregInVregOut"
+ data.GoType = *o.In[0].Go
case OneKmaskIn:
- OpsMask = append(OpsMask, opData)
+ tplName = "oneKmaskInVregOut"
+ data.GoType = *o.In[0].Go
+ data.ElemBits = *o.In[0].ElemBits
+ data.Lanes = *o.In[0].Lanes
case OneConstImmIn:
- OpsImm = append(OpsImm, opData)
+ tplName = "oneConstImmInVregOut"
+ data.GoType = *o.In[1].Go
+ data.Const = *o.In[0].Const
case OneKmaskConstImmIn:
- OpsImmMask = append(OpsImmMask, opData)
+ tplName = "oneKmaskConstImmInVregOut"
+ data.GoType = *o.In[1].Go
+ data.Const = *o.In[0].Const
+ data.ElemBits = *o.In[1].ElemBits
+ data.Lanes = *o.In[1].Lanes
case PureKmaskIn:
return fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations")
}
@@ -106,57 +147,42 @@
// OneKmaskOut case
switch opInShape {
case PureVregIn:
- OpsMaskOut = append(OpsMaskOut, opData)
+ tplName = "pureVregInKmaskOut"
+ data.GoType = *o.In[0].Go
+ data.ElemBits = *o.In[0].ElemBits
+ data.Lanes = *o.In[0].Lanes
case OneKmaskIn:
- OpsMaskInMaskOut = append(OpsMaskInMaskOut, opData)
+ tplName = "oneKmaskInKmaskOut"
+ data.GoType = *o.In[0].Go
+ data.ElemBits = *o.In[0].ElemBits
+ data.Lanes = *o.In[0].Lanes
case OneConstImmIn:
- OpsImmInMaskOut = append(OpsImmInMaskOut, opData)
+ tplName = "oneConstImmInKmaskOut"
+ data.GoType = *o.In[1].Go
+ data.Const = *o.In[0].Const
+ data.ElemBits = *o.In[1].ElemBits
+ data.Lanes = *o.In[1].Lanes
case OneKmaskConstImmIn:
- OpsImmMaskInMaskOut = append(OpsImmMaskInMaskOut, opData)
+ tplName = "oneKmaskConstImmInKmaskOut"
+ data.GoType = *o.In[1].Go
+ data.Const = *o.In[0].Const
+ data.ElemBits = *o.In[1].ElemBits
+ data.Lanes = *o.In[1].Lanes
case PureKmaskIn:
return fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations")
}
}
- }
- sortKey := func(op *OpAndArgList) string {
- return *op.Op.In[0].Go + op.Op.Go
- }
- sortBySortKey := func(ops []OpAndArgList) {
- sort.Slice(ops, func(i, j int) bool {
- return sortKey(&ops[i]) < sortKey(&ops[j])
- })
- }
- sortBySortKey(Ops)
- sortBySortKey(OpsImm)
- sortBySortKey(OpsMask)
- sortBySortKey(OpsImmMask)
- sortBySortKey(OpsMaskOut)
- sortBySortKey(OpsImmInMaskOut)
- sortBySortKey(OpsMaskInMaskOut)
- sortBySortKey(OpsImmMaskInMaskOut)
- type templateData struct {
- Ops []OpAndArgList
- OpsImm []OpAndArgList
- OpsMask []OpAndArgList
- OpsImmMask []OpAndArgList
- OpsMaskOut []OpAndArgList
- OpsImmInMaskOut []OpAndArgList
- OpsMaskInMaskOut []OpAndArgList
- OpsImmMaskInMaskOut []OpAndArgList
+ data.tplName = tplName
+ allData = append(allData, data)
}
- err = t.Execute(file, templateData{
- Ops,
- OpsImm,
- OpsMask,
- OpsImmMask,
- OpsMaskOut,
- OpsImmInMaskOut,
- OpsMaskInMaskOut,
- OpsImmMaskInMaskOut})
- if err != nil {
- return fmt.Errorf("failed to execute template: %w", err)
+ slices.SortFunc(allData, compareTplRuleData)
+
+ for _, data := range allData {
+ if err := ruleTemplates.ExecuteTemplate(file, data.tplName, data); err != nil {
+ return fmt.Errorf("failed to execute template %s for %s: %w", data.tplName, data.GoOp+data.GoType, err)
+ }
}
return nil