cmd/compile: optimize generated struct/array equality code

Use a standard "not-equal" label that we can jump to when we
detect that the arguments are not equal. This prevents the
recombination that was noticed in #39428.

Fixes #39428

Change-Id: Ib7c6b3539f4f6046043fd7148f940fcdcab70427
Reviewed-on: https://go-review.googlesource.com/c/go/+/255317
Trust: Keith Randall <khr@golang.org>
Run-TryBot: Keith Randall <khr@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Josh Bleecher Snyder <josharian@gmail.com>
diff --git a/src/cmd/compile/internal/gc/alg.go b/src/cmd/compile/internal/gc/alg.go
index 2ab69c2..2f7fa27 100644
--- a/src/cmd/compile/internal/gc/alg.go
+++ b/src/cmd/compile/internal/gc/alg.go
@@ -529,6 +529,10 @@
 	fn := dclfunc(sym, tfn)
 	np := asNode(tfn.Type.Params().Field(0).Nname)
 	nq := asNode(tfn.Type.Params().Field(1).Nname)
+	nr := asNode(tfn.Type.Results().Field(0).Nname)
+
+	// Label to jump to if an equality test fails.
+	neq := autolabel(".neq")
 
 	// We reach here only for types that have equality but
 	// cannot be handled by the standard algorithms,
@@ -555,13 +559,13 @@
 		// for i := 0; i < nelem; i++ {
 		//   if eq(p[i], q[i]) {
 		//   } else {
-		//     return
+		//     goto neq
 		//   }
 		// }
 		//
 		// TODO(josharian): consider doing some loop unrolling
 		// for larger nelem as well, processing a few elements at a time in a loop.
