blob: 3d99dd2a81a1a4344c46061af82e45532edbbb11 [file] [log] [blame]
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"bytes"
"fmt"
"log"
"sort"
"strings"
)
const simdMachineOpsTmpl = `
package main
func simdAMD64Ops(v11, v21, v2k, vkv, v2kv, v2kk, v31, v3kv, vgpv, vgp, vfpv, vfpkv, w11, w21, w2k, wkw, w2kw, w2kk, w31, w3kw, wgpw, wgp, wfpw, wfpkw,
wkwload, v21load, v31load, v11load, w21load, w31load, w2kload, w2kwload, w11load, w3kwload, w2kkload, v31x0AtIn2 regInfo) []opData {
return []opData{
{{- range .OpsData }}
{name: "{{.OpName}}", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", commutative: {{.Comm}}, typ: "{{.Type}}", resultInArg0: {{.ResultInArg0}}},
{{- end }}
{{- range .OpsDataImm }}
{name: "{{.OpName}}", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", aux: "UInt8", commutative: {{.Comm}}, typ: "{{.Type}}", resultInArg0: {{.ResultInArg0}}},
{{- end }}
{{- range .OpsDataLoad}}
{name: "{{.OpName}}", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", commutative: {{.Comm}}, typ: "{{.Type}}", aux: "SymOff", symEffect: "Read", resultInArg0: {{.ResultInArg0}}},
{{- end}}
{{- range .OpsDataImmLoad}}
{name: "{{.OpName}}", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", commutative: {{.Comm}}, typ: "{{.Type}}", aux: "SymValAndOff", symEffect: "Read", resultInArg0: {{.ResultInArg0}}},
{{- end}}
{{- range .OpsDataMerging }}
{name: "{{.OpName}}Merging", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", commutative: false, typ: "{{.Type}}", resultInArg0: true},
{{- end }}
{{- range .OpsDataImmMerging }}
{name: "{{.OpName}}Merging", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", aux: "UInt8", commutative: false, typ: "{{.Type}}", resultInArg0: true},
{{- end }}
}
}
`
// writeSIMDMachineOps generates the machine ops and writes it to simdAMD64ops.go
// within the specified directory.
func writeSIMDMachineOps(ops []Operation) *bytes.Buffer {
t := templateOf(simdMachineOpsTmpl, "simdAMD64Ops")
buffer := new(bytes.Buffer)
buffer.WriteString(generatedHeader)
type opData struct {
OpName string
Asm string
OpInLen int
RegInfo string
Comm bool
Type string
ResultInArg0 bool
}
type machineOpsData struct {
OpsData []opData
OpsDataImm []opData
OpsDataLoad []opData
OpsDataImmLoad []opData
OpsDataMerging []opData
OpsDataImmMerging []opData
}
regInfoSet := map[string]bool{
"v11": true, "v21": true, "v2k": true, "v2kv": true, "v2kk": true, "vkv": true, "v31": true, "v3kv": true, "vgpv": true, "vgp": true, "vfpv": true, "vfpkv": true,
"w11": true, "w21": true, "w2k": true, "w2kw": true, "w2kk": true, "wkw": true, "w31": true, "w3kw": true, "wgpw": true, "wgp": true, "wfpw": true, "wfpkw": true,
"wkwload": true, "v21load": true, "v31load": true, "v11load": true, "w21load": true, "w31load": true, "w2kload": true, "w2kwload": true, "w11load": true,
"w3kwload": true, "w2kkload": true, "v31x0AtIn2": true}
opsData := make([]opData, 0)
opsDataImm := make([]opData, 0)
opsDataLoad := make([]opData, 0)
opsDataImmLoad := make([]opData, 0)
opsDataMerging := make([]opData, 0)
opsDataImmMerging := make([]opData, 0)
// Determine the "best" version of an instruction to use
best := make(map[string]Operation)
var mOpOrder []string
countOverrides := func(s []Operand) int {
a := 0
for _, o := range s {
if o.OverwriteBase != nil {
a++
}
}
return a
}
for _, op := range ops {
_, _, maskType, _, gOp := op.shape()
asm := machineOpName(maskType, gOp)
other, ok := best[asm]
if !ok {
best[asm] = op
mOpOrder = append(mOpOrder, asm)
continue
}
if !op.Commutative && other.Commutative { // if there's a non-commutative version of the op, it wins.
best[asm] = op
continue
}
// see if "op" is better than "other"
if countOverrides(op.In)+countOverrides(op.Out) < countOverrides(other.In)+countOverrides(other.Out) {
best[asm] = op
}
}
regInfoErrs := make([]error, 0)
regInfoMissing := make(map[string]bool, 0)
for _, asm := range mOpOrder {
op := best[asm]
shapeIn, shapeOut, maskType, _, gOp := op.shape()
// TODO: all our masked operations are now zeroing, we need to generate machine ops with merging masks, maybe copy
// one here with a name suffix "Merging". The rewrite rules will need them.
makeRegInfo := func(op Operation, mem memShape) (string, error) {
regInfo, err := op.regShape(mem)
if err != nil {
panic(err)
}
regInfo, err = rewriteVecAsScalarRegInfo(op, regInfo)
if err != nil {
if mem == NoMem || mem == InvalidMem {
panic(err)
}
return "", err
}
if regInfo == "v01load" {
regInfo = "vload"
}
// Makes AVX512 operations use upper registers
if strings.Contains(op.CPUFeature, "AVX512") {
regInfo = strings.ReplaceAll(regInfo, "v", "w")
}
if _, ok := regInfoSet[regInfo]; !ok {
regInfoErrs = append(regInfoErrs, fmt.Errorf("unsupported register constraint, please update the template and AMD64Ops.go: %s. Op is %s", regInfo, op))
regInfoMissing[regInfo] = true
}
return regInfo, nil
}
regInfo, err := makeRegInfo(op, NoMem)
if err != nil {
panic(err)
}
var outType string
if shapeOut == OneVregOut || shapeOut == OneVregOutAtIn || gOp.Out[0].OverwriteClass != nil {
// If class overwrite is happening, that's not really a mask but a vreg.
outType = fmt.Sprintf("Vec%d", *gOp.Out[0].Bits)
} else if shapeOut == OneGregOut {
outType = gOp.GoType() // this is a straight Go type, not a VecNNN type
} else if shapeOut == OneKmaskOut {
outType = "Mask"
} else {
panic(fmt.Errorf("simdgen does not recognize this output shape: %d", shapeOut))
}
resultInArg0 := false
if shapeOut == OneVregOutAtIn {
resultInArg0 = true
}
var memOpData *opData
regInfoMerging := regInfo
hasMerging := false
if op.MemFeatures != nil && *op.MemFeatures == "vbcst" {
// Right now we only have vbcst case
// Make a full vec memory variant.
opMem := rewriteLastVregToMem(op)
regInfo, err := makeRegInfo(opMem, VregMemIn)
if err != nil {
// Just skip it if it's non nill.
// an error could be triggered by [checkVecAsScalar].
// TODO: make [checkVecAsScalar] aware of mem ops.
if *Verbose {
log.Printf("Seen error: %e", err)
}
} else {
memOpData = &opData{asm + "load", gOp.Asm, len(gOp.In) + 1, regInfo, false, outType, resultInArg0}
}
}
hasMerging = gOp.hasMaskedMerging(maskType, shapeOut)
if hasMerging && !resultInArg0 {
// We have to copy the slice here becasue the sort will be visible from other
// aliases when no reslicing is happening.
newIn := make([]Operand, len(op.In), len(op.In)+1)
copy(newIn, op.In)
op.In = newIn
op.In = append(op.In, op.Out[0])
op.sortOperand()
regInfoMerging, err = makeRegInfo(op, NoMem)
if err != nil {
panic(err)
}
}
if shapeIn == OneImmIn || shapeIn == OneKmaskImmIn {
opsDataImm = append(opsDataImm, opData{asm, gOp.Asm, len(gOp.In), regInfo, gOp.Commutative, outType, resultInArg0})
if memOpData != nil {
if *op.MemFeatures != "vbcst" {
panic("simdgen only knows vbcst for mem ops for now")
}
opsDataImmLoad = append(opsDataImmLoad, *memOpData)
}
if hasMerging {
mergingLen := len(gOp.In)
if !resultInArg0 {
mergingLen++
}
opsDataImmMerging = append(opsDataImmMerging, opData{asm, gOp.Asm, mergingLen, regInfoMerging, gOp.Commutative, outType, resultInArg0})
}
} else {
opsData = append(opsData, opData{asm, gOp.Asm, len(gOp.In), regInfo, gOp.Commutative, outType, resultInArg0})
if memOpData != nil {
if *op.MemFeatures != "vbcst" {
panic("simdgen only knows vbcst for mem ops for now")
}
opsDataLoad = append(opsDataLoad, *memOpData)
}
if hasMerging {
mergingLen := len(gOp.In)
if !resultInArg0 {
mergingLen++
}
opsDataMerging = append(opsDataMerging, opData{asm, gOp.Asm, mergingLen, regInfoMerging, gOp.Commutative, outType, resultInArg0})
}
}
}
if len(regInfoErrs) != 0 {
for _, e := range regInfoErrs {
log.Printf("Errors: %e\n", e)
}
panic(fmt.Errorf("these regInfo unseen: %v", regInfoMissing))
}
sort.Slice(opsData, func(i, j int) bool {
return compareNatural(opsData[i].OpName, opsData[j].OpName) < 0
})
sort.Slice(opsDataImm, func(i, j int) bool {
return compareNatural(opsDataImm[i].OpName, opsDataImm[j].OpName) < 0
})
sort.Slice(opsDataLoad, func(i, j int) bool {
return compareNatural(opsDataLoad[i].OpName, opsDataLoad[j].OpName) < 0
})
sort.Slice(opsDataImmLoad, func(i, j int) bool {
return compareNatural(opsDataImmLoad[i].OpName, opsDataImmLoad[j].OpName) < 0
})
sort.Slice(opsDataMerging, func(i, j int) bool {
return compareNatural(opsDataMerging[i].OpName, opsDataMerging[j].OpName) < 0
})
sort.Slice(opsDataImmMerging, func(i, j int) bool {
return compareNatural(opsDataImmMerging[i].OpName, opsDataImmMerging[j].OpName) < 0
})
err := t.Execute(buffer, machineOpsData{opsData, opsDataImm, opsDataLoad, opsDataImmLoad,
opsDataMerging, opsDataImmMerging})
if err != nil {
panic(fmt.Errorf("failed to execute template: %w", err))
}
return buffer
}