cmd/compile: optimize unsigned comparisons to 0/1 on arm64

For an unsigned integer, it's useful to convert its order test with 0/1
to its equality test with 0. We can save a comparison instruction that
followed by a conditional branch on arm64 since it supports
compare-with-zero-and-branch instructions. For example,

  if x > 0 { ... } else { ... }

the original version:
  CMP $0, R0
  BLS 9

the optimized version:
  CBZ R0, 8

Updates #21439

Change-Id: Id1de6f865f6aa72c5d45b29f7894818857288425
Reviewed-on: https://go-review.googlesource.com/c/go/+/246857
Reviewed-by: Keith Randall <khr@golang.org>
diff --git a/src/cmd/compile/internal/ssa/gen/ARM64.rules b/src/cmd/compile/internal/ssa/gen/ARM64.rules
index 442d769..27959d0 100644
--- a/src/cmd/compile/internal/ssa/gen/ARM64.rules
+++ b/src/cmd/compile/internal/ssa/gen/ARM64.rules
@@ -279,6 +279,16 @@
 (Less32F x y) => (LessThanF (FCMPS x y))
 (Less64F x y) => (LessThanF (FCMPD x y))
 
+// For an unsigned integer x, the following rules are useful when combining branch
+// 0 <  x  =>  x != 0
+// x <= 0  =>  x == 0
+// x <  1  =>  x == 0
+// 1 <= x  =>  x != 0
+(Less(8U|16U|32U|64U) zero:(MOVDconst [0]) x) => (Neq(8|16|32|64) zero x)
+(Leq(8U|16U|32U|64U) x zero:(MOVDconst [0]))  => (Eq(8|16|32|64) x zero)
+(Less(8U|16U|32U|64U) x (MOVDconst [1])) => (Eq(8|16|32|64) x (MOVDconst [0]))
+(Leq(8U|16U|32U|64U) (MOVDconst [1]) x)  => (Neq(8|16|32|64) (MOVDconst [0]) x)
+
 (Less8U x y)  => (LessThanU (CMPW (ZeroExt8to32 x) (ZeroExt8to32 y)))
 (Less16U x y) => (LessThanU (CMPW (ZeroExt16to32 x) (ZeroExt16to32 y)))
 (Less32U x y) => (LessThanU (CMPW x y))
diff --git a/src/cmd/compile/internal/ssa/rewriteARM64.go b/src/cmd/compile/internal/ssa/rewriteARM64.go
index 8e48b33..023d990 100644
--- a/src/cmd/compile/internal/ssa/rewriteARM64.go
+++ b/src/cmd/compile/internal/ssa/rewriteARM64.go
@@ -21976,6 +21976,31 @@
 	v_0 := v.Args[0]
 	b := v.Block
 	typ := &b.Func.Config.Types
