blob: c9d8693aa15f29eea422c1e964673750539320ae [file] [log] [blame]
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"bytes"
"fmt"
"log"
"strings"
"text/template"
)
var (
ssaTemplates = template.Must(template.New("simdSSA").Parse(`
{{define "header"}}// Code generated by x/arch/internal/simdgen using 'go run . -xedPath $XED_PATH -o godefs -goroot $GOROOT go.yaml types.yaml categories.yaml'; DO NOT EDIT.
package amd64
import (
"cmd/compile/internal/ssa"
"cmd/compile/internal/ssagen"
"cmd/internal/obj"
"cmd/internal/obj/x86"
)
func ssaGenSIMDValue(s *ssagen.State, v *ssa.Value) bool {
var p *obj.Prog
switch v.Op {{"{"}}{{end}}
{{define "case"}}
case {{.Cases}}:
p = {{.Helper}}(s, v)
{{end}}
{{define "footer"}}
default:
// Unknown reg shape
return false
}
{{end}}
{{define "zeroing"}}
// Masked operation are always compiled with zeroing.
switch v.Op {
case {{.}}:
x86.ParseSuffix(p, "Z")
}
{{end}}
{{define "ending"}}
return true
}
{{end}}`))
)
type tplSSAData struct {
Cases string
Helper string
}
// writeSIMDSSA generates the ssa to prog lowering codes and writes it to simdssa.go
// within the specified directory.
func writeSIMDSSA(ops []Operation) *bytes.Buffer {
var ZeroingMask []string
regInfoKeys := []string{
"v11",
"v21",
"v2k",
"v2kv",
"v2kk",
"vkv",
"v31",
"v3kv",
"v11Imm8",
"vkvImm8",
"v21Imm8",
"v2kImm8",
"v2kkImm8",
"v31ResultInArg0",
"v3kvResultInArg0",
"vfpv",
"vfpkv",
"vgpvImm8",
"vgpImm8",
"v2kvImm8",
"vkvload",
"v21load",
"v31loadResultInArg0",
"v3kvloadResultInArg0",
"v2kvload",
"v2kload",
"v11load",
"v11loadImm8",
"vkvloadImm8",
"v21loadImm8",
"v2kloadImm8",
"v2kkloadImm8",
"v2kvloadImm8",
"v31ResultInArg0Imm8",
"v31loadResultInArg0Imm8",
"v21ResultInArg0",
"v21ResultInArg0Imm8",
"v31x0AtIn2ResultInArg0",
"v2kvResultInArg0",
}
regInfoSet := map[string][]string{}
for _, key := range regInfoKeys {
regInfoSet[key] = []string{}
}
seen := map[string]struct{}{}
allUnseen := make(map[string][]Operation)
allUnseenCaseStr := make(map[string][]string)
classifyOp := func(op Operation, maskType maskShape, shapeIn inShape, shapeOut outShape, caseStr string, mem memShape) error {
regShape, err := op.regShape(mem)
if err != nil {
return err
}
if regShape == "v01load" {
regShape = "vload"
}
if shapeOut == OneVregOutAtIn {
regShape += "ResultInArg0"
}
if shapeIn == OneImmIn || shapeIn == OneKmaskImmIn {
regShape += "Imm8"
}
regShape, err = rewriteVecAsScalarRegInfo(op, regShape)
if err != nil {
return err
}
if _, ok := regInfoSet[regShape]; !ok {
allUnseen[regShape] = append(allUnseen[regShape], op)
allUnseenCaseStr[regShape] = append(allUnseenCaseStr[regShape], caseStr)
}
regInfoSet[regShape] = append(regInfoSet[regShape], caseStr)
if mem == NoMem && op.hasMaskedMerging(maskType, shapeOut) {
regShapeMerging := regShape
if shapeOut != OneVregOutAtIn {
// We have to copy the slice here becasue the sort will be visible from other
// aliases when no reslicing is happening.
newIn := make([]Operand, len(op.In), len(op.In)+1)
copy(newIn, op.In)
op.In = newIn
op.In = append(op.In, op.Out[0])
op.sortOperand()
regShapeMerging, err = op.regShape(mem)
regShapeMerging += "ResultInArg0"
}
if err != nil {
return err
}
if _, ok := regInfoSet[regShapeMerging]; !ok {
allUnseen[regShapeMerging] = append(allUnseen[regShapeMerging], op)
allUnseenCaseStr[regShapeMerging] = append(allUnseenCaseStr[regShapeMerging], caseStr+"Merging")
}
regInfoSet[regShapeMerging] = append(regInfoSet[regShapeMerging], caseStr+"Merging")
}
return nil
}
for _, op := range ops {
shapeIn, shapeOut, maskType, _, gOp := op.shape()
asm := machineOpName(maskType, gOp)
if _, ok := seen[asm]; ok {
continue
}
seen[asm] = struct{}{}
caseStr := fmt.Sprintf("ssa.OpAMD64%s", asm)
isZeroMasking := false
if shapeIn == OneKmaskIn || shapeIn == OneKmaskImmIn {
if gOp.Zeroing == nil || *gOp.Zeroing {
ZeroingMask = append(ZeroingMask, caseStr)
isZeroMasking = true
}
}
if err := classifyOp(op, maskType, shapeIn, shapeOut, caseStr, NoMem); err != nil {
panic(err)
}
if op.MemFeatures != nil && *op.MemFeatures == "vbcst" {
// Make a full vec memory variant
op = rewriteLastVregToMem(op)
// Ignore the error
// an error could be triggered by [checkVecAsScalar].
// TODO: make [checkVecAsScalar] aware of mem ops.
if err := classifyOp(op, maskType, shapeIn, shapeOut, caseStr+"load", VregMemIn); err != nil {
if *Verbose {
log.Printf("Seen error: %e", err)
}
} else if isZeroMasking {
ZeroingMask = append(ZeroingMask, caseStr+"load")
}
}
}
if len(allUnseen) != 0 {
allKeys := make([]string, 0)
for k := range allUnseen {
allKeys = append(allKeys, k)
}
panic(fmt.Errorf("unsupported register constraint for prog, please update gen_simdssa.go and amd64/ssa.go: %+v\nAll keys: %v\n, cases: %v\n", allUnseen, allKeys, allUnseenCaseStr))
}
buffer := new(bytes.Buffer)
if err := ssaTemplates.ExecuteTemplate(buffer, "header", nil); err != nil {
panic(fmt.Errorf("failed to execute header template: %w", err))
}
for _, regShape := range regInfoKeys {
// Stable traversal of regInfoSet
cases := regInfoSet[regShape]
if len(cases) == 0 {
continue
}
data := tplSSAData{
Cases: strings.Join(cases, ",\n\t\t"),
Helper: "simd" + capitalizeFirst(regShape),
}
if err := ssaTemplates.ExecuteTemplate(buffer, "case", data); err != nil {
panic(fmt.Errorf("failed to execute case template for %s: %w", regShape, err))
}
}
if err := ssaTemplates.ExecuteTemplate(buffer, "footer", nil); err != nil {
panic(fmt.Errorf("failed to execute footer template: %w", err))
}
if len(ZeroingMask) != 0 {
if err := ssaTemplates.ExecuteTemplate(buffer, "zeroing", strings.Join(ZeroingMask, ",\n\t\t")); err != nil {
panic(fmt.Errorf("failed to execute footer template: %w", err))
}
}
if err := ssaTemplates.ExecuteTemplate(buffer, "ending", nil); err != nil {
panic(fmt.Errorf("failed to execute footer template: %w", err))
}
return buffer
}