sha3: make APIs usable with zero allocations

The "buf points into storage" pattern is nice, but causes the whole
state struct to escape, since escape analysis can't track the pointer
once it's assigned to buf.

Change-Id: I31c0e83f946d66bedb5a180e96ab5d5e936eb322
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/544817
Reviewed-by: Cherry Mui <cherryyz@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Reviewed-by: Mauri de Souza Meneguzzo <mauri870@gmail.com>
Auto-Submit: Filippo Valsorda <filippo@golang.org>
diff --git a/sha3/allocations_test.go b/sha3/allocations_test.go
new file mode 100644
index 0000000..c925099
--- /dev/null
+++ b/sha3/allocations_test.go
@@ -0,0 +1,53 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !noopt
+
+package sha3_test
+
+import (
+	"testing"
+
+	"golang.org/x/crypto/sha3"
+)
+
+var sink byte
+
+func TestAllocations(t *testing.T) {
+	t.Run("New", func(t *testing.T) {
+		if allocs := testing.AllocsPerRun(10, func() {
+			h := sha3.New256()
+			b := []byte("ABC")
+			h.Write(b)
+			out := make([]byte, 0, 32)
+			out = h.Sum(out)
+			sink ^= out[0]
+		}); allocs > 0 {
+			t.Errorf("expected zero allocations, got %0.1f", allocs)
+		}
+	})
+	t.Run("NewShake", func(t *testing.T) {
+		if allocs := testing.AllocsPerRun(10, func() {
+			h := sha3.NewShake128()
+			b := []byte("ABC")
+			h.Write(b)
+			out := make([]byte, 0, 32)
+			out = h.Sum(out)
+			sink ^= out[0]
+			h.Read(out)
+			sink ^= out[0]
+		}); allocs > 0 {
+			t.Errorf("expected zero allocations, got %0.1f", allocs)
+		}
+	})
+	t.Run("Sum", func(t *testing.T) {
+		if allocs := testing.AllocsPerRun(10, func() {
+			b := []byte("ABC")
+			out := sha3.Sum256(b)
+			sink ^= out[0]
+		}); allocs > 0 {
+			t.Errorf("expected zero allocations, got %0.1f", allocs)
+		}
+	})
+}
diff --git a/sha3/sha3.go b/sha3/sha3.go
index 33bd73b..afedde5 100644
--- a/sha3/sha3.go
+++ b/sha3/sha3.go
@@ -23,7 +23,6 @@
 type state struct {
 	// Generic sponge components.
 	a    [25]uint64 // main state of the hash
-	buf  []byte     // points into storage
 	rate int        // the number of bytes of state to use
 
 	// dsbyte contains the "domain separation" bits and the first bit of
@@ -40,6 +39,7 @@
 	//      Extendable-Output Functions (May 2014)"
 	dsbyte byte
 
+	i, n    int // storage[i:n] is the buffer, i is only used while squeezing
 	storage [maxRate]byte
 
 	// Specific to SHA-3 and SHAKE.
@@ -54,24 +54,18 @@
 func (d *state) Size() int { return d.outputLen }
 
 // Reset clears the internal state by zeroing the sponge state and
-// the byte buffer, and setting Sponge.state to absorbing.
+// the buffer indexes, and setting Sponge.state to absorbing.
 func (d *state) Reset() {
 	// Zero the permutation's state.
 	for i := range d.a {
 		d.a[i] = 0
 	}
 	d.state = spongeAbsorbing
-	d.buf = d.storage[:0]
+	d.i, d.n = 0, 0
 }
 
 func (d *state) clone() *state {
 	ret := *d
-	if ret.state == spongeAbsorbing {
-		ret.buf = ret.storage[:len(ret.buf)]
-	} else {
-		ret.buf = ret.storage[d.rate-cap(d.buf) : d.rate]
-	}
-
 	return &ret
 }
 
@@ -82,43 +76,40 @@
 	case spongeAbsorbing:
 		// If we're absorbing, we need to xor the input into the state
 		// before applying the permutation.
-		xorIn(d, d.buf)
-		d.buf = d.storage[:0]
+		xorIn(d, d.storage[:d.rate])
+		d.n = 0
 		keccakF1600(&d.a)
 	case spongeSqueezing:
 		// If we're squeezing, we need to apply the permutation before
 		// copying more output.
 		keccakF1600(&d.a)
-		d.buf = d.storage[:d.rate]
-		copyOut(d, d.buf)
+		d.i = 0
+		copyOut(d, d.storage[:d.rate])
 	}
 }
 
 // pads appends the domain separation bits in dsbyte, applies
 // the multi-bitrate 10..1 padding rule, and permutes the state.