+	// match: (Leq16U x zero:(MOVDconst [0]))
+	// result: (Eq16 x zero)
+	for {
+		x := v_0
+		zero := v_1
+		if zero.Op != OpARM64MOVDconst || auxIntToInt64(zero.AuxInt) != 0 {
+			break
+		}
+		v.reset(OpEq16)
+		v.AddArg2(x, zero)
+		return true
+	}
+	// match: (Leq16U (MOVDconst [1]) x)
+	// result: (Neq16 (MOVDconst [0]) x)
+	for {
+		if v_0.Op != OpARM64MOVDconst || auxIntToInt64(v_0.AuxInt) != 1 {
+			break
+		}
+		x := v_1
+		v.reset(OpNeq16)
+		v0 := b.NewValue0(v.Pos, OpARM64MOVDconst, typ.UInt64)
+		v0.AuxInt = int64ToAuxInt(0)
+		v.AddArg2(v0, x)
+		return true
+	}
 	// match: (Leq16U x y)
 	// result: (LessEqualU (CMPW (ZeroExt16to32 x) (ZeroExt16to32 y)))
 	for {
@@ -22028,6 +22053,32 @@
 	v_1 := v.Args[1]
 	v_0 := v.Args[0]
 	b := v.Block
+	typ := &b.Func.Config.Types
+	// match: (Leq32U x zero:(MOVDconst [0]))
+	// result: (Eq32 x zero)
+	for {
+		x := v_0
+		zero := v_1
+		if zero.Op != OpARM64MOVDconst || auxIntToInt64(zero.AuxInt) != 0 {
+			break
+		}
+		v.reset(OpEq32)
+		v.AddArg2(x, zero)
+		return true
+	}
+	// match: (Leq32U (MOVDconst [1]) x)
+	// result: (Neq32 (MOVDconst [0]) x)
+	for {
+		if v_0.Op != OpARM64MOVDconst || auxIntToInt64(v_0.AuxInt) != 1 {
+			break
+		}
+		x := v_1
+		v.reset(OpNeq32)
+		v0 := b.NewValue0(v.Pos, OpARM64MOVDconst, typ.UInt64)
+		v0.AuxInt = int64ToAuxInt(0)
+		v.AddArg2(v0, x)
+		return true
+	}
 	// match: (Leq32U x y)
 	// result: (LessEqualU (CMPW x y))
 	for {
@@ -22076,6 +22127,32 @@
 	v_1 := v.Args[1]
 	v_0 := v.Args[0]
 	b := v.Block
+	typ := &b.Func.Config.Types
+	// match: (Leq64U x zero:(MOVDconst [0]))
+	// result: (Eq64 x zero)
+	for {
+		x := v_0
+		zero := v_1
+		if zero.Op != OpARM64MOVDconst || auxIntToInt64(zero.AuxInt) != 0 {
+			break
+		}
+		v.reset(OpEq64)
+		v.AddArg2(x, zero)
+		return true
+	}
+	// match: (Leq64U (MOVDconst [1]) x)
+	// result: (Neq64 (MOVDconst [0]) x)
+	for {
+		if v_0.Op != OpARM64MOVDconst || auxIntToInt64(v_0.AuxInt) != 1 {
+			break
+		}
+		x := v_1
+		v.reset(OpNeq64)
+		v0 := b.NewValue0(v.Pos, OpARM64MOVDconst, typ.UInt64)
+		v0.AuxInt = int64ToAuxInt(0)
+		v.AddArg2(v0, x)
+		return true
+	}
 	// match: (Leq64U x y)
 	// result: (LessEqualU (CMP x y))
 	for {
@@ -22114,6 +22191,31 @@
 	v_0 := v.Args[0]
 	b := v.Block
 	typ := &b.Func.Config.Types
+	// match: (Leq8U x zero:(MOVDconst [0]))
+	// result: (Eq8 x zero)
+	for {
+		x := v_0
+		zero := v_1
+		if zero.Op != OpARM64MOVDconst || auxIntToInt64(zero.AuxInt) != 0 {
+			break
+		}
+		v.reset(OpEq8)
+		v.AddArg2(x, zero)
+		return true
+	}
+	// match: (Leq8U (MOVDconst [1]) x)
+	// result: (Neq8 (MOVDconst [0]) x)
+	for {
+		if v_0.Op != OpARM64MOVDconst || auxIntToInt64(v_0.AuxInt) != 1 {
+			break
+		}
+		x := v_1
+		v.reset(OpNeq8)
+		v0 := b.NewValue0(v.Pos, OpARM64MOVDconst, typ.UInt64)
+		v0.AuxInt = int64ToAuxInt(0)
+		v.AddArg2(v0, x)
+		return true
+	}
 	// match: (Leq8U x y)
 	// result: (LessEqualU (CMPW (ZeroExt8to32 x) (ZeroExt8to32 y)))
 	for {
@@ -22156,6 +22258,31 @@
 	v_0 := v.Args[0]
 	b := v.Block
 	typ := &b.Func.Config.Types
+	// match: (Less16U zero:(MOVDconst [0]) x)
+	// result: (Neq16 zero x)
+	for {
+		zero := v_0
+		if zero.Op != OpARM64MOVDconst || auxIntToInt64(zero.AuxInt) != 0 {
+			break
+		}
+		x := v_1
+		v.reset(OpNeq16)
+		v.AddArg2(zero, x)
+		return true
+	}
+	// match: (Less16U x (MOVDconst [1]))
+	// result: (Eq16 x (MOVDconst [0]))
+	for {
+		x := v_0
+		if v_1.Op != OpARM64MOVDconst || auxIntToInt64(v_1.AuxInt) != 1 {
+			break
+		}
+		v.reset(OpEq16)
+		v0 := b.NewValue0(v.Pos, OpARM64MOVDconst, typ.UInt64)
+		v0.AuxInt = int64ToAuxInt(0)
+		v.AddArg2(x, v0)
+		return true
+	}
 	// match: (Less16U x y)
 	// result: (LessThanU (CMPW (ZeroExt16to32 x) (ZeroExt16to32 y)))
 	for {
@@ -22208,6 +22335,32 @@
 	v_1 := v.Args[1]
 	v_0 := v.Args[0]
 	b := v.Block
+	typ := &b.Func.Config.Types
+	// match: (Less32U zero:(MOVDconst [0]) x)
+	// result: (Neq32 zero x)
+	for {
+		zero := v_0
+		if zero.Op != OpARM64MOVDconst || auxIntToInt64(zero.AuxInt) != 0 {
+			break
+		}
+		x := v_1
+		v.reset(OpNeq32)
+		v.AddArg2(zero, x)
+		return true
+	}
+	// match: (Less32U x (MOVDconst [1]))
+	// result: (Eq32 x (MOVDconst [0]))
+	for {
+		x := v_0
+		if v_1.Op != OpARM64MOVDconst || auxIntToInt64(v_1.AuxInt) != 1 {
+			break
+		}
+		v.reset(OpEq32)
+		v0 := b.NewValue0(v.Pos, OpARM64MOVDconst, typ.UInt64)
+		v0.AuxInt = int64ToAuxInt(0)
+		v.AddArg2(x, v0)
+		return true
+	}
 	// match: (Less32U x y)
 	// result: (LessThanU (CMPW x y))
 	for {
@@ -22256,6 +22409,32 @@
 	v_1 := v.Args[1]
 	v_0 := v.Args[0]
 	b := v.Block
+	typ := &b.Func.Config.Types
+	// match: (Less64U zero:(MOVDconst [0]) x)
+	// result: (Neq64 zero x)
+	for {
+		zero := v_0
+		if zero.Op != OpARM64MOVDconst || auxIntToInt64(zero.AuxInt) != 0 {
+			break
+		}
+		x := v_1
+		v.reset(OpNeq64)
+		v.AddArg2(zero, x)
+		return true
+	}
+	// match: (Less64U x (MOVDconst [1]))
+	// result: (Eq64 x (MOVDconst [0]))
+	for {
+		x := v_0
+		if v_1.Op != OpARM64MOVDconst || auxIntToInt64(v_1.AuxInt) != 1 {
+			break
+		}
+		v.reset(OpEq64)
+		v0 := b.NewValue0(v.Pos, OpARM64MOVDconst, typ.UInt64)
+		v0.AuxInt = int64ToAuxInt(0)
+		v.AddArg2(x, v0)
+		return true
+	}
 	// match: (Less64U x y)
 	// result: (LessThanU (CMP x y))
 	for {
@@ -22294,6 +22473,31 @@
 	v_0 := v.Args[0]
 	b := v.Block
 	typ := &b.Func.Config.Types
+	// match: (Less8U zero:(MOVDconst [0]) x)
+	// result: (Neq8 zero x)
+	for {
+		zero := v_0
+		if zero.Op != OpARM64MOVDconst || auxIntToInt64(zero.AuxInt) != 0 {
+			break
+		}
+		x := v_1
+		v.reset(OpNeq8)
+		v.AddArg2(zero, x)
+		return true
+	}
+	// match: (Less8U x (MOVDconst [1]))
+	// result: (Eq8 x (MOVDconst [0]))
+	for {
+		x := v_0
+		if v_1.Op != OpARM64MOVDconst || auxIntToInt64(v_1.AuxInt) != 1 {
+			break
+		}
+		v.reset(OpEq8)
+		v0 := b.NewValue0(v.Pos, OpARM64MOVDconst, typ.UInt64)
+		v0.AuxInt = int64ToAuxInt(0)
+		v.AddArg2(x, v0)
+		return true
+	}
 	// match: (Less8U x y)
 	// result: (LessThanU (CMPW (ZeroExt8to32 x) (ZeroExt8to32 y)))
 	for {
diff --git a/test/codegen/comparisons.go b/test/codegen/comparisons.go
index f3c1553..3c2dcb7 100644
--- a/test/codegen/comparisons.go
+++ b/test/codegen/comparisons.go
@@ -424,3 +424,35 @@
 	}
 	return 0
 }
+
+func UintGtZero(a uint8, b uint16, c uint32, d uint64) int {
+	// arm64: `CBZW`, `CBNZW`, `CBNZ`, -`(CMPW|CMP|BLS|BHI)`
+	if a > 0 || b > 0 || c > 0 || d > 0 {
+		return 1
+	}
+	return 0
+}
+
+func UintLeqZero(a uint8, b uint16, c uint32, d uint64) int {
+	// arm64: `CBNZW`, `CBZW`, `CBZ`, -`(CMPW|CMP|BHI|BLS)`
+	if a <= 0 || b <= 0 || c <= 0 || d <= 0 {
+		return 1
+	}
+	return 0
+}
+
+func UintLtOne(a uint8, b uint16, c uint32, d uint64) int {
+	// arm64: `CBNZW`, `CBZW`, `CBZW`, `CBZ`, -`(CMPW|CMP|BHS|BLO)`
+	if a < 1 || b < 1 || c < 1 || d < 1 {
+		return 1
+	}
+	return 0
+}
+
+func UintGeqOne(a uint8, b uint16, c uint32, d uint64) int {
+	// arm64: `CBZW`, `CBNZW`, `CBNZ`, -`(CMPW|CMP|BLO|BHS)`
+	if a >= 1 || b >= 1 || c >= 1 || d >= 1 {
+		return 1
+	}
+	return 0
+}