sha3: fix Sum results for SHAKE functions on s390x

Sum was taking the digest from the state which is correct for SHA-3
functions but not for SHAKE functions.

Updates golang/go#66804

Change-Id: If782464d773262075950e3168128c0d46e4a6530
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/578715
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Cherry Mui <cherryyz@google.com>
Reviewed-by: Than McIntosh <thanm@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Run-TryBot: Michael Munday <mike.munday@lowrisc.org>
diff --git a/sha3/sha3_s390x.go b/sha3/sha3_s390x.go
index d861bca..b4fbbf8 100644
--- a/sha3/sha3_s390x.go
+++ b/sha3/sha3_s390x.go
@@ -143,6 +143,12 @@
 
 // Read squeezes an arbitrary number of bytes from the sponge.
 func (s *asmState) Read(out []byte) (n int, err error) {
+	// The 'compute last message digest' instruction only stores the digest
+	// at the first operand (dst) for SHAKE functions.
+	if s.function != shake_128 && s.function != shake_256 {
+		panic("sha3: can only call Read for SHAKE functions")
+	}
+
 	n = len(out)
 
 	// need to pad if we were absorbing
@@ -202,8 +208,17 @@
 
 	// Hash the buffer. Note that we don't clear it because we
 	// aren't updating the state.
-	klmd(s.function, &a, nil, s.buf)
-	return append(b, a[:s.outputLen]...)
+	switch s.function {
+	case sha3_224, sha3_256, sha3_384, sha3_512:
+		klmd(s.function, &a, nil, s.buf)
+		return append(b, a[:s.outputLen]...)
+	case shake_128, shake_256:
+		d := make([]byte, s.outputLen, 64)
+		klmd(s.function, &a, d, s.buf)
+		return append(b, d[:s.outputLen]...)
+	default:
+		panic("sha3: unknown function")
+	}
 }
 
 // Reset resets the Hash to its initial state.
diff --git a/sha3/sha3_test.go b/sha3/sha3_test.go
index 83bd619..afcb722 100644
--- a/sha3/sha3_test.go
+++ b/sha3/sha3_test.go
@@ -188,6 +188,34 @@
 	}
 }
 
+// TestShakeSum tests that the output of Sum matches the output of Read.
+func TestShakeSum(t *testing.T) {
+	tests := [...]struct {
+		name        string
+		hash        ShakeHash
+		expectedLen int
+	}{
+		{"SHAKE128", NewShake128(), 32},
+		{"SHAKE256", NewShake256(), 64},
+		{"cSHAKE128", NewCShake128([]byte{'X'}, nil), 32},
+		{"cSHAKE256", NewCShake256([]byte{'X'}, nil), 64},
+	}
+
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			s := test.hash.Sum(nil)
+			if len(s) != test.expectedLen {
+				t.Errorf("Unexpected digest length: got %d, want %d", len(s), test.expectedLen)
+			}
+			r := make([]byte, test.expectedLen)
+			test.hash.Read(r)
+			if !bytes.Equal(s, r) {
+				t.Errorf("Mismatch between Sum and Read:\nSum:  %s\nRead: %s", hex.EncodeToString(s), hex.EncodeToString(r))
+			}
+		})
+	}
+}
+
 // TestUnalignedWrite tests that writing data in an arbitrary pattern with
 // small input buffers.
 func TestUnalignedWrite(t *testing.T) {