internal/simdgen: simplify gen_simdrules.go This gets the control flow out of the templates, simplifies the templates, and allows better sorting of the generated rules. Change-Id: Ic31f2554bf3d2aaf1d3efd27a8a5060c8904767f Reviewed-on: https://go-review.googlesource.com/c/arch/+/680275 Reviewed-by: Junyang Shao <shaojunyang@google.com> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/internal/simdgen/gen_simdrules.go b/internal/simdgen/gen_simdrules.go index 172282e..2f2178d 100644 --- a/internal/simdgen/gen_simdrules.go +++ b/internal/simdgen/gen_simdrules.go
@@ -6,10 +6,76 @@ import ( "fmt" - "sort" + "io" + "os" + "path/filepath" + "slices" + "strings" + "text/template" ) -const simdrulesTmpl = `// Code generated by x/arch/internal/simdgen using 'go run . -xedPath $XED_PATH -o godefs -goroot $GOROOT go.yaml types.yaml categories.yaml'; DO NOT EDIT. +var ( + ruleTemplates = template.Must(template.New("simdRules").Parse(` +{{define "pureVregInVregOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.Asm}} {{.ReverseArgs}}) +{{end}} +{{define "oneKmaskInVregOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.Asm}} {{.ReverseArgs}} (VPMOVVec{{.ElemBits}}x{{.Lanes}}ToM <types.TypeMask> mask)) +{{end}} +{{define "oneConstImmInVregOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.Asm}} [{{.Const}}] {{.ReverseArgs}}) +{{end}} +{{define "oneKmaskConstImmInVregOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.Asm}} [{{.Const}}] {{.ReverseArgs}} (VPMOVVec{{.ElemBits}}x{{.Lanes}}ToM <types.TypeMask> mask)) +{{end}} +{{define "pureVregInKmaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => (VPMOVMToVec{{.ElemBits}}x{{.Lanes}} ({{.Asm}} {{.ReverseArgs}})) +{{end}} +{{define "oneKmaskInKmaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => (VPMOVMToVec{{.ElemBits}}x{{.Lanes}} ({{.Asm}} {{.ReverseArgs}} (VPMOVVec{{.ElemBits}}x{{.Lanes}}ToM <types.TypeMask> mask))) +{{end}} +{{define "oneConstImmInKmaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => (VPMOVMToVec{{.ElemBits}}x{{.Lanes}} ({{.Asm}} [{{.Const}}] {{.ReverseArgs}})) +{{end}} +{{define "oneKmaskConstImmInKmaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => (VPMOVMToVec{{.ElemBits}}x{{.Lanes}} ({{.Asm}} [{{.Const}}] {{.ReverseArgs}} (VPMOVVec{{.ElemBits}}x{{.Lanes}}ToM <types.TypeMask> mask))) +{{end}} +`)) +) + +type tplRuleData struct { + tplName string + GoOp string + GoType string + Args string + Asm string + ReverseArgs string + ElemBits int + Lanes int + Const string +} + +func compareTplRuleData(x, y tplRuleData) int { + // TODO should MaskedXYZ compare just after XYZ? + if c := strings.Compare(x.GoOp, y.GoOp); c != 0 { + return c + } + if c := strings.Compare(x.GoType, y.GoType); c != 0 { + return c + } + if c := strings.Compare(x.Const, y.Const); c != 0 { + return c + } + return 0 +} + +// writeSIMDRules generates the lowering and rewrite rules for ssa and writes it to simdAMD64.rules +// within the specified directory. +func writeSIMDRules(directory string, ops []Operation) error { + + outPath := filepath.Join(directory, "src/cmd/compile/internal/ssa/_gen/simdAMD64.rules") + if err := os.MkdirAll(filepath.Dir(outPath), 0755); err != nil { + return fmt.Errorf("failed to create directory for %s: %w", outPath, err) + } + file, err := os.Create(outPath) + if err != nil { + return fmt.Errorf("failed to create %s: %w", outPath, err) + } + defer file.Close() + + header := `// Code generated by x/arch/internal/simdgen using 'go run . -xedPath $XED_PATH -o godefs -goroot $GOROOT go.yaml types.yaml categories.yaml'; DO NOT EDIT. // The AVX instruction encodings orders vector register from right to left, for example: // VSUBPS X Y Z means Z=Y-X @@ -17,88 +83,63 @@ // left to right order. // TODO: we should offload the logic to simdssa.go, instead of here. // -// Masks are always at the end, immediates always at the beginning. -{{- range .Ops }} -({{.Op.Go}}{{(index .Op.In 0).Go}} {{.Args}}) => ({{.Op.Asm}} {{.ReverseArgs}}) -{{- end }} -{{- range .OpsImm }} -({{.Op.Go}}{{(index .Op.In 1).Go}} {{.Args}}) => ({{.Op.Asm}} [{{(index .Op.In 0).Const}}] {{.ReverseArgs}}) -{{- end }} -{{- range .OpsMask}} -({{.Op.Go}}{{(index .Op.In 0).Go}} {{.Args}} mask) => ({{.Op.Asm}} {{.ReverseArgs}} (VPMOVVec{{(index .Op.In 0).ElemBits}}x{{(index .Op.In 0).Lanes}}ToM <types.TypeMask> mask)) -{{- end }} -{{- range .OpsImmMask}} -({{.Op.Go}}{{(index .Op.In 1).Go}} {{.Args}} mask) => ({{.Op.Asm}} [{{(index .Op.In 0).Const}}] {{.ReverseArgs}} (VPMOVVec{{(index .Op.In 1).ElemBits}}x{{(index .Op.In 1).Lanes}}ToM <types.TypeMask> mask)) -{{- end }} -{{- range .OpsMaskOut}} -({{.Op.Go}}{{(index .Op.In 0).Go}} {{.Args}}) => (VPMOVMToVec{{(index .Op.In 0).ElemBits}}x{{(index .Op.In 0).Lanes}} ({{.Op.Asm}} {{.ReverseArgs}})) -{{- end }} -{{- range .OpsImmInMaskOut}} -({{.Op.Go}}{{(index .Op.In 1).Go}} {{.Args}}) => (VPMOVMToVec{{(index .Op.In 1).ElemBits}}x{{(index .Op.In 1).Lanes}} ({{.Op.Asm}} [{{(index .Op.In 0).Const}}] {{.ReverseArgs}})) -{{- end }} -{{- range .OpsMaskInMaskOut}} -({{.Op.Go}}{{(index .Op.In 0).Go}} {{.Args}} mask) => (VPMOVMToVec{{(index .Op.In 0).ElemBits}}x{{(index .Op.In 0).Lanes}} ({{.Op.Asm}} {{.ReverseArgs}} (VPMOVVec{{(index .Op.In 0).ElemBits}}x{{(index .Op.In 0).Lanes}}ToM <types.TypeMask> mask))) -{{- end }} -{{- range .OpsImmMaskInMaskOut}} -({{.Op.Go}}{{(index .Op.In 1).Go}} {{.Args}} mask) => (VPMOVMToVec{{(index .Op.In 1).ElemBits}}x{{(index .Op.In 1).Lanes}} ({{.Op.Asm}} [{{(index .Op.In 0).Const}}] {{.ReverseArgs}} (VPMOVVec{{(index .Op.In 1).ElemBits}}x{{(index .Op.In 1).Lanes}}ToM <types.TypeMask> mask))) -{{- end }} ` - -// writeSIMDRules generates the lowering and rewrite rules for ssa and writes it to simdAMD64.rules -// within the specified directory. -func writeSIMDRules(directory string, ops []Operation) error { - file, t, err := openFileAndPrepareTemplate(directory, "src/cmd/compile/internal/ssa/_gen/simdAMD64.rules", simdrulesTmpl) - if err != nil { - return err + if _, err := io.WriteString(file, header); err != nil { + return fmt.Errorf("failed to write header to %s: %w", outPath, err) } - defer file.Close() - type OpAndArgList struct { - Op Operation - Args string // "x y", does not include masks - ReverseArgs string // "y x", does not include masks - } - Ops := make([]OpAndArgList, 0) - OpsImm := make([]OpAndArgList, 0) - OpsMask := make([]OpAndArgList, 0) - OpsImmMask := make([]OpAndArgList, 0) - OpsMaskOut := make([]OpAndArgList, 0) - OpsImmInMaskOut := make([]OpAndArgList, 0) - OpsMaskInMaskOut := make([]OpAndArgList, 0) - OpsImmMaskInMaskOut := make([]OpAndArgList, 0) - for _, op := range ops { - opInShape, opOutShape, maskType, _, op, gOp, err := op.shape() + var allData []tplRuleData + + for _, opr := range ops { + opInShape, opOutShape, maskType, _, o, gOp, err := opr.shape() if err != nil { return err } vregInCnt := len(gOp.In) if maskType == OneMask { - op.Asm += "Masked" + o.Asm += "Masked" vregInCnt-- } - op.Asm = fmt.Sprintf("%s%d", op.Asm, *op.Out[0].Bits) - opData := OpAndArgList{Op: op} + o.Asm = fmt.Sprintf("%s%d", o.Asm, *o.Out[0].Bits) + + data := tplRuleData{ + GoOp: o.Go, + Asm: o.Asm, + } + if vregInCnt == 1 { - opData.Args = "x" - opData.ReverseArgs = "x" + data.Args = "x" + data.ReverseArgs = "x" } else if vregInCnt == 2 { - opData.Args = "x y" - opData.ReverseArgs = "y x" + data.Args = "x y" + data.ReverseArgs = "y x" } else { return fmt.Errorf("simdgen does not support more than 2 vreg in inputs") } + + var tplName string // If class overwrite is happening, that's not really a mask but a vreg. - if opOutShape == OneVregOut || op.Out[0].OverwriteClass != nil { + if opOutShape == OneVregOut || o.Out[0].OverwriteClass != nil { switch opInShape { case PureVregIn: - Ops = append(Ops, opData) + tplName = "pureVregInVregOut" + data.GoType = *o.In[0].Go case OneKmaskIn: - OpsMask = append(OpsMask, opData) + tplName = "oneKmaskInVregOut" + data.GoType = *o.In[0].Go + data.ElemBits = *o.In[0].ElemBits + data.Lanes = *o.In[0].Lanes case OneConstImmIn: - OpsImm = append(OpsImm, opData) + tplName = "oneConstImmInVregOut" + data.GoType = *o.In[1].Go + data.Const = *o.In[0].Const case OneKmaskConstImmIn: - OpsImmMask = append(OpsImmMask, opData) + tplName = "oneKmaskConstImmInVregOut" + data.GoType = *o.In[1].Go + data.Const = *o.In[0].Const + data.ElemBits = *o.In[1].ElemBits + data.Lanes = *o.In[1].Lanes case PureKmaskIn: return fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations") } @@ -106,57 +147,42 @@ // OneKmaskOut case switch opInShape { case PureVregIn: - OpsMaskOut = append(OpsMaskOut, opData) + tplName = "pureVregInKmaskOut" + data.GoType = *o.In[0].Go + data.ElemBits = *o.In[0].ElemBits + data.Lanes = *o.In[0].Lanes case OneKmaskIn: - OpsMaskInMaskOut = append(OpsMaskInMaskOut, opData) + tplName = "oneKmaskInKmaskOut" + data.GoType = *o.In[0].Go + data.ElemBits = *o.In[0].ElemBits + data.Lanes = *o.In[0].Lanes case OneConstImmIn: - OpsImmInMaskOut = append(OpsImmInMaskOut, opData) + tplName = "oneConstImmInKmaskOut" + data.GoType = *o.In[1].Go + data.Const = *o.In[0].Const + data.ElemBits = *o.In[1].ElemBits + data.Lanes = *o.In[1].Lanes case OneKmaskConstImmIn: - OpsImmMaskInMaskOut = append(OpsImmMaskInMaskOut, opData) + tplName = "oneKmaskConstImmInKmaskOut" + data.GoType = *o.In[1].Go + data.Const = *o.In[0].Const + data.ElemBits = *o.In[1].ElemBits + data.Lanes = *o.In[1].Lanes case PureKmaskIn: return fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations") } } - } - sortKey := func(op *OpAndArgList) string { - return *op.Op.In[0].Go + op.Op.Go - } - sortBySortKey := func(ops []OpAndArgList) { - sort.Slice(ops, func(i, j int) bool { - return sortKey(&ops[i]) < sortKey(&ops[j]) - }) - } - sortBySortKey(Ops) - sortBySortKey(OpsImm) - sortBySortKey(OpsMask) - sortBySortKey(OpsImmMask) - sortBySortKey(OpsMaskOut) - sortBySortKey(OpsImmInMaskOut) - sortBySortKey(OpsMaskInMaskOut) - sortBySortKey(OpsImmMaskInMaskOut) - type templateData struct { - Ops []OpAndArgList - OpsImm []OpAndArgList - OpsMask []OpAndArgList - OpsImmMask []OpAndArgList - OpsMaskOut []OpAndArgList - OpsImmInMaskOut []OpAndArgList - OpsMaskInMaskOut []OpAndArgList - OpsImmMaskInMaskOut []OpAndArgList + data.tplName = tplName + allData = append(allData, data) } - err = t.Execute(file, templateData{ - Ops, - OpsImm, - OpsMask, - OpsImmMask, - OpsMaskOut, - OpsImmInMaskOut, - OpsMaskInMaskOut, - OpsImmMaskInMaskOut}) - if err != nil { - return fmt.Errorf("failed to execute template: %w", err) + slices.SortFunc(allData, compareTplRuleData) + + for _, data := range allData { + if err := ruleTemplates.ExecuteTemplate(file, data.tplName, data); err != nil { + return fmt.Errorf("failed to execute template %s for %s: %w", data.tplName, data.GoOp+data.GoType, err) + } } return nil