internal/simdgen: simplify gen_simdrules.go

This gets the control flow out of the templates,
simplifies the templates, and allows better sorting
of the generated rules.

Change-Id: Ic31f2554bf3d2aaf1d3efd27a8a5060c8904767f
Reviewed-on: https://go-review.googlesource.com/c/arch/+/680275
Reviewed-by: Junyang Shao <shaojunyang@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/internal/simdgen/gen_simdrules.go b/internal/simdgen/gen_simdrules.go
index 172282e..2f2178d 100644
--- a/internal/simdgen/gen_simdrules.go
+++ b/internal/simdgen/gen_simdrules.go
@@ -6,10 +6,76 @@
 
 import (
 	"fmt"
-	"sort"
+	"io"
+	"os"
+	"path/filepath"
+	"slices"
+	"strings"
+	"text/template"
 )
 
-const simdrulesTmpl = `// Code generated by x/arch/internal/simdgen using 'go run . -xedPath $XED_PATH -o godefs -goroot $GOROOT go.yaml types.yaml categories.yaml'; DO NOT EDIT.
+var (
+	ruleTemplates = template.Must(template.New("simdRules").Parse(`
+{{define "pureVregInVregOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.Asm}} {{.ReverseArgs}})
+{{end}}
+{{define "oneKmaskInVregOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.Asm}} {{.ReverseArgs}} (VPMOVVec{{.ElemBits}}x{{.Lanes}}ToM <types.TypeMask> mask))
+{{end}}
+{{define "oneConstImmInVregOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.Asm}} [{{.Const}}] {{.ReverseArgs}})
+{{end}}
+{{define "oneKmaskConstImmInVregOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.Asm}} [{{.Const}}] {{.ReverseArgs}} (VPMOVVec{{.ElemBits}}x{{.Lanes}}ToM <types.TypeMask> mask))
+{{end}}
+{{define "pureVregInKmaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => (VPMOVMToVec{{.ElemBits}}x{{.Lanes}} ({{.Asm}} {{.ReverseArgs}}))
+{{end}}
+{{define "oneKmaskInKmaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => (VPMOVMToVec{{.ElemBits}}x{{.Lanes}} ({{.Asm}} {{.ReverseArgs}} (VPMOVVec{{.ElemBits}}x{{.Lanes}}ToM <types.TypeMask> mask)))
+{{end}}
+{{define "oneConstImmInKmaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => (VPMOVMToVec{{.ElemBits}}x{{.Lanes}} ({{.Asm}} [{{.Const}}] {{.ReverseArgs}}))
+{{end}}
+{{define "oneKmaskConstImmInKmaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => (VPMOVMToVec{{.ElemBits}}x{{.Lanes}} ({{.Asm}} [{{.Const}}] {{.ReverseArgs}} (VPMOVVec{{.ElemBits}}x{{.Lanes}}ToM <types.TypeMask> mask)))
+{{end}}
+`))
+)
+
+type tplRuleData struct {
+	tplName     string
+	GoOp        string
+	GoType      string
+	Args        string
+	Asm         string
+	ReverseArgs string
+	ElemBits    int
+	Lanes       int
+	Const       string
+}
+
+func compareTplRuleData(x, y tplRuleData) int {
+	// TODO should MaskedXYZ compare just after XYZ?
+	if c := strings.Compare(x.GoOp, y.GoOp); c != 0 {
+		return c
+	}
+	if c := strings.Compare(x.GoType, y.GoType); c != 0 {
+		return c
+	}
+	if c := strings.Compare(x.Const, y.Const); c != 0 {
+		return c
+	}
+	return 0
+}
+
+// writeSIMDRules generates the lowering and rewrite rules for ssa and writes it to simdAMD64.rules
+// within the specified directory.
+func writeSIMDRules(directory string, ops []Operation) error {
+
+	outPath := filepath.Join(directory, "src/cmd/compile/internal/ssa/_gen/simdAMD64.rules")
+	if err := os.MkdirAll(filepath.Dir(outPath), 0755); err != nil {
+		return fmt.Errorf("failed to create directory for %s: %w", outPath, err)
+	}
+	file, err := os.Create(outPath)
+	if err != nil {
+		return fmt.Errorf("failed to create %s: %w", outPath, err)
+	}
+	defer file.Close()
+
+	header := `// Code generated by x/arch/internal/simdgen using 'go run . -xedPath $XED_PATH -o godefs -goroot $GOROOT go.yaml types.yaml categories.yaml'; DO NOT EDIT.
 
 // The AVX instruction encodings orders vector register from right to left, for example:
 // VSUBPS X Y Z means Z=Y-X
@@ -17,88 +83,63 @@
 // left to right order.
 // TODO: we should offload the logic to simdssa.go, instead of here.
 //
-// Masks are always at the end, immediates always at the beginning.
 
-{{- range .Ops }}
-({{.Op.Go}}{{(index .Op.In 0).Go}} {{.Args}}) => ({{.Op.Asm}} {{.ReverseArgs}})
-{{- end }}
-{{- range .OpsImm }}
-({{.Op.Go}}{{(index .Op.In 1).Go}} {{.Args}}) => ({{.Op.Asm}} [{{(index .Op.In 0).Const}}] {{.ReverseArgs}})
-{{- end }}
-{{- range .OpsMask}}
-({{.Op.Go}}{{(index .Op.In 0).Go}} {{.Args}} mask) => ({{.Op.Asm}} {{.ReverseArgs}} (VPMOVVec{{(index .Op.In 0).ElemBits}}x{{(index .Op.In 0).Lanes}}ToM <types.TypeMask> mask))
-{{- end }}
-{{- range .OpsImmMask}}
-({{.Op.Go}}{{(index .Op.In 1).Go}} {{.Args}} mask) => ({{.Op.Asm}} [{{(index .Op.In 0).Const}}] {{.ReverseArgs}} (VPMOVVec{{(index .Op.In 1).ElemBits}}x{{(index .Op.In 1).Lanes}}ToM <types.TypeMask> mask))
-{{- end }}
-{{- range .OpsMaskOut}}
-({{.Op.Go}}{{(index .Op.In 0).Go}} {{.Args}}) => (VPMOVMToVec{{(index .Op.In 0).ElemBits}}x{{(index .Op.In 0).Lanes}} ({{.Op.Asm}} {{.ReverseArgs}}))
-{{- end }}
-{{- range .OpsImmInMaskOut}}
-({{.Op.Go}}{{(index .Op.In 1).Go}} {{.Args}}) => (VPMOVMToVec{{(index .Op.In 1).ElemBits}}x{{(index .Op.In 1).Lanes}} ({{.Op.Asm}} [{{(index .Op.In 0).Const}}] {{.ReverseArgs}}))
-{{- end }}
-{{- range .OpsMaskInMaskOut}}
-({{.Op.Go}}{{(index .Op.In 0).Go}} {{.Args}} mask) => (VPMOVMToVec{{(index .Op.In 0).ElemBits}}x{{(index .Op.In 0).Lanes}} ({{.Op.Asm}} {{.ReverseArgs}} (VPMOVVec{{(index .Op.In 0).ElemBits}}x{{(index .Op.In 0).Lanes}}ToM <types.TypeMask> mask)))
-{{- end }}
-{{- range .OpsImmMaskInMaskOut}}
-({{.Op.Go}}{{(index .Op.In 1).Go}} {{.Args}} mask) => (VPMOVMToVec{{(index .Op.In 1).ElemBits}}x{{(index .Op.In 1).Lanes}} ({{.Op.Asm}} [{{(index .Op.In 0).Const}}] {{.ReverseArgs}} (VPMOVVec{{(index .Op.In 1).ElemBits}}x{{(index .Op.In 1).Lanes}}ToM <types.TypeMask> mask)))
-{{- end }}
 `
-
-// writeSIMDRules generates the lowering and rewrite rules for ssa and writes it to simdAMD64.rules
-// within the specified directory.
-func writeSIMDRules(directory string, ops []Operation) error {
-	file, t, err := openFileAndPrepareTemplate(directory, "src/cmd/compile/internal/ssa/_gen/simdAMD64.rules", simdrulesTmpl)
-	if err != nil {
-		return err
+	if _, err := io.WriteString(file, header); err != nil {
+		return fmt.Errorf("failed to write header to %s: %w", outPath, err)
 	}
-	defer file.Close()
-	type OpAndArgList struct {
-		Op          Operation
-		Args        string // "x y", does not include masks
-		ReverseArgs string // "y x", does not include masks
-	}
-	Ops := make([]OpAndArgList, 0)
-	OpsImm := make([]OpAndArgList, 0)
-	OpsMask := make([]OpAndArgList, 0)
-	OpsImmMask := make([]OpAndArgList, 0)
-	OpsMaskOut := make([]OpAndArgList, 0)
-	OpsImmInMaskOut := make([]OpAndArgList, 0)
-	OpsMaskInMaskOut := make([]OpAndArgList, 0)
-	OpsImmMaskInMaskOut := make([]OpAndArgList, 0)
 
-	for _, op := range ops {
-		opInShape, opOutShape, maskType, _, op, gOp, err := op.shape()
+	var allData []tplRuleData
+
+	for _, opr := range ops {
+		opInShape, opOutShape, maskType, _, o, gOp, err := opr.shape()
 		if err != nil {
 			return err
 		}
 		vregInCnt := len(gOp.In)
 		if maskType == OneMask {
-			op.Asm += "Masked"
+			o.Asm += "Masked"
 			vregInCnt--
 		}
-		op.Asm = fmt.Sprintf("%s%d", op.Asm, *op.Out[0].Bits)
-		opData := OpAndArgList{Op: op}
+		o.Asm = fmt.Sprintf("%s%d", o.Asm, *o.Out[0].Bits)
+
+		data := tplRuleData{
+			GoOp: o.Go,
+			Asm:  o.Asm,
+		}
+
 		if vregInCnt == 1 {
-			opData.Args = "x"
-			opData.ReverseArgs = "x"
+			data.Args = "x"
+			data.ReverseArgs = "x"
 		} else if vregInCnt == 2 {
-			opData.Args = "x y"
-			opData.ReverseArgs = "y x"
+			data.Args = "x y"
+			data.ReverseArgs = "y x"
 		} else {
 			return fmt.Errorf("simdgen does not support more than 2 vreg in inputs")
 		}
+
+		var tplName string
 		// If class overwrite is happening, that's not really a mask but a vreg.
-		if opOutShape == OneVregOut || op.Out[0].OverwriteClass != nil {
+		if opOutShape == OneVregOut || o.Out[0].OverwriteClass != nil {
 			switch opInShape {
 			case PureVregIn:
-				Ops = append(Ops, opData)
+				tplName = "pureVregInVregOut"
+				data.GoType = *o.In[0].Go
 			case OneKmaskIn:
-				OpsMask = append(OpsMask, opData)
+				tplName = "oneKmaskInVregOut"
+				data.GoType = *o.In[0].Go
+				data.ElemBits = *o.In[0].ElemBits
+				data.Lanes = *o.In[0].Lanes
 			case OneConstImmIn:
-				OpsImm = append(OpsImm, opData)
+				tplName = "oneConstImmInVregOut"
+				data.GoType = *o.In[1].Go
+				data.Const = *o.In[0].Const
 			case OneKmaskConstImmIn:
-				OpsImmMask = append(OpsImmMask, opData)
+				tplName = "oneKmaskConstImmInVregOut"
+				data.GoType = *o.In[1].Go
+				data.Const = *o.In[0].Const
+				data.ElemBits = *o.In[1].ElemBits
+				data.Lanes = *o.In[1].Lanes
 			case PureKmaskIn:
 				return fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations")
 			}
@@ -106,57 +147,42 @@
 			// OneKmaskOut case
 			switch opInShape {
 			case PureVregIn:
-				OpsMaskOut = append(OpsMaskOut, opData)
+				tplName = "pureVregInKmaskOut"
+				data.GoType = *o.In[0].Go
+				data.ElemBits = *o.In[0].ElemBits
+				data.Lanes = *o.In[0].Lanes
 			case OneKmaskIn:
-				OpsMaskInMaskOut = append(OpsMaskInMaskOut, opData)
+				tplName = "oneKmaskInKmaskOut"
+				data.GoType = *o.In[0].Go
+				data.ElemBits = *o.In[0].ElemBits
+				data.Lanes = *o.In[0].Lanes
 			case OneConstImmIn:
-				OpsImmInMaskOut = append(OpsImmInMaskOut, opData)
+				tplName = "oneConstImmInKmaskOut"
+				data.GoType = *o.In[1].Go
+				data.Const = *o.In[0].Const
+				data.ElemBits = *o.In[1].ElemBits
+				data.Lanes = *o.In[1].Lanes
 			case OneKmaskConstImmIn:
-				OpsImmMaskInMaskOut = append(OpsImmMaskInMaskOut, opData)
+				tplName = "oneKmaskConstImmInKmaskOut"
+				data.GoType = *o.In[1].Go
+				data.Const = *o.In[0].Const
+				data.ElemBits = *o.In[1].ElemBits
+				data.Lanes = *o.In[1].Lanes
 			case PureKmaskIn:
 				return fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations")
 			}
 		}
-	}
-	sortKey := func(op *OpAndArgList) string {
-		return *op.Op.In[0].Go + op.Op.Go
-	}
-	sortBySortKey := func(ops []OpAndArgList) {
-		sort.Slice(ops, func(i, j int) bool {
-			return sortKey(&ops[i]) < sortKey(&ops[j])
-		})
-	}
-	sortBySortKey(Ops)
-	sortBySortKey(OpsImm)
-	sortBySortKey(OpsMask)
-	sortBySortKey(OpsImmMask)
-	sortBySortKey(OpsMaskOut)
-	sortBySortKey(OpsImmInMaskOut)
-	sortBySortKey(OpsMaskInMaskOut)
-	sortBySortKey(OpsImmMaskInMaskOut)
 
-	type templateData struct {
-		Ops                 []OpAndArgList
-		OpsImm              []OpAndArgList
-		OpsMask             []OpAndArgList
-		OpsImmMask          []OpAndArgList
-		OpsMaskOut          []OpAndArgList
-		OpsImmInMaskOut     []OpAndArgList
-		OpsMaskInMaskOut    []OpAndArgList
-		OpsImmMaskInMaskOut []OpAndArgList
+		data.tplName = tplName
+		allData = append(allData, data)
 	}
 
-	err = t.Execute(file, templateData{
-		Ops,
-		OpsImm,
-		OpsMask,
-		OpsImmMask,
-		OpsMaskOut,
-		OpsImmInMaskOut,
-		OpsMaskInMaskOut,
-		OpsImmMaskInMaskOut})
-	if err != nil {
-		return fmt.Errorf("failed to execute template: %w", err)
+	slices.SortFunc(allData, compareTplRuleData)
+
+	for _, data := range allData {
+		if err := ruleTemplates.ExecuteTemplate(file, data.tplName, data); err != nil {
+			return fmt.Errorf("failed to execute template %s for %s: %w", data.tplName, data.GoOp+data.GoType, err)
+		}
 	}
 
 	return nil