-func (d *state) padAndPermute(dsbyte byte) {
-	if d.buf == nil {
-		d.buf = d.storage[:0]
-	}
+func (d *state) padAndPermute() {
 	// Pad with this instance's domain-separator bits. We know that there's
 	// at least one byte of space in d.buf because, if it were full,
 	// permute would have been called to empty it. dsbyte also contains the
 	// first one bit for the padding. See the comment in the state struct.
-	d.buf = append(d.buf, dsbyte)
-	zerosStart := len(d.buf)
-	d.buf = d.storage[:d.rate]
-	for i := zerosStart; i < d.rate; i++ {
-		d.buf[i] = 0
+	d.storage[d.n] = d.dsbyte
+	d.n++
+	for d.n < d.rate {
+		d.storage[d.n] = 0
+		d.n++
 	}
 	// This adds the final one bit for the padding. Because of the way that
 	// bits are numbered from the LSB upwards, the final bit is the MSB of
 	// the last byte.
-	d.buf[d.rate-1] ^= 0x80
+	d.storage[d.rate-1] ^= 0x80
 	// Apply the permutation
 	d.permute()
 	d.state = spongeSqueezing
-	d.buf = d.storage[:d.rate]
-	copyOut(d, d.buf)
+	d.n = d.rate
+	copyOut(d, d.storage[:d.rate])
 }
 
 // Write absorbs more data into the hash's state. It panics if any
@@ -127,28 +118,25 @@
 	if d.state != spongeAbsorbing {
 		panic("sha3: Write after Read")
 	}
-	if d.buf == nil {
-		d.buf = d.storage[:0]
-	}
 	written = len(p)
 
 	for len(p) > 0 {
-		if len(d.buf) == 0 && len(p) >= d.rate {
+		if d.n == 0 && len(p) >= d.rate {
 			// The fast path; absorb a full "rate" bytes of input and apply the permutation.
 			xorIn(d, p[:d.rate])
 			p = p[d.rate:]
 			keccakF1600(&d.a)
 		} else {
 			// The slow path; buffer the input until we can fill the sponge, and then xor it in.
-			todo := d.rate - len(d.buf)
+			todo := d.rate - d.n
 			if todo > len(p) {
 				todo = len(p)
 			}
-			d.buf = append(d.buf, p[:todo]...)
+			d.n += copy(d.storage[d.n:], p[:todo])
 			p = p[todo:]
 
 			// If the sponge is full, apply the permutation.
-			if len(d.buf) == d.rate {
+			if d.n == d.rate {
 				d.permute()
 			}
 		}
@@ -161,19 +149,19 @@
 func (d *state) Read(out []byte) (n int, err error) {
 	// If we're still absorbing, pad and apply the permutation.
 	if d.state == spongeAbsorbing {
-		d.padAndPermute(d.dsbyte)
+		d.padAndPermute()
 	}
 
 	n = len(out)
 
 	// Now, do the squeezing.
 	for len(out) > 0 {
-		n := copy(out, d.buf)
-		d.buf = d.buf[n:]
+		n := copy(out, d.storage[d.i:d.n])
+		d.i += n
 		out = out[n:]
 
 		// Apply the permutation if we've squeezed the sponge dry.
-		if len(d.buf) == 0 {
+		if d.i == d.rate {
 			d.permute()
 		}
 	}