| // Copyright 2025 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package main |
| |
| import ( |
| "bytes" |
| "fmt" |
| "log" |
| "strings" |
| "text/template" |
| ) |
| |
| var ( |
| ssaTemplates = template.Must(template.New("simdSSA").Parse(` |
| {{define "header"}}// Code generated by x/arch/internal/simdgen using 'go run . -xedPath $XED_PATH -o godefs -goroot $GOROOT go.yaml types.yaml categories.yaml'; DO NOT EDIT. |
| |
| package amd64 |
| |
| import ( |
| "cmd/compile/internal/ssa" |
| "cmd/compile/internal/ssagen" |
| "cmd/internal/obj" |
| "cmd/internal/obj/x86" |
| ) |
| |
| func ssaGenSIMDValue(s *ssagen.State, v *ssa.Value) bool { |
| var p *obj.Prog |
| switch v.Op {{"{"}}{{end}} |
| {{define "case"}} |
| case {{.Cases}}: |
| p = {{.Helper}}(s, v) |
| {{end}} |
| {{define "footer"}} |
| default: |
| // Unknown reg shape |
| return false |
| } |
| {{end}} |
| {{define "zeroing"}} |
| // Masked operation are always compiled with zeroing. |
| switch v.Op { |
| case {{.}}: |
| x86.ParseSuffix(p, "Z") |
| } |
| {{end}} |
| {{define "ending"}} |
| return true |
| } |
| {{end}}`)) |
| ) |
| |
| type tplSSAData struct { |
| Cases string |
| Helper string |
| } |
| |
| // writeSIMDSSA generates the ssa to prog lowering codes and writes it to simdssa.go |
| // within the specified directory. |
| func writeSIMDSSA(ops []Operation) *bytes.Buffer { |
| var ZeroingMask []string |
| regInfoKeys := []string{ |
| "v11", |
| "v21", |
| "v2k", |
| "v2kv", |
| "v2kk", |
| "vkv", |
| "v31", |
| "v3kv", |
| "v11Imm8", |
| "vkvImm8", |
| "v21Imm8", |
| "v2kImm8", |
| "v2kkImm8", |
| "v31ResultInArg0", |
| "v3kvResultInArg0", |
| "vfpv", |
| "vfpkv", |
| "vgpvImm8", |
| "vgpImm8", |
| "v2kvImm8", |
| "vkvload", |
| "v21load", |
| "v31loadResultInArg0", |
| "v3kvloadResultInArg0", |
| "v2kvload", |
| "v2kload", |
| "v11load", |
| "v11loadImm8", |
| "vkvloadImm8", |
| "v21loadImm8", |
| "v2kloadImm8", |
| "v2kkloadImm8", |
| "v2kvloadImm8", |
| "v31ResultInArg0Imm8", |
| "v31loadResultInArg0Imm8", |
| "v21ResultInArg0", |
| "v21ResultInArg0Imm8", |
| "v31x0AtIn2ResultInArg0", |
| "v2kvResultInArg0", |
| } |
| regInfoSet := map[string][]string{} |
| for _, key := range regInfoKeys { |
| regInfoSet[key] = []string{} |
| } |
| |
| seen := map[string]struct{}{} |
| allUnseen := make(map[string][]Operation) |
| allUnseenCaseStr := make(map[string][]string) |
| classifyOp := func(op Operation, maskType maskShape, shapeIn inShape, shapeOut outShape, caseStr string, mem memShape) error { |
| regShape, err := op.regShape(mem) |
| if err != nil { |
| return err |
| } |
| if regShape == "v01load" { |
| regShape = "vload" |
| } |
| if shapeOut == OneVregOutAtIn { |
| regShape += "ResultInArg0" |
| } |
| if shapeIn == OneImmIn || shapeIn == OneKmaskImmIn { |
| regShape += "Imm8" |
| } |
| regShape, err = rewriteVecAsScalarRegInfo(op, regShape) |
| if err != nil { |
| return err |
| } |
| if _, ok := regInfoSet[regShape]; !ok { |
| allUnseen[regShape] = append(allUnseen[regShape], op) |
| allUnseenCaseStr[regShape] = append(allUnseenCaseStr[regShape], caseStr) |
| } |
| regInfoSet[regShape] = append(regInfoSet[regShape], caseStr) |
| if mem == NoMem && op.hasMaskedMerging(maskType, shapeOut) { |
| regShapeMerging := regShape |
| if shapeOut != OneVregOutAtIn { |
| // We have to copy the slice here becasue the sort will be visible from other |
| // aliases when no reslicing is happening. |
| newIn := make([]Operand, len(op.In), len(op.In)+1) |
| copy(newIn, op.In) |
| op.In = newIn |
| op.In = append(op.In, op.Out[0]) |
| op.sortOperand() |
| regShapeMerging, err = op.regShape(mem) |
| regShapeMerging += "ResultInArg0" |
| } |
| if err != nil { |
| return err |
| } |
| if _, ok := regInfoSet[regShapeMerging]; !ok { |
| allUnseen[regShapeMerging] = append(allUnseen[regShapeMerging], op) |
| allUnseenCaseStr[regShapeMerging] = append(allUnseenCaseStr[regShapeMerging], caseStr+"Merging") |
| } |
| regInfoSet[regShapeMerging] = append(regInfoSet[regShapeMerging], caseStr+"Merging") |
| } |
| return nil |
| } |
| for _, op := range ops { |
| shapeIn, shapeOut, maskType, _, gOp := op.shape() |
| asm := machineOpName(maskType, gOp) |
| if _, ok := seen[asm]; ok { |
| continue |
| } |
| seen[asm] = struct{}{} |
| caseStr := fmt.Sprintf("ssa.OpAMD64%s", asm) |
| isZeroMasking := false |
| if shapeIn == OneKmaskIn || shapeIn == OneKmaskImmIn { |
| if gOp.Zeroing == nil || *gOp.Zeroing { |
| ZeroingMask = append(ZeroingMask, caseStr) |
| isZeroMasking = true |
| } |
| } |
| if err := classifyOp(op, maskType, shapeIn, shapeOut, caseStr, NoMem); err != nil { |
| panic(err) |
| } |
| if op.MemFeatures != nil && *op.MemFeatures == "vbcst" { |
| // Make a full vec memory variant |
| op = rewriteLastVregToMem(op) |
| // Ignore the error |
| // an error could be triggered by [checkVecAsScalar]. |
| // TODO: make [checkVecAsScalar] aware of mem ops. |
| if err := classifyOp(op, maskType, shapeIn, shapeOut, caseStr+"load", VregMemIn); err != nil { |
| if *Verbose { |
| log.Printf("Seen error: %e", err) |
| } |
| } else if isZeroMasking { |
| ZeroingMask = append(ZeroingMask, caseStr+"load") |
| } |
| } |
| } |
| if len(allUnseen) != 0 { |
| allKeys := make([]string, 0) |
| for k := range allUnseen { |
| allKeys = append(allKeys, k) |
| } |
| panic(fmt.Errorf("unsupported register constraint for prog, please update gen_simdssa.go and amd64/ssa.go: %+v\nAll keys: %v\n, cases: %v\n", allUnseen, allKeys, allUnseenCaseStr)) |
| } |
| |
| buffer := new(bytes.Buffer) |
| |
| if err := ssaTemplates.ExecuteTemplate(buffer, "header", nil); err != nil { |
| panic(fmt.Errorf("failed to execute header template: %w", err)) |
| } |
| |
| for _, regShape := range regInfoKeys { |
| // Stable traversal of regInfoSet |
| cases := regInfoSet[regShape] |
| if len(cases) == 0 { |
| continue |
| } |
| data := tplSSAData{ |
| Cases: strings.Join(cases, ",\n\t\t"), |
| Helper: "simd" + capitalizeFirst(regShape), |
| } |
| if err := ssaTemplates.ExecuteTemplate(buffer, "case", data); err != nil { |
| panic(fmt.Errorf("failed to execute case template for %s: %w", regShape, err)) |
| } |
| } |
| |
| if err := ssaTemplates.ExecuteTemplate(buffer, "footer", nil); err != nil { |
| panic(fmt.Errorf("failed to execute footer template: %w", err)) |
| } |
| |
| if len(ZeroingMask) != 0 { |
| if err := ssaTemplates.ExecuteTemplate(buffer, "zeroing", strings.Join(ZeroingMask, ",\n\t\t")); err != nil { |
| panic(fmt.Errorf("failed to execute footer template: %w", err)) |
| } |
| } |
| |
| if err := ssaTemplates.ExecuteTemplate(buffer, "ending", nil); err != nil { |
| panic(fmt.Errorf("failed to execute footer template: %w", err)) |
| } |
| |
| return buffer |
| } |