cmd/compile: use jump tables for large type switches

For large interface -> concrete type switches, we can use a jump
table on some bits of the type hash instead of a binary search on
the type hash.

name                        old time/op  new time/op  delta
SwitchTypePredictable-24    1.99ns ± 2%  1.78ns ± 5%  -10.87%  (p=0.000 n=10+10)
SwitchTypeUnpredictable-24  11.0ns ± 1%   9.1ns ± 2%  -17.55%  (p=0.000 n=7+9)

Change-Id: Ida4768e5d62c3ce1c2701288b72664aaa9e64259
Reviewed-on: https://go-review.googlesource.com/c/go/+/521497
Reviewed-by: Keith Randall <khr@google.com>
Auto-Submit: Keith Randall <khr@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Cherry Mui <cherryyz@google.com>
Run-TryBot: Keith Randall <khr@golang.org>
diff --git a/src/cmd/compile/internal/test/switch_test.go b/src/cmd/compile/internal/test/switch_test.go
index 30dee62..ddb39bf 100644
--- a/src/cmd/compile/internal/test/switch_test.go
+++ b/src/cmd/compile/internal/test/switch_test.go
@@ -120,6 +120,48 @@
 	sink = n
 }
 
+func BenchmarkSwitchTypePredictable(b *testing.B) {
+	benchmarkSwitchType(b, true)
+}
+func BenchmarkSwitchTypeUnpredictable(b *testing.B) {
+	benchmarkSwitchType(b, false)
+}
+func benchmarkSwitchType(b *testing.B, predictable bool) {
+	a := []any{
+		int8(1),
+		int16(2),
+		int32(3),
+		int64(4),
+		uint8(5),
+		uint16(6),
+		uint32(7),
+		uint64(8),
+	}
+	n := 0
+	rng := newRNG()
+	for i := 0; i < b.N; i++ {
+		rng = rng.next(predictable)
+		switch a[rng.value()&7].(type) {
+		case int8:
+			n += 1
+		case int16:
+			n += 2
+		case int32:
+			n += 3
+		case int64:
+			n += 4
+		case uint8:
+			n += 5
+		case uint16:
+			n += 6
+		case uint32:
+			n += 7
+		case uint64:
+			n += 8
+		}
+	}
+}
+
 // A simple random number generator used to make switches conditionally predictable.
 type rng uint64
 
diff --git a/src/cmd/compile/internal/walk/switch.go b/src/cmd/compile/internal/walk/switch.go
index ebd3128..67ccb2e 100644
--- a/src/cmd/compile/internal/walk/switch.go
+++ b/src/cmd/compile/internal/walk/switch.go
@@ -7,6 +7,7 @@
 import (
 	"go/constant"
 	"go/token"
+	"math/bits"
 	"sort"
 
 	"cmd/compile/internal/base"
@@ -617,7 +618,9 @@
 	}
 	cc = merged
 
-	// TODO: figure out if we could use a jump table using some low bits of the type hashes.
+	if s.tryJumpTable(cc, &s.done) {
+		return
+	}
 	binarySearch(len(cc), &s.done,
 		func(i int) ir.Node {
 			return ir.NewBinaryExpr(base.Pos, ir.OLE, s.hashname, ir.NewInt(base.Pos, int64(cc[i-1].hash)))
@@ -632,6 +635,83 @@
 	)
 }
 
+// Try to implement the clauses with a jump table. Returns true if successful.
+func (s *typeSwitch) tryJumpTable(cc []typeClause, out *ir.Nodes) bool {
+	const minCases = 5 // have at least minCases cases in the switch
+	if base.Flag.N != 0 || !ssagen.Arch.LinkArch.CanJumpTable || base.Ctxt.Retpoline {
+		return false
+	}
+	if len(cc) < minCases {
+		return false // not enough cases for it to be worth it
+	}
+	hashes := make([]uint32, len(cc))
+	// b = # of bits to use. Start with the minimum number of
+	// bits possible, but try a few larger sizes if needed.
+	b0 := bits.Len(uint(len(cc) - 1))
+	for b := b0; b < b0+3; b++ {
+	pickI:
+		for i := 0; i <= 32-b; i++ { // starting bit position
+			// Compute the hash we'd get from all the cases,
+			// selecting b bits starting at bit i.
+			hashes = hashes[:0]
+			for _, c := range cc {
+				h := c.hash >> i & (1<<b - 1)
+				hashes = append(hashes, h)
+			}
+			// Order by increasing hash.
+			sort.Slice(hashes, func(j, k int) bool {
+				return hashes[j] < hashes[k]
+			})
+			for j := 1; j < len(hashes); j++ {
+				if hashes[j] == hashes[j-1] {
+					// There is a duplicate hash; try a different b/i pair.
+					continue pickI
+				}
+			}
+
+			// All hashes are distinct. Use these values of b and i.
+			h := s.hashname
+			if i != 0 {
+				h = ir.NewBinaryExpr(base.Pos, ir.ORSH, h, ir.NewInt(base.Pos, int64(i)))
+			}
+			h = ir.NewBinaryExpr(base.Pos, ir.OAND, h, ir.NewInt(base.Pos, int64(1<<b-1)))
+			h = typecheck.Expr(h)
+
+			// Build jump table.
+			jt := ir.NewJumpTableStmt(base.Pos, h)
+			jt.Cases = make([]constant.Value, 1<<b)
+			jt.Targets = make([]*types.Sym, 1<<b)
+			out.Append(jt)
+
+			// Start with all hashes going to the didn't-match target.
+			noMatch := typecheck.AutoLabel(".s")
+			for j := 0; j < 1<<b; j++ {
+				jt.Cases[j] = constant.MakeInt64(int64(j))
+				jt.Targets[j] = noMatch
+			}
+			// This statement is not reachable, but it will make it obvious that we don't
+			// fall through to the first case.
+			out.Append(ir.NewBranchStmt(base.Pos, ir.OGOTO, noMatch))
+
+			// Emit each of the actual cases.
+			for _, c := range cc {
+				h := c.hash >> i & (1<<b - 1)
+				label := typecheck.AutoLabel(".s")
+				jt.Targets[h] = label
+				out.Append(ir.NewLabelStmt(base.Pos, label))
+				out.Append(c.body...)
+				// We reach here if the hash matches but the type equality test fails.
+				out.Append(ir.NewBranchStmt(base.Pos, ir.OGOTO, noMatch))
+			}
+			// Emit point to go to if type doesn't match any case.
+			out.Append(ir.NewLabelStmt(base.Pos, noMatch))
+			return true
+		}
+	}
+	// Couldn't find a perfect hash. Fall back to binary search.
+	return false
+}
+
 // binarySearch constructs a binary search tree for handling n cases,
 // and appends it to out. It's used for efficiently implementing
 // switch statements.
diff --git a/test/codegen/switch.go b/test/codegen/switch.go
index 603e0be..556d02a 100644
--- a/test/codegen/switch.go
+++ b/test/codegen/switch.go
@@ -99,3 +99,22 @@
 		return ""
 	}
 }
+
+// use jump tables for type switches to concrete types.
+func typeSwitch(x any) int {
+	// amd64:`JMP\s\(.*\)\(.*\)$`
+	// arm64:`MOVD\s\(R.*\)\(R.*<<3\)`,`JMP\s\(R.*\)$`
+	switch x.(type) {
+	case int:
+		return 0
+	case int8:
+		return 1
+	case int16:
+		return 2
+	case int32:
+		return 3
+	case int64:
+		return 4
+	}
+	return 7
+}