cmd/compile: introduce generic ssa intrinsic for fused-multiply-add

In order to make math.FMA a compiler intrinsic for ISAs like ARM64,
PPC64[le], and S390X, a generic 3-argument opcode "Fma" is provided and
rewritten as

    ARM64: (Fma x y z) -> (FMADDD z x y)
    PPC64: (Fma x y z) -> (FMADD x y z)
    S390X: (Fma x y z) -> (FMADD z x y)

Updates #25819.

Change-Id: Ie5bc628311e6feeb28ddf9adaa6e702c8c291efa
Reviewed-on: https://go-review.googlesource.com/c/go/+/131959
Run-TryBot: Akhil Indurti <aindurti@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Keith Randall <khr@golang.org>
diff --git a/src/cmd/compile/internal/gc/ssa.go b/src/cmd/compile/internal/gc/ssa.go
index dd8dacd..0b76ad7 100644
--- a/src/cmd/compile/internal/gc/ssa.go
+++ b/src/cmd/compile/internal/gc/ssa.go
@@ -3321,6 +3321,11 @@
 			return s.newValue2(ssa.OpCopysign, types.Types[TFLOAT64], args[0], args[1])
 		},
 		sys.PPC64, sys.Wasm)
+	addF("math", "Fma",
+		func(s *state, n *Node, args []*ssa.Value) *ssa.Value {
+			return s.newValue3(ssa.OpFma, types.Types[TFLOAT64], args[0], args[1], args[2])
+		},
+		sys.ARM64, sys.PPC64, sys.S390X)
 
 	makeRoundAMD64 := func(op ssa.Op) func(s *state, n *Node, args []*ssa.Value) *ssa.Value {
 		return func(s *state, n *Node, args []*ssa.Value) *ssa.Value {
diff --git a/src/cmd/compile/internal/ssa/gen/ARM64.rules b/src/cmd/compile/internal/ssa/gen/ARM64.rules
index 7edd19e..26ae004 100644
--- a/src/cmd/compile/internal/ssa/gen/ARM64.rules
+++ b/src/cmd/compile/internal/ssa/gen/ARM64.rules
@@ -90,6 +90,7 @@
 (Round x) -> (FRINTAD x)
 (RoundToEven x) -> (FRINTND x)
 (Trunc x) -> (FRINTZD x)
+(Fma x y z) -> (FMADDD z x y)
 
 // lowering rotates
 (RotateLeft8 <t> x (MOVDconst [c])) -> (Or8 (Lsh8x64 <t> x (MOVDconst [c&7])) (Rsh8Ux64 <t> x (MOVDconst [-c&7])))
diff --git a/src/cmd/compile/internal/ssa/gen/PPC64.rules b/src/cmd/compile/internal/ssa/gen/PPC64.rules
index 59cce4e..239414f 100644
--- a/src/cmd/compile/internal/ssa/gen/PPC64.rules
+++ b/src/cmd/compile/internal/ssa/gen/PPC64.rules
@@ -68,6 +68,7 @@
 (Round x) -> (FROUND x)
 (Copysign x y) -> (FCPSGN y x)
 (Abs x) -> (FABS x)
+(Fma x y z) -> (FMADD x y z)
 
 // Lowering constants
 (Const(64|32|16|8)  [val]) -> (MOVDconst [val])
diff --git a/src/cmd/compile/internal/ssa/gen/S390X.rules b/src/cmd/compile/internal/ssa/gen/S390X.rules
index 4e45904..d7cb972 100644
--- a/src/cmd/compile/internal/ssa/gen/S390X.rules
+++ b/src/cmd/compile/internal/ssa/gen/S390X.rules
@@ -139,6 +139,7 @@
 (Trunc       x) -> (FIDBR [5] x)
 (RoundToEven x) -> (FIDBR [4] x)
 (Round       x) -> (FIDBR [1] x)
+(Fma     x y z) -> (FMADD z x y)
 
 // Atomic loads and stores.
 // The SYNC instruction (fast-BCR-serialization) prevents store-load
diff --git a/src/cmd/compile/internal/ssa/gen/genericOps.go b/src/cmd/compile/internal/ssa/gen/genericOps.go
index df0dd8c..7bd7931 100644
--- a/src/cmd/compile/internal/ssa/gen/genericOps.go
+++ b/src/cmd/compile/internal/ssa/gen/genericOps.go
@@ -302,6 +302,20 @@
 	{name: "Abs", argLength: 1},      // absolute value arg0
 	{name: "Copysign", argLength: 2}, // copy sign from arg0 to arg1
 
+	// 3-input opcode.
+	// Fused-multiply-add, float64 only.
+	// When a*b+c is exactly zero (before rounding), then the result is +0 or -0.
+	// The 0's sign is determined according to the standard rules for the
+	// addition (-0 if both a*b and c are -0, +0 otherwise).
+	//
+	// Otherwise, when a*b+c rounds to zero, then the resulting 0's sign is
+	// determined by the sign of the exact result a*b+c.
+	// See section 6.3 in ieee754.
+	//
+	// When the multiply is an infinity times a zero, the result is NaN.
+	// See section 7.2 in ieee754.
+	{name: "Fma", argLength: 3}, // compute (a*b)+c without intermediate rounding
+
 	// Data movement, max argument length for Phi is indefinite so just pick
 	// a really large number
 	{name: "Phi", argLength: -1, zeroWidth: true}, // select an argument based on which predecessor block we came from
diff --git a/src/cmd/compile/internal/ssa/opGen.go b/src/cmd/compile/internal/ssa/opGen.go
index b7e6517..7f9fb4e 100644
--- a/src/cmd/compile/internal/ssa/opGen.go
+++ b/src/cmd/compile/internal/ssa/opGen.go
@@ -2419,6 +2419,7 @@
 	OpRoundToEven
 	OpAbs
 	OpCopysign
+	OpFma
 	OpPhi
 	OpCopy
 	OpConvert
@@ -30594,6 +30595,11 @@
 		generic: true,
 	},
 	{
+		name:    "Fma",
+		argLen:  3,
+		generic: true,
+	},
+	{
 		name:      "Phi",
 		argLen:    -1,
 		zeroWidth: true,
diff --git a/src/cmd/compile/internal/ssa/rewriteARM64.go b/src/cmd/compile/internal/ssa/rewriteARM64.go
index dfb5554..2aa38f5 100644
--- a/src/cmd/compile/internal/ssa/rewriteARM64.go
+++ b/src/cmd/compile/internal/ssa/rewriteARM64.go
@@ -571,6 +571,8 @@
 		return rewriteValueARM64_OpEqPtr_0(v)
 	case OpFloor:
 		return rewriteValueARM64_OpFloor_0(v)
+	case OpFma:
+		return rewriteValueARM64_OpFma_0(v)
 	case OpGeq16:
 		return rewriteValueARM64_OpGeq16_0(v)
 	case OpGeq16U:
@@ -28565,6 +28567,21 @@
 		return true
 	}
 }
