| // Copyright 2025 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package main |
| |
| import ( |
| "fmt" |
| "io" |
| "os" |
| "path/filepath" |
| "slices" |
| "strings" |
| "text/template" |
| ) |
| |
| var ( |
| ruleTemplates = template.Must(template.New("simdRules").Parse(` |
| {{define "pureVregInVregOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.Asm}} {{.ReverseArgs}}) |
| {{end}} |
| {{define "oneKmaskInVregOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.Asm}} {{.ReverseArgs}} (VPMOVVec{{.ElemBits}}x{{.Lanes}}ToM <types.TypeMask> mask)) |
| {{end}} |
| {{define "oneConstImmInVregOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.Asm}} [{{.Const}}] {{.ReverseArgs}}) |
| {{end}} |
| {{define "oneKmaskConstImmInVregOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.Asm}} [{{.Const}}] {{.ReverseArgs}} (VPMOVVec{{.ElemBits}}x{{.Lanes}}ToM <types.TypeMask> mask)) |
| {{end}} |
| {{define "pureVregInKmaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => (VPMOVMToVec{{.ElemBits}}x{{.Lanes}} ({{.Asm}} {{.ReverseArgs}})) |
| {{end}} |
| {{define "oneKmaskInKmaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => (VPMOVMToVec{{.ElemBits}}x{{.Lanes}} ({{.Asm}} {{.ReverseArgs}} (VPMOVVec{{.ElemBits}}x{{.Lanes}}ToM <types.TypeMask> mask))) |
| {{end}} |
| {{define "oneConstImmInKmaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => (VPMOVMToVec{{.ElemBits}}x{{.Lanes}} ({{.Asm}} [{{.Const}}] {{.ReverseArgs}})) |
| {{end}} |
| {{define "oneKmaskConstImmInKmaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => (VPMOVMToVec{{.ElemBits}}x{{.Lanes}} ({{.Asm}} [{{.Const}}] {{.ReverseArgs}} (VPMOVVec{{.ElemBits}}x{{.Lanes}}ToM <types.TypeMask> mask))) |
| {{end}} |
| `)) |
| ) |
| |
| type tplRuleData struct { |
| tplName string |
| GoOp string |
| GoType string |
| Args string |
| Asm string |
| ReverseArgs string |
| ElemBits int |
| Lanes int |
| Const string |
| } |
| |
| func compareTplRuleData(x, y tplRuleData) int { |
| // TODO should MaskedXYZ compare just after XYZ? |
| if c := strings.Compare(x.GoOp, y.GoOp); c != 0 { |
| return c |
| } |
| if c := strings.Compare(x.GoType, y.GoType); c != 0 { |
| return c |
| } |
| if c := strings.Compare(x.Const, y.Const); c != 0 { |
| return c |
| } |
| return 0 |
| } |
| |
| // writeSIMDRules generates the lowering and rewrite rules for ssa and writes it to simdAMD64.rules |
| // within the specified directory. |
| func writeSIMDRules(directory string, ops []Operation) error { |
| |
| outPath := filepath.Join(directory, "src/cmd/compile/internal/ssa/_gen/simdAMD64.rules") |
| if err := os.MkdirAll(filepath.Dir(outPath), 0755); err != nil { |
| return fmt.Errorf("failed to create directory for %s: %w", outPath, err) |
| } |
| file, err := os.Create(outPath) |
| if err != nil { |
| return fmt.Errorf("failed to create %s: %w", outPath, err) |
| } |
| defer file.Close() |
| |
| header := `// Code generated by x/arch/internal/simdgen using 'go run . -xedPath $XED_PATH -o godefs -goroot $GOROOT go.yaml types.yaml categories.yaml'; DO NOT EDIT. |
| |
| // The AVX instruction encodings orders vector register from right to left, for example: |
| // VSUBPS X Y Z means Z=Y-X |
| // The rules here swapped the order of such X and Y because the ssa to prog lowering in simdssa.go assumes a |
| // left to right order. |
| // TODO: we should offload the logic to simdssa.go, instead of here. |
| // |
| |
| ` |
| if _, err := io.WriteString(file, header); err != nil { |
| return fmt.Errorf("failed to write header to %s: %w", outPath, err) |
| } |
| |
| var allData []tplRuleData |
| |
| for _, opr := range ops { |
| opInShape, opOutShape, maskType, _, o, gOp, err := opr.shape() |
| if err != nil { |
| return err |
| } |
| vregInCnt := len(gOp.In) |
| if maskType == OneMask { |
| o.Asm += "Masked" |
| vregInCnt-- |
| } |
| o.Asm = fmt.Sprintf("%s%d", o.Asm, *o.Out[0].Bits) |
| |
| data := tplRuleData{ |
| GoOp: o.Go, |
| Asm: o.Asm, |
| } |
| |
| if vregInCnt == 1 { |
| data.Args = "x" |
| data.ReverseArgs = "x" |
| } else if vregInCnt == 2 { |
| data.Args = "x y" |
| data.ReverseArgs = "y x" |
| } else { |
| return fmt.Errorf("simdgen does not support more than 2 vreg in inputs") |
| } |
| |
| var tplName string |
| // If class overwrite is happening, that's not really a mask but a vreg. |
| if opOutShape == OneVregOut || o.Out[0].OverwriteClass != nil { |
| switch opInShape { |
| case PureVregIn: |
| tplName = "pureVregInVregOut" |
| data.GoType = *o.In[0].Go |
| case OneKmaskIn: |
| tplName = "oneKmaskInVregOut" |
| data.GoType = *o.In[0].Go |
| data.ElemBits = *o.In[0].ElemBits |
| data.Lanes = *o.In[0].Lanes |
| case OneConstImmIn: |
| tplName = "oneConstImmInVregOut" |
| data.GoType = *o.In[1].Go |
| data.Const = *o.In[0].Const |
| case OneKmaskConstImmIn: |
| tplName = "oneKmaskConstImmInVregOut" |
| data.GoType = *o.In[1].Go |
| data.Const = *o.In[0].Const |
| data.ElemBits = *o.In[1].ElemBits |
| data.Lanes = *o.In[1].Lanes |
| case PureKmaskIn: |
| return fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations") |
| } |
| } else { |
| // OneKmaskOut case |
| switch opInShape { |
| case PureVregIn: |
| tplName = "pureVregInKmaskOut" |
| data.GoType = *o.In[0].Go |
| data.ElemBits = *o.In[0].ElemBits |
| data.Lanes = *o.In[0].Lanes |
| case OneKmaskIn: |
| tplName = "oneKmaskInKmaskOut" |
| data.GoType = *o.In[0].Go |
| data.ElemBits = *o.In[0].ElemBits |
| data.Lanes = *o.In[0].Lanes |
| case OneConstImmIn: |
| tplName = "oneConstImmInKmaskOut" |
| data.GoType = *o.In[1].Go |
| data.Const = *o.In[0].Const |
| data.ElemBits = *o.In[1].ElemBits |
| data.Lanes = *o.In[1].Lanes |
| case OneKmaskConstImmIn: |
| tplName = "oneKmaskConstImmInKmaskOut" |
| data.GoType = *o.In[1].Go |
| data.Const = *o.In[0].Const |
| data.ElemBits = *o.In[1].ElemBits |
| data.Lanes = *o.In[1].Lanes |
| case PureKmaskIn: |
| return fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations") |
| } |
| } |
| |
| data.tplName = tplName |
| allData = append(allData, data) |
| } |
| |
| slices.SortFunc(allData, compareTplRuleData) |
| |
| for _, data := range allData { |
| if err := ruleTemplates.ExecuteTemplate(file, data.tplName, data); err != nil { |
| return fmt.Errorf("failed to execute template %s for %s: %w", data.tplName, data.GoOp+data.GoType, err) |
| } |
| } |
| |
| return nil |
| } |