blob: 60e4845dfb71b82118a50d28f0472b9a896308ea [file] [edit]
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build goexperiment.simd && amd64
package simd_test
import (
"simd/archsimd"
"testing"
)
func TestShift(t *testing.T) {
if !archsimd.X86.AVX2() {
t.Skip("requires AVX2")
}
testInt32x4Binary(t,
func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftLeft(y.AsUint32x4()) },
map2(func(x, y int32) int32 { return x << uint32(y) }))
testInt32x4Binary(t,
func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftRight(y.AsUint32x4()) },
map2(func(x, y int32) int32 { return x >> uint32(y) }))
testUint32x4Binary(t,
func(x, y archsimd.Uint32x4) archsimd.Uint32x4 { return x.ShiftRight(y) },
map2(func(x, y uint32) uint32 { return x >> y }))
}
func TestShiftAll(t *testing.T) {
// Test both const and non-const shifts.
// Test both regular and over-shifts.
hide := hideConst[uint64]
// ShiftAllLeft
testInt32x4Unary(t,
func(x archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllLeft(2) },
map1(func(x int32) int32 { return x << 2 }))
testInt32x4Unary(t,
func(x archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllLeft(hide(2)) },
map1(func(x int32) int32 { return x << hide(2) }))
// Ironically, we have to hide the constant in the want function so the
// compiler doesn't complain about a silly shift.
testInt32x4Unary(t,
func(x archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllLeft(0x1000) },
map1(func(x int32) int32 { return x << hide(0x1000) }))
testInt32x4Unary(t,
func(x archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllLeft(hide(0x1000)) },
map1(func(x int32) int32 { return x << hide(0x1000) }))
// Signed ShiftAllRight
testInt32x4Unary(t,
func(x archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllRight(2) },
map1(func(x int32) int32 { return x >> 2 }))
testInt32x4Unary(t,
func(x archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllRight(hide(2)) },
map1(func(x int32) int32 { return x >> hide(2) }))
testInt32x4Unary(t,
func(x archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllRight(0x1000) },
map1(func(x int32) int32 { return x >> hide(0x1000) }))
testInt32x4Unary(t,
func(x archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllRight(hide(0x1000)) },
map1(func(x int32) int32 { return x >> hide(0x1000) }))
// Unsigned ShiftAllRight
testUint32x4Unary(t,
func(x archsimd.Uint32x4) archsimd.Uint32x4 { return x.ShiftAllRight(2) },
map1(func(x uint32) uint32 { return x >> 2 }))
testUint32x4Unary(t,
func(x archsimd.Uint32x4) archsimd.Uint32x4 { return x.ShiftAllRight(hide(2)) },
map1(func(x uint32) uint32 { return x >> hide(2) }))
testUint32x4Unary(t,
func(x archsimd.Uint32x4) archsimd.Uint32x4 { return x.ShiftAllRight(0x1000) },
map1(func(x uint32) uint32 { return x >> hide(0x1000) }))
testUint32x4Unary(t,
func(x archsimd.Uint32x4) archsimd.Uint32x4 { return x.ShiftAllRight(hide(0x1000)) },
map1(func(x uint32) uint32 { return x >> hide(0x1000) }))
}
func concatInt32s(x, y int32) int64 {
return (int64(x) << 32) | int64(uint32(y))
}
func concatUint32s(x, y uint32) uint64 {
return (uint64(x) << 32) | uint64(y)
}
func TestShiftAllConcat(t *testing.T) {
if !archsimd.X86.AVX512VBMI2() {
t.Skip("requires AVX512-VBMI2")
}
// Note that unlike their non-Concat counterparts, these wrap the shift count.
hide := hideConst[uint8]
// ShiftAllLeftConcat
salc := func(shift uint8) func(x, y int32) int32 {
return func(x, y int32) int32 {
return int32(concatInt32s(x, y) >> (32 - shift%32))
}
}
testInt32x4Binary(t,
func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllLeftConcat(2, y) },
map2(salc(2)))
testInt32x4Binary(t,
func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllLeftConcat(hide(2), y) },
map2(salc(hide(2))))
// TODO: If we expand the shift from uint8, add larger cases (e.g., 0x1000).
// The uint8 conversion is only here so the build will fail if we change it.
testInt32x4Binary(t,
func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllLeftConcat(uint8(128), y) },
map2(salc(128)))
testInt32x4Binary(t,
func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllLeftConcat(hide(128), y) },
map2(salc(hide(128))))
// Signed ShiftAllRightConcat
sarc := func(shift uint8) func(x, y int32) int32 {
return func(x, y int32) int32 {
return int32(concatInt32s(y, x) >> (shift % 32))
}
}
testInt32x4Binary(t,
func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllRightConcat(2, y) },
map2(sarc(2)))
testInt32x4Binary(t,
func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllRightConcat(hide(2), y) },
map2(sarc(hide(2))))
testInt32x4Binary(t,
func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllRightConcat(128, y) },
map2(sarc(128)))
testInt32x4Binary(t,
func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllRightConcat(hide(128), y) },
map2(sarc(hide(128))))
// Unsigned ShiftAllRightConcat
usarc := func(shift uint8) func(x, y uint32) uint32 {
return func(x, y uint32) uint32 {
return uint32(concatUint32s(y, x) >> (shift % 32))
}
}
testUint32x4Binary(t,
func(x, y archsimd.Uint32x4) archsimd.Uint32x4 { return x.ShiftAllRightConcat(2, y) },
map2(usarc(2)))
testUint32x4Binary(t,
func(x, y archsimd.Uint32x4) archsimd.Uint32x4 { return x.ShiftAllRightConcat(hide(2), y) },
map2(usarc(hide(2))))
testUint32x4Binary(t,
func(x, y archsimd.Uint32x4) archsimd.Uint32x4 { return x.ShiftAllRightConcat(128, y) },
map2(usarc(128)))
testUint32x4Binary(t,
func(x, y archsimd.Uint32x4) archsimd.Uint32x4 { return x.ShiftAllRightConcat(hide(128), y) },
map2(usarc(hide(128))))
}
func TestShiftConcat(t *testing.T) {
if !archsimd.X86.AVX512VBMI2() {
t.Skip("requires AVX512-VBMI2")
}
// Note that unlike their non-Concat counterparts, these wrap the shift count.
testInt32x4Ternary(t,
func(x, y, z archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftLeftConcat(y, z.AsUint32x4()) },
map3(func(x, y, z int32) int32 {
return int32(concatInt32s(x, y) >> (32 - uint32(z)%32))
}))
testInt32x4Ternary(t,
func(x, y, z archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftRightConcat(y, z.AsUint32x4()) },
map3(func(x, y, z int32) int32 {
return int32(concatInt32s(y, x) >> (uint32(z) % 32))
}))
testUint32x4Ternary(t,
func(x, y, z archsimd.Uint32x4) archsimd.Uint32x4 { return x.ShiftRightConcat(y, z) },
map3(func(x, y, z uint32) uint32 {
return uint32(concatUint32s(y, x) >> (z % 32))
}))
}
func TestConcatShiftBytesRight(t *testing.T) {
hide := hideConst[uint8]
csbr := func(shift uint8) func(x, y []uint8) []uint8 {
return func(x, y []uint8) []uint8 {
z := make([]uint8, len(x))
for i := range z {
target := i + int(shift)
if target < 16 {
z[i] = y[target]
} else if target < 32 {
z[i] = x[(target - 16)]
}
}
return z
}
}
t.Run("Uint8x16", func(t *testing.T) {
if !archsimd.X86.AVX() {
t.Skip("requires AVX")
}
// TODO: If we expand the shift from uint8, add larger cases (e.g., 0x1000)
for _, shift := range []uint8{0, 2, 16, 20, 32, 128} {
t.Log("shift", shift)
testUint8x16Binary(t,
func(x, y archsimd.Uint8x16) archsimd.Uint8x16 { return x.ConcatShiftBytesRight(shift, y) },
csbr(shift))
testUint8x16Binary(t,
func(x, y archsimd.Uint8x16) archsimd.Uint8x16 { return x.ConcatShiftBytesRight(hide(shift), y) },
csbr(hide(shift)))
}
})
t.Run("Uint8x32", func(t *testing.T) {
if !archsimd.X86.AVX2() {
t.Skip("requires AVX2")
}
for _, shift := range []uint8{0, 2, 16, 20, 32, 128} {
t.Log("shift", shift)
testUint8x32Binary(t,
func(x, y archsimd.Uint8x32) archsimd.Uint8x32 { return x.ConcatShiftBytesRightGrouped(shift, y) },
grouped2(csbr(shift)))
testUint8x32Binary(t,
func(x, y archsimd.Uint8x32) archsimd.Uint8x32 { return x.ConcatShiftBytesRightGrouped(hide(shift), y) },
grouped2(csbr(hide(shift))))
}
})
t.Run("Uint8x64", func(t *testing.T) {
if !archsimd.X86.AVX512() {
t.Skip("requires AVX512")
}
for _, shift := range []uint8{0, 2, 16, 20, 32, 128} {
t.Log("shift", shift)
testUint8x64Binary(t,
func(x, y archsimd.Uint8x64) archsimd.Uint8x64 { return x.ConcatShiftBytesRightGrouped(shift, y) },
grouped2(csbr(shift)))
testUint8x64Binary(t,
func(x, y archsimd.Uint8x64) archsimd.Uint8x64 { return x.ConcatShiftBytesRightGrouped(hide(shift), y) },
grouped2(csbr(hide(shift))))
}
})
}