-		checkAll := func(unroll int64, eq func(pi, qi *Node) *Node) {
+		checkAll := func(unroll int64, last bool, eq func(pi, qi *Node) *Node) {
 			// checkIdx generates a node to check for equality at index i.
 			checkIdx := func(i *Node) *Node {
 				// pi := p[i]
@@ -576,37 +580,38 @@
 			}
 
 			if nelem <= unroll {
-				// Generate a series of checks.
-				var cond *Node
-				for i := int64(0); i < nelem; i++ {
-					c := nodintconst(i)
-					check := checkIdx(c)
-					if cond == nil {
-						cond = check
-						continue
-					}
-					cond = nod(OANDAND, cond, check)
+				if last {
+					// Do last comparison in a different manner.
+					nelem--
 				}
-				nif := nod(OIF, cond, nil)
-				nif.Rlist.Append(nod(ORETURN, nil, nil))
-				fn.Nbody.Append(nif)
-				return
+				// Generate a series of checks.
+				for i := int64(0); i < nelem; i++ {
+					// if check {} else { goto neq }
+					nif := nod(OIF, checkIdx(nodintconst(i)), nil)
+					nif.Rlist.Append(nodSym(OGOTO, nil, neq))
+					fn.Nbody.Append(nif)
+				}
+				if last {
+					fn.Nbody.Append(nod(OAS, nr, checkIdx(nodintconst(nelem))))
+				}
+			} else {
+				// Generate a for loop.
+				// for i := 0; i < nelem; i++
+				i := temp(types.Types[TINT])
+				init := nod(OAS, i, nodintconst(0))
+				cond := nod(OLT, i, nodintconst(nelem))
+				post := nod(OAS, i, nod(OADD, i, nodintconst(1)))
+				loop := nod(OFOR, cond, post)
+				loop.Ninit.Append(init)
+				// if eq(pi, qi) {} else { goto neq }
+				nif := nod(OIF, checkIdx(i), nil)
+				nif.Rlist.Append(nodSym(OGOTO, nil, neq))
+				loop.Nbody.Append(nif)
+				fn.Nbody.Append(loop)
+				if last {
+					fn.Nbody.Append(nod(OAS, nr, nodbool(true)))
+				}
 			}
-
-			// Generate a for loop.
-			// for i := 0; i < nelem; i++
-			i := temp(types.Types[TINT])
-			init := nod(OAS, i, nodintconst(0))
-			cond := nod(OLT, i, nodintconst(nelem))
-			post := nod(OAS, i, nod(OADD, i, nodintconst(1)))
-			loop := nod(OFOR, cond, post)
-			loop.Ninit.Append(init)
-			// if eq(pi, qi) {} else { return }
-			check := checkIdx(i)
-			nif := nod(OIF, check, nil)
-			nif.Rlist.Append(nod(ORETURN, nil, nil))
-			loop.Nbody.Append(nif)
-			fn.Nbody.Append(loop)
 		}
 
 		switch t.Elem().Etype {
@@ -614,32 +619,28 @@
 			// Do two loops. First, check that all the lengths match (cheap).
 			// Second, check that all the contents match (expensive).
 			// TODO: when the array size is small, unroll the length match checks.
-			checkAll(3, func(pi, qi *Node) *Node {
+			checkAll(3, false, func(pi, qi *Node) *Node {
 				// Compare lengths.
 				eqlen, _ := eqstring(pi, qi)
 				return eqlen
 			})
-			checkAll(1, func(pi, qi *Node) *Node {
+			checkAll(1, true, func(pi, qi *Node) *Node {
 				// Compare contents.
 				_, eqmem := eqstring(pi, qi)
 				return eqmem
 			})
 		case TFLOAT32, TFLOAT64:
-			checkAll(2, func(pi, qi *Node) *Node {
+			checkAll(2, true, func(pi, qi *Node) *Node {
 				// p[i] == q[i]
 				return nod(OEQ, pi, qi)
 			})
 		// TODO: pick apart structs, do them piecemeal too
 		default:
-			checkAll(1, func(pi, qi *Node) *Node {
+			checkAll(1, true, func(pi, qi *Node) *Node {
 				// p[i] == q[i]
 				return nod(OEQ, pi, qi)
 			})
 		}
-		// return true
-		ret := nod(ORETURN, nil, nil)
-		ret.List.Append(nodbool(true))
-		fn.Nbody.Append(ret)
 
 	case TSTRUCT:
 		// Build a list of conditions to satisfy.
@@ -717,21 +718,41 @@
 			flatConds = append(flatConds, c...)
 		}
 
-		var cond *Node
 		if len(flatConds) == 0 {
-			cond = nodbool(true)
+			fn.Nbody.Append(nod(OAS, nr, nodbool(true)))
 		} else {
-			cond = flatConds[0]
-			for _, c := range flatConds[1:] {
-				cond = nod(OANDAND, cond, c)
+			for _, c := range flatConds[:len(flatConds)-1] {
+				// if cond {} else { goto neq }
+				n := nod(OIF, c, nil)
+				n.Rlist.Append(nodSym(OGOTO, nil, neq))
+				fn.Nbody.Append(n)
 			}
+			fn.Nbody.Append(nod(OAS, nr, flatConds[len(flatConds)-1]))
 		}
-
-		ret := nod(ORETURN, nil, nil)
-		ret.List.Append(cond)
-		fn.Nbody.Append(ret)
 	}
 
+	// ret:
+	//   return
+	ret := autolabel(".ret")
+	fn.Nbody.Append(nodSym(OLABEL, nil, ret))
+	fn.Nbody.Append(nod(ORETURN, nil, nil))
+
+	// neq:
+	//   r = false
+	//   return (or goto ret)
+	fn.Nbody.Append(nodSym(OLABEL, nil, neq))
+	fn.Nbody.Append(nod(OAS, nr, nodbool(false)))
+	if EqCanPanic(t) || hasCall(fn) {
+		// Epilogue is large, so share it with the equal case.
+		fn.Nbody.Append(nodSym(OGOTO, nil, ret))
+	} else {
+		// Epilogue is small, so don't bother sharing.
+		fn.Nbody.Append(nod(ORETURN, nil, nil))
+	}
+	// TODO(khr): the epilogue size detection condition above isn't perfect.
+	// We should really do a generic CL that shares epilogues across
+	// the board. See #24936.
+
 	if Debug.r != 0 {
 		dumplist("geneq body", fn.Nbody)
 	}
@@ -762,6 +783,39 @@
 	return closure
 }
 
+func hasCall(n *Node) bool {
+	if n.Op == OCALL || n.Op == OCALLFUNC {
+		return true
+	}
+	if n.Left != nil && hasCall(n.Left) {
+		return true
+	}
+	if n.Right != nil && hasCall(n.Right) {
+		return true
+	}
+	for _, x := range n.Ninit.Slice() {
+		if hasCall(x) {
+			return true
+		}
+	}
+	for _, x := range n.Nbody.Slice() {
+		if hasCall(x) {
+			return true
+		}
+	}
+	for _, x := range n.List.Slice() {
+		if hasCall(x) {
+			return true
+		}
+	}
+	for _, x := range n.Rlist.Slice() {
+		if hasCall(x) {
+			return true
+		}
+	}
+	return false
+}
+
 // eqfield returns the node
 // 	p.field == q.field
 func eqfield(p *Node, q *Node, field *types.Sym) *Node {