blob: 2f2178d647413c1ba1a82f06c0a52251d1c9b8fc [file] [log] [blame]
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
"io"
"os"
"path/filepath"
"slices"
"strings"
"text/template"
)
var (
ruleTemplates = template.Must(template.New("simdRules").Parse(`
{{define "pureVregInVregOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.Asm}} {{.ReverseArgs}})
{{end}}
{{define "oneKmaskInVregOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.Asm}} {{.ReverseArgs}} (VPMOVVec{{.ElemBits}}x{{.Lanes}}ToM <types.TypeMask> mask))
{{end}}
{{define "oneConstImmInVregOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.Asm}} [{{.Const}}] {{.ReverseArgs}})
{{end}}
{{define "oneKmaskConstImmInVregOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.Asm}} [{{.Const}}] {{.ReverseArgs}} (VPMOVVec{{.ElemBits}}x{{.Lanes}}ToM <types.TypeMask> mask))
{{end}}
{{define "pureVregInKmaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => (VPMOVMToVec{{.ElemBits}}x{{.Lanes}} ({{.Asm}} {{.ReverseArgs}}))
{{end}}
{{define "oneKmaskInKmaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => (VPMOVMToVec{{.ElemBits}}x{{.Lanes}} ({{.Asm}} {{.ReverseArgs}} (VPMOVVec{{.ElemBits}}x{{.Lanes}}ToM <types.TypeMask> mask)))
{{end}}
{{define "oneConstImmInKmaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => (VPMOVMToVec{{.ElemBits}}x{{.Lanes}} ({{.Asm}} [{{.Const}}] {{.ReverseArgs}}))
{{end}}
{{define "oneKmaskConstImmInKmaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => (VPMOVMToVec{{.ElemBits}}x{{.Lanes}} ({{.Asm}} [{{.Const}}] {{.ReverseArgs}} (VPMOVVec{{.ElemBits}}x{{.Lanes}}ToM <types.TypeMask> mask)))
{{end}}
`))
)
type tplRuleData struct {
tplName string
GoOp string
GoType string
Args string
Asm string
ReverseArgs string
ElemBits int
Lanes int
Const string
}
func compareTplRuleData(x, y tplRuleData) int {
// TODO should MaskedXYZ compare just after XYZ?
if c := strings.Compare(x.GoOp, y.GoOp); c != 0 {
return c
}
if c := strings.Compare(x.GoType, y.GoType); c != 0 {
return c
}
if c := strings.Compare(x.Const, y.Const); c != 0 {
return c
}
return 0
}
// writeSIMDRules generates the lowering and rewrite rules for ssa and writes it to simdAMD64.rules
// within the specified directory.
func writeSIMDRules(directory string, ops []Operation) error {
outPath := filepath.Join(directory, "src/cmd/compile/internal/ssa/_gen/simdAMD64.rules")
if err := os.MkdirAll(filepath.Dir(outPath), 0755); err != nil {
return fmt.Errorf("failed to create directory for %s: %w", outPath, err)
}
file, err := os.Create(outPath)
if err != nil {
return fmt.Errorf("failed to create %s: %w", outPath, err)
}
defer file.Close()
header := `// Code generated by x/arch/internal/simdgen using 'go run . -xedPath $XED_PATH -o godefs -goroot $GOROOT go.yaml types.yaml categories.yaml'; DO NOT EDIT.
// The AVX instruction encodings orders vector register from right to left, for example:
// VSUBPS X Y Z means Z=Y-X
// The rules here swapped the order of such X and Y because the ssa to prog lowering in simdssa.go assumes a
// left to right order.
// TODO: we should offload the logic to simdssa.go, instead of here.
//
`
if _, err := io.WriteString(file, header); err != nil {
return fmt.Errorf("failed to write header to %s: %w", outPath, err)
}
var allData []tplRuleData
for _, opr := range ops {
opInShape, opOutShape, maskType, _, o, gOp, err := opr.shape()
if err != nil {
return err
}
vregInCnt := len(gOp.In)
if maskType == OneMask {
o.Asm += "Masked"
vregInCnt--
}
o.Asm = fmt.Sprintf("%s%d", o.Asm, *o.Out[0].Bits)
data := tplRuleData{
GoOp: o.Go,
Asm: o.Asm,
}
if vregInCnt == 1 {
data.Args = "x"
data.ReverseArgs = "x"
} else if vregInCnt == 2 {
data.Args = "x y"
data.ReverseArgs = "y x"
} else {
return fmt.Errorf("simdgen does not support more than 2 vreg in inputs")
}
var tplName string
// If class overwrite is happening, that's not really a mask but a vreg.
if opOutShape == OneVregOut || o.Out[0].OverwriteClass != nil {
switch opInShape {
case PureVregIn:
tplName = "pureVregInVregOut"
data.GoType = *o.In[0].Go
case OneKmaskIn:
tplName = "oneKmaskInVregOut"
data.GoType = *o.In[0].Go
data.ElemBits = *o.In[0].ElemBits
data.Lanes = *o.In[0].Lanes
case OneConstImmIn:
tplName = "oneConstImmInVregOut"
data.GoType = *o.In[1].Go
data.Const = *o.In[0].Const
case OneKmaskConstImmIn:
tplName = "oneKmaskConstImmInVregOut"
data.GoType = *o.In[1].Go
data.Const = *o.In[0].Const
data.ElemBits = *o.In[1].ElemBits
data.Lanes = *o.In[1].Lanes
case PureKmaskIn:
return fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations")
}
} else {
// OneKmaskOut case
switch opInShape {
case PureVregIn:
tplName = "pureVregInKmaskOut"
data.GoType = *o.In[0].Go
data.ElemBits = *o.In[0].ElemBits
data.Lanes = *o.In[0].Lanes
case OneKmaskIn:
tplName = "oneKmaskInKmaskOut"
data.GoType = *o.In[0].Go
data.ElemBits = *o.In[0].ElemBits
data.Lanes = *o.In[0].Lanes
case OneConstImmIn:
tplName = "oneConstImmInKmaskOut"
data.GoType = *o.In[1].Go
data.Const = *o.In[0].Const
data.ElemBits = *o.In[1].ElemBits
data.Lanes = *o.In[1].Lanes
case OneKmaskConstImmIn:
tplName = "oneKmaskConstImmInKmaskOut"
data.GoType = *o.In[1].Go
data.Const = *o.In[0].Const
data.ElemBits = *o.In[1].ElemBits
data.Lanes = *o.In[1].Lanes
case PureKmaskIn:
return fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations")
}
}
data.tplName = tplName
allData = append(allData, data)
}
slices.SortFunc(allData, compareTplRuleData)
for _, data := range allData {
if err := ruleTemplates.ExecuteTemplate(file, data.tplName, data); err != nil {
return fmt.Errorf("failed to execute template %s for %s: %w", data.tplName, data.GoOp+data.GoType, err)
}
}
return nil
}