sha3: have ShakeHash extend hash.Hash

Package sha3 recommends the SHAKE functions for new uses, but this is
currently somewhat inconvenient because ShakeHash does not implement
hash.Hash. This is understandable, as SHAKE supports arbitrary-length
outputs whereas hash.Hash only supports fixed-length outputs. But
there's a natural fixed-length output to provide: the minimum output
that still provides SHAKE's full-strength generic security.

While here, tweak Sum so that its temporary buffer can be stack
allocated.

Also, tweak the panic message in Write so that the error text is more
readily understandable to Go programmers without needing to be
familiar with crypto jargon, and add a similar check in Sum.

Change-Id: Icf037d3990a71de5630f8825606614443f8c5245
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/526937
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Adam Langley <agl@google.com>
Auto-Submit: Matthew Dempsky <mdempsky@google.com>
diff --git a/sha3/sha3.go b/sha3/sha3.go
index fa182be..4884d17 100644
--- a/sha3/sha3.go
+++ b/sha3/sha3.go
@@ -121,11 +121,11 @@
 	copyOut(d, d.buf)
 }
 
-// Write absorbs more data into the hash's state. It produces an error
-// if more data is written to the ShakeHash after writing
+// Write absorbs more data into the hash's state. It panics if any
+// output has already been read.
 func (d *state) Write(p []byte) (written int, err error) {
 	if d.state != spongeAbsorbing {
-		panic("sha3: write to sponge after read")
+		panic("sha3: Write after Read")
 	}
 	if d.buf == nil {
 		d.buf = d.storage.asBytes()[:0]
@@ -182,12 +182,16 @@
 }
 
 // Sum applies padding to the hash state and then squeezes out the desired
-// number of output bytes.
+// number of output bytes. It panics if any output has already been read.
 func (d *state) Sum(in []byte) []byte {
+	if d.state != spongeAbsorbing {
+		panic("sha3: Sum after Read")
+	}
+
 	// Make a copy of the original hash so that caller can keep writing
 	// and summing.
 	dup := d.clone()
-	hash := make([]byte, dup.outputLen)
+	hash := make([]byte, dup.outputLen, 64) // explicit cap to allow stack allocation
 	dup.Read(hash)
 	return append(in, hash...)
 }
diff --git a/sha3/sha3_s390x.go b/sha3/sha3_s390x.go
index 63a3edb..ec26f14 100644
--- a/sha3/sha3_s390x.go
+++ b/sha3/sha3_s390x.go
@@ -49,7 +49,7 @@
 	buf       []byte          // care must be taken to ensure cap(buf) is a multiple of rate
 	rate      int             // equivalent to block size
 	storage   [3072]byte      // underlying storage for buf
-	outputLen int             // output length if fixed, 0 if not
+	outputLen int             // output length for full security
 	function  code            // KIMD/KLMD function code
 	state     spongeDirection // whether the sponge is absorbing or squeezing
 }
@@ -72,8 +72,10 @@
 		s.outputLen = 64
 	case shake_128:
 		s.rate = 168
+		s.outputLen = 32
 	case shake_256:
 		s.rate = 136
+		s.outputLen = 64
 	default:
 		panic("sha3: unrecognized function code")
 	}
@@ -108,7 +110,7 @@
 // It never returns an error.
 func (s *asmState) Write(b []byte) (int, error) {
 	if s.state != spongeAbsorbing {
-		panic("sha3: write to sponge after read")
+		panic("sha3: Write after Read")
 	}
 	length := len(b)
 	for len(b) > 0 {
@@ -192,8 +194,8 @@
 // Sum appends the current hash to b and returns the resulting slice.
 // It does not change the underlying hash state.
 func (s *asmState) Sum(b []byte) []byte {
-	if s.outputLen == 0 {
-		panic("sha3: cannot call Sum on SHAKE functions")
+	if s.state != spongeAbsorbing {
+		panic("sha3: Sum after Read")
 	}
 
 	// Copy the state to preserve the original.
diff --git a/sha3/shake.go b/sha3/shake.go
index d7be295..bb69984 100644
--- a/sha3/shake.go
+++ b/sha3/shake.go
@@ -17,26 +17,25 @@
 
 import (
 	"encoding/binary"
+	"hash"
 	"io"
 )
 
-// ShakeHash defines the interface to hash functions that
-// support arbitrary-length output.
+// ShakeHash defines the interface to hash functions that support
+// arbitrary-length output. When used as a plain [hash.Hash], it
+// produces minimum-length outputs that provide full-strength generic
+// security.
 type ShakeHash interface {
-	// Write absorbs more data into the hash's state. It panics if input is
-	// written to it after output has been read from it.
-	io.Writer
+	hash.Hash
 
 	// Read reads more output from the hash; reading affects the hash's
 	// state. (ShakeHash.Read is thus very different from Hash.Sum)
-	// It never returns an error.
+	// It never returns an error, but subsequent calls to Write or Sum
+	// will panic.
 	io.Reader
 
 	// Clone returns a copy of the ShakeHash in its current state.
 	Clone() ShakeHash
-
-	// Reset resets the ShakeHash to its initial state.
-	Reset()
 }
 
 // cSHAKE specific context
@@ -81,8 +80,8 @@
 	return b[i-1:]
 }
 
-func newCShake(N, S []byte, rate int, dsbyte byte) ShakeHash {
-	c := cshakeState{state: &state{rate: rate, dsbyte: dsbyte}}
+func newCShake(N, S []byte, rate, outputLen int, dsbyte byte) ShakeHash {
+	c := cshakeState{state: &state{rate: rate, outputLen: outputLen, dsbyte: dsbyte}}
 
 	// leftEncode returns max 9 bytes
 	c.initBlock = make([]byte, 0, 9*2+len(N)+len(S))
@@ -119,7 +118,7 @@
 	if h := newShake128Asm(); h != nil {
 		return h
 	}
-	return &state{rate: rate128, dsbyte: dsbyteShake}
+	return &state{rate: rate128, outputLen: 32, dsbyte: dsbyteShake}
 }
 
 // NewShake256 creates a new SHAKE256 variable-output-length ShakeHash.
@@ -129,7 +128,7 @@
 	if h := newShake256Asm(); h != nil {
 		return h
 	}
-	return &state{rate: rate256, dsbyte: dsbyteShake}
+	return &state{rate: rate256, outputLen: 64, dsbyte: dsbyteShake}
 }
 
 // NewCShake128 creates a new instance of cSHAKE128 variable-output-length ShakeHash,
@@ -142,7 +141,7 @@
 	if len(N) == 0 && len(S) == 0 {
 		return NewShake128()
 	}
-	return newCShake(N, S, rate128, dsbyteCShake)
+	return newCShake(N, S, rate128, 32, dsbyteCShake)
 }
 
 // NewCShake256 creates a new instance of cSHAKE256 variable-output-length ShakeHash,
@@ -155,7 +154,7 @@
 	if len(N) == 0 && len(S) == 0 {
 		return NewShake256()
 	}
-	return newCShake(N, S, rate256, dsbyteCShake)
+	return newCShake(N, S, rate256, 64, dsbyteCShake)
 }
 
 // ShakeSum128 writes an arbitrary-length digest of data into hash.