+func rewriteValueARM64_OpFma_0(v *Value) bool {
+	// match: (Fma x y z)
+	// cond:
+	// result: (FMADDD z x y)
+	for {
+		z := v.Args[2]
+		x := v.Args[0]
+		y := v.Args[1]
+		v.reset(OpARM64FMADDD)
+		v.AddArg(z)
+		v.AddArg(x)
+		v.AddArg(y)
+		return true
+	}
+}
 func rewriteValueARM64_OpGeq16_0(v *Value) bool {
 	b := v.Block
 	typ := &b.Func.Config.Types
diff --git a/src/cmd/compile/internal/ssa/rewritePPC64.go b/src/cmd/compile/internal/ssa/rewritePPC64.go
index 7f49d98..b09bd85 100644
--- a/src/cmd/compile/internal/ssa/rewritePPC64.go
+++ b/src/cmd/compile/internal/ssa/rewritePPC64.go
@@ -181,6 +181,8 @@
 		return rewriteValuePPC64_OpEqPtr_0(v)
 	case OpFloor:
 		return rewriteValuePPC64_OpFloor_0(v)
+	case OpFma:
+		return rewriteValuePPC64_OpFma_0(v)
 	case OpGeq16:
 		return rewriteValuePPC64_OpGeq16_0(v)
 	case OpGeq16U:
@@ -1988,6 +1990,21 @@
 		return true
 	}
 }
+func rewriteValuePPC64_OpFma_0(v *Value) bool {
+	// match: (Fma x y z)
+	// cond:
+	// result: (FMADD x y z)
+	for {
+		z := v.Args[2]
+		x := v.Args[0]
+		y := v.Args[1]
+		v.reset(OpPPC64FMADD)
+		v.AddArg(x)
+		v.AddArg(y)
+		v.AddArg(z)
+		return true
+	}
+}
 func rewriteValuePPC64_OpGeq16_0(v *Value) bool {
 	b := v.Block
 	typ := &b.Func.Config.Types
diff --git a/src/cmd/compile/internal/ssa/rewriteS390X.go b/src/cmd/compile/internal/ssa/rewriteS390X.go
index 72bbdc0..0c03fa2 100644
--- a/src/cmd/compile/internal/ssa/rewriteS390X.go
+++ b/src/cmd/compile/internal/ssa/rewriteS390X.go
@@ -166,6 +166,8 @@
 		return rewriteValueS390X_OpEqPtr_0(v)
 	case OpFloor:
 		return rewriteValueS390X_OpFloor_0(v)
+	case OpFma:
+		return rewriteValueS390X_OpFma_0(v)
 	case OpGeq16:
 		return rewriteValueS390X_OpGeq16_0(v)
 	case OpGeq16U:
@@ -1918,6 +1920,21 @@
 		return true
 	}
 }
+func rewriteValueS390X_OpFma_0(v *Value) bool {
+	// match: (Fma x y z)
+	// cond:
+	// result: (FMADD z x y)
+	for {
+		z := v.Args[2]
+		x := v.Args[0]
+		y := v.Args[1]
+		v.reset(OpS390XFMADD)
+		v.AddArg(z)
+		v.AddArg(x)
+		v.AddArg(y)
+		return true
+	}
+}
 func rewriteValueS390X_OpGeq16_0(v *Value) bool {
 	b := v.Block
 	typ := &b.Func.Config.Types
diff --git a/test/codegen/math.go b/test/codegen/math.go
index 8aa6525..427f305 100644
--- a/test/codegen/math.go
+++ b/test/codegen/math.go
@@ -107,6 +107,14 @@
 	sink64[3] = math.Copysign(-1, c)
 }
 
+func fma(x, y, z float64) float64 {
+	// arm64:"FMADDD"
+	// s390x:"FMADD"
+	// ppc64:"FMADD"
+	// ppc64le:"FMADD"
+	return math.Fma(x, y, z)
+}
+
 func fromFloat64(f64 float64) uint64 {
 	// amd64:"MOVQ\tX.*, [^X].*"
 	// arm64:"FMOVD\tF.*, R.*"