cmd/compile: use jump tables for large type switches
For large interface -> concrete type switches, we can use a jump
table on some bits of the type hash instead of a binary search on
the type hash.
name old time/op new time/op delta
SwitchTypePredictable-24 1.99ns ± 2% 1.78ns ± 5% -10.87% (p=0.000 n=10+10)
SwitchTypeUnpredictable-24 11.0ns ± 1% 9.1ns ± 2% -17.55% (p=0.000 n=7+9)
Change-Id: Ida4768e5d62c3ce1c2701288b72664aaa9e64259
Reviewed-on: https://go-review.googlesource.com/c/go/+/521497
Reviewed-by: Keith Randall <khr@google.com>
Auto-Submit: Keith Randall <khr@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Cherry Mui <cherryyz@google.com>
Run-TryBot: Keith Randall <khr@golang.org>
diff --git a/src/cmd/compile/internal/test/switch_test.go b/src/cmd/compile/internal/test/switch_test.go
index 30dee62..ddb39bf 100644
--- a/src/cmd/compile/internal/test/switch_test.go
+++ b/src/cmd/compile/internal/test/switch_test.go
@@ -120,6 +120,48 @@
sink = n
}
+func BenchmarkSwitchTypePredictable(b *testing.B) {
+ benchmarkSwitchType(b, true)
+}
+func BenchmarkSwitchTypeUnpredictable(b *testing.B) {
+ benchmarkSwitchType(b, false)
+}
+func benchmarkSwitchType(b *testing.B, predictable bool) {
+ a := []any{
+ int8(1),
+ int16(2),
+ int32(3),
+ int64(4),
+ uint8(5),
+ uint16(6),
+ uint32(7),
+ uint64(8),
+ }
+ n := 0
+ rng := newRNG()
+ for i := 0; i < b.N; i++ {
+ rng = rng.next(predictable)
+ switch a[rng.value()&7].(type) {
+ case int8:
+ n += 1
+ case int16:
+ n += 2
+ case int32:
+ n += 3
+ case int64:
+ n += 4
+ case uint8:
+ n += 5
+ case uint16:
+ n += 6
+ case uint32:
+ n += 7
+ case uint64:
+ n += 8
+ }
+ }
+}
+
// A simple random number generator used to make switches conditionally predictable.
type rng uint64
diff --git a/src/cmd/compile/internal/walk/switch.go b/src/cmd/compile/internal/walk/switch.go
index ebd3128..67ccb2e 100644
--- a/src/cmd/compile/internal/walk/switch.go
+++ b/src/cmd/compile/internal/walk/switch.go
@@ -7,6 +7,7 @@
import (
"go/constant"
"go/token"
+ "math/bits"
"sort"
"cmd/compile/internal/base"
@@ -617,7 +618,9 @@
}
cc = merged
- // TODO: figure out if we could use a jump table using some low bits of the type hashes.
+ if s.tryJumpTable(cc, &s.done) {
+ return
+ }
binarySearch(len(cc), &s.done,
func(i int) ir.Node {
return ir.NewBinaryExpr(base.Pos, ir.OLE, s.hashname, ir.NewInt(base.Pos, int64(cc[i-1].hash)))
@@ -632,6 +635,83 @@
)
}
+// Try to implement the clauses with a jump table. Returns true if successful.
+func (s *typeSwitch) tryJumpTable(cc []typeClause, out *ir.Nodes) bool {
+ const minCases = 5 // have at least minCases cases in the switch
+ if base.Flag.N != 0 || !ssagen.Arch.LinkArch.CanJumpTable || base.Ctxt.Retpoline {
+ return false
+ }
+ if len(cc) < minCases {
+ return false // not enough cases for it to be worth it
+ }
+ hashes := make([]uint32, len(cc))
+ // b = # of bits to use. Start with the minimum number of
+ // bits possible, but try a few larger sizes if needed.
+ b0 := bits.Len(uint(len(cc) - 1))
+ for b := b0; b < b0+3; b++ {
+ pickI:
+ for i := 0; i <= 32-b; i++ { // starting bit position
+ // Compute the hash we'd get from all the cases,
+ // selecting b bits starting at bit i.
+ hashes = hashes[:0]
+ for _, c := range cc {
+ h := c.hash >> i & (1<<b - 1)
+ hashes = append(hashes, h)
+ }
+ // Order by increasing hash.
+ sort.Slice(hashes, func(j, k int) bool {
+ return hashes[j] < hashes[k]
+ })
+ for j := 1; j < len(hashes); j++ {
+ if hashes[j] == hashes[j-1] {
+ // There is a duplicate hash; try a different b/i pair.
+ continue pickI
+ }
+ }
+
+ // All hashes are distinct. Use these values of b and i.
+ h := s.hashname
+ if i != 0 {
+ h = ir.NewBinaryExpr(base.Pos, ir.ORSH, h, ir.NewInt(base.Pos, int64(i)))
+ }
+ h = ir.NewBinaryExpr(base.Pos, ir.OAND, h, ir.NewInt(base.Pos, int64(1<<b-1)))
+ h = typecheck.Expr(h)
+
+ // Build jump table.
+ jt := ir.NewJumpTableStmt(base.Pos, h)
+ jt.Cases = make([]constant.Value, 1<<b)
+ jt.Targets = make([]*types.Sym, 1<<b)
+ out.Append(jt)
+
+ // Start with all hashes going to the didn't-match target.
+ noMatch := typecheck.AutoLabel(".s")
+ for j := 0; j < 1<<b; j++ {
+ jt.Cases[j] = constant.MakeInt64(int64(j))
+ jt.Targets[j] = noMatch
+ }
+ // This statement is not reachable, but it will make it obvious that we don't
+ // fall through to the first case.
+ out.Append(ir.NewBranchStmt(base.Pos, ir.OGOTO, noMatch))
+
+ // Emit each of the actual cases.
+ for _, c := range cc {
+ h := c.hash >> i & (1<<b - 1)
+ label := typecheck.AutoLabel(".s")
+ jt.Targets[h] = label
+ out.Append(ir.NewLabelStmt(base.Pos, label))
+ out.Append(c.body...)
+ // We reach here if the hash matches but the type equality test fails.
+ out.Append(ir.NewBranchStmt(base.Pos, ir.OGOTO, noMatch))
+ }
+ // Emit point to go to if type doesn't match any case.
+ out.Append(ir.NewLabelStmt(base.Pos, noMatch))
+ return true
+ }
+ }
+ // Couldn't find a perfect hash. Fall back to binary search.
+ return false
+}
+
// binarySearch constructs a binary search tree for handling n cases,
// and appends it to out. It's used for efficiently implementing
// switch statements.
diff --git a/test/codegen/switch.go b/test/codegen/switch.go
index 603e0be..556d02a 100644
--- a/test/codegen/switch.go
+++ b/test/codegen/switch.go
@@ -99,3 +99,22 @@
return ""
}
}
+
+// use jump tables for type switches to concrete types.
+func typeSwitch(x any) int {
+ // amd64:`JMP\s\(.*\)\(.*\)$`
+ // arm64:`MOVD\s\(R.*\)\(R.*<<3\)`,`JMP\s\(R.*\)$`
+ switch x.(type) {
+ case int:
+ return 0
+ case int8:
+ return 1
+ case int16:
+ return 2
+ case int32:
+ return 3
+ case int64:
+ return 4
+ }
+ return 7
+}