sha3: add cSHAKE support

This patch implements 128- and 256-bit version of customizable variant
of SHAKE function (cSHAKE).

* Implementation based on NIST FIPS 202
* Test data file has been updated with cSHAKE KATs. I've copied
  examples from NIST document available here:
  https://csrc.nist.gov/csrc/media/projects/cryptographic-standards-and
  -guidelines/documents/examples/cshake_samples.pdf

Fixes #25395

Change-Id: Icbbc4232f3d9a28b3d6ead51937c2e60c00e5d8c
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/111281
Reviewed-by: Filippo Valsorda <filippo@golang.org>
diff --git a/sha3/sha3_test.go b/sha3/sha3_test.go
index 26d1549..c4e4498 100644
--- a/sha3/sha3_test.go
+++ b/sha3/sha3_test.go
@@ -27,14 +27,6 @@
 	katFilename = "testdata/keccakKats.json.deflate"
 )
 
-// Internal-use instances of SHAKE used to test against KATs.
-func newHashShake128() hash.Hash {
-	return &state{rate: 168, dsbyte: 0x1f, outputLen: 512}
-}
-func newHashShake256() hash.Hash {
-	return &state{rate: 136, dsbyte: 0x1f, outputLen: 512}
-}
-
 // testDigests contains functions returning hash.Hash instances
 // with output-length equal to the KAT length for SHA-3, Keccak
 // and SHAKE instances.
@@ -45,15 +37,20 @@
 	"SHA3-512":   New512,
 	"Keccak-256": NewLegacyKeccak256,
 	"Keccak-512": NewLegacyKeccak512,
-	"SHAKE128":   newHashShake128,
-	"SHAKE256":   newHashShake256,
 }
 
-// testShakes contains functions that return ShakeHash instances for
-// testing the ShakeHash-specific interface.
-var testShakes = map[string]func() ShakeHash{
-	"SHAKE128": NewShake128,
-	"SHAKE256": NewShake256,
+// testShakes contains functions that return sha3.ShakeHash instances for
+// with output-length equal to the KAT length.
+var testShakes = map[string]struct {
+	constructor  func(N []byte, S []byte) ShakeHash
+	defAlgoName  string
+	defCustomStr string
+}{
+	// NewCShake without customization produces same result as SHAKE
+	"SHAKE128":  {NewCShake128, "", ""},
+	"SHAKE256":  {NewCShake256, "", ""},
+	"cSHAKE128": {NewCShake128, "CSHAKE128", "CustomStrign"},
+	"cSHAKE256": {NewCShake256, "CSHAKE256", "CustomStrign"},
 }
 
 // decodeHex converts a hex-encoded string into a raw byte string.
@@ -71,6 +68,10 @@
 		Digest  string `json:"digest"`
 		Length  int64  `json:"length"`
 		Message string `json:"message"`
+
+		// Defined only for cSHAKE
+		N string `json:"N"`
+		S string `json:"S"`
 	}
 }
 
@@ -103,10 +104,9 @@
 			t.Errorf("error decoding KATs: %s", err)
 		}
 
-		// Do the KATs.
-		for functionName, kats := range katSet.Kats {
-			d := testDigests[functionName]()
-			for _, kat := range kats {
+		for algo, function := range testDigests {
+			d := function()
+			for _, kat := range katSet.Kats[algo] {
 				d.Reset()
 				in, err := hex.DecodeString(kat.Message)
 				if err != nil {
@@ -115,8 +115,39 @@
 				d.Write(in[:kat.Length/8])
 				got := strings.ToUpper(hex.EncodeToString(d.Sum(nil)))
 				if got != kat.Digest {
-					t.Errorf("function=%s, implementation=%s, length=%d\nmessage:\n  %s\ngot:\n  %s\nwanted:\n %s",
-						functionName, impl, kat.Length, kat.Message, got, kat.Digest)
+					t.Errorf("function=%s, implementation=%s, length=%d\nmessage:\n %s\ngot:\n  %s\nwanted:\n %s",
+						algo, impl, kat.Length, kat.Message, got, kat.Digest)
+					t.Logf("wanted %+v", kat)
+					t.FailNow()
+				}
+				continue
+			}
+		}
+
+		for algo, v := range testShakes {
+			for _, kat := range katSet.Kats[algo] {
+				N, err := hex.DecodeString(kat.N)
+				if err != nil {
+					t.Errorf("error decoding KAT: %s", err)
+				}
+
+				S, err := hex.DecodeString(kat.S)
+				if err != nil {
+					t.Errorf("error decoding KAT: %s", err)
+				}
+				d := v.constructor(N, S)
+				in, err := hex.DecodeString(kat.Message)
+				if err != nil {
+					t.Errorf("error decoding KAT: %s", err)
+				}
+
+				d.Write(in[:kat.Length/8])
+				out := make([]byte, len(kat.Digest)/2)
+				d.Read(out)
+				got := strings.ToUpper(hex.EncodeToString(out))
+				if got != kat.Digest {
+					t.Errorf("function=%s, implementation=%s, length=%d N:%s\n S:%s\nmessage:\n %s \ngot:\n  %s\nwanted:\n %s",
+						algo, impl, kat.Length, kat.N, kat.S, kat.Message, got, kat.Digest)
 					t.Logf("wanted %+v", kat)
 					t.FailNow()
 				}
@@ -184,6 +215,34 @@
 				t.Errorf("Unaligned writes, implementation=%s, alg=%s\ngot %q, want %q", impl, alg, got, want)
 			}
 		}
+
+		// Same for SHAKE
+		for alg, df := range testShakes {
+			want := make([]byte, 16)
+			got := make([]byte, 16)
+			d := df.constructor([]byte(df.defAlgoName), []byte(df.defCustomStr))
+
+			d.Reset()
+			d.Write(buf)
+			d.Read(want)
+			d.Reset()
+			for i := 0; i < len(buf); {
+				// Cycle through offsets which make a 137 byte sequence.
+				// Because 137 is prime this sequence should exercise all corner cases.
+				offsets := [17]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1}
+				for _, j := range offsets {
+					if v := len(buf) - i; v < j {
+						j = v
+					}
+					d.Write(buf[i : i+j])
+					i += j
+				}
+			}
+			d.Read(got)
+			if !bytes.Equal(got, want) {
+				t.Errorf("Unaligned writes, implementation=%s, alg=%s\ngot %q, want %q", impl, alg, got, want)
+			}
+		}
 	})
 }
 
@@ -225,13 +284,13 @@
 // the same output as repeatedly squeezing the instance.
 func TestSqueezing(t *testing.T) {
 	testUnalignedAndGeneric(t, func(impl string) {
-		for functionName, newShakeHash := range testShakes {
-			d0 := newShakeHash()
+		for algo, v := range testShakes {
+			d0 := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr))
 			d0.Write([]byte(testString))
 			ref := make([]byte, 32)
 			d0.Read(ref)
 
-			d1 := newShakeHash()
+			d1 := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr))
 			d1.Write([]byte(testString))
 			var multiple []byte
 			for range ref {
@@ -240,7 +299,7 @@
 				multiple = append(multiple, one...)
 			}
 			if !bytes.Equal(ref, multiple) {
-				t.Errorf("%s (%s): squeezing %d bytes one at a time failed", functionName, impl, len(ref))
+				t.Errorf("%s (%s): squeezing %d bytes one at a time failed", algo, impl, len(ref))
 			}
 		}
 	})
@@ -255,6 +314,50 @@
 	return result
 }
 
+func TestReset(t *testing.T) {
+	out1 := make([]byte, 32)
+	out2 := make([]byte, 32)
+
+	for _, v := range testShakes {
+		// Calculate hash for the first time
+		c := v.constructor(nil, []byte{0x99, 0x98})
+		c.Write(sequentialBytes(0x100))
+		c.Read(out1)
+
+		// Calculate hash again
+		c.Reset()
+		c.Write(sequentialBytes(0x100))
+		c.Read(out2)
+
+		if !bytes.Equal(out1, out2) {
+			t.Error("\nExpected:\n", out1, "\ngot:\n", out2)
+		}
+	}
+}
+
+func TestClone(t *testing.T) {
+	out1 := make([]byte, 16)
+	out2 := make([]byte, 16)
+	in := sequentialBytes(0x100)
+
+	for _, v := range testShakes {
+		h1 := v.constructor(nil, []byte{0x01})
+		h1.Write([]byte{0x01})
+
+		h2 := h1.Clone()
+
+		h1.Write(in)
+		h1.Read(out1)
+
+		h2.Write(in)
+		h2.Read(out2)
+
+		if !bytes.Equal(out1, out2) {
+			t.Error("\nExpected:\n", hex.EncodeToString(out1), "\ngot:\n", hex.EncodeToString(out2))
+		}
+	}
+}
+
 // BenchmarkPermutationFunction measures the speed of the permutation function
 // with no input data.
 func BenchmarkPermutationFunction(b *testing.B) {
@@ -341,3 +444,37 @@
 	fmt.Printf("%x\n", h)
 	// Output: 78de2974bd2711d5549ffd32b753ef0f5fa80a0db2556db60f0987eb8a9218ff
 }
+
+func ExampleNewCShake256() {
+	out := make([]byte, 32)
+	msg := []byte("The quick brown fox jumps over the lazy dog")
+
+	// Example 1: Simple cshake
+	c1 := NewCShake256([]byte("NAME"), []byte("Partition1"))
+	c1.Write(msg)
+	c1.Read(out)
+	fmt.Println(hex.EncodeToString(out))
+
+	// Example 2: Different customization string produces different digest
+	c1 = NewCShake256([]byte("NAME"), []byte("Partition2"))
+	c1.Write(msg)
+	c1.Read(out)
+	fmt.Println(hex.EncodeToString(out))
+
+	// Example 3: Longer output length produces longer digest
+	out = make([]byte, 64)
+	c1 = NewCShake256([]byte("NAME"), []byte("Partition1"))
+	c1.Write(msg)
+	c1.Read(out)
+	fmt.Println(hex.EncodeToString(out))
+
+	// Example 4: Next read produces different result
+	c1.Read(out)
+	fmt.Println(hex.EncodeToString(out))
+
+	// Output:
+	//a90a4c6ca9af2156eba43dc8398279e6b60dcd56fb21837afe6c308fd4ceb05b
+	//a8db03e71f3e4da5c4eee9d28333cdd355f51cef3c567e59be5beb4ecdbb28f0
+	//a90a4c6ca9af2156eba43dc8398279e6b60dcd56fb21837afe6c308fd4ceb05b9dd98c6ee866ca7dc5a39d53e960f400bcd5a19c8a2d6ec6459f63696543a0d8
+	//85e73a72228d08b46515553ca3a29d47df3047e5d84b12d6c2c63e579f4fd1105716b7838e92e981863907f434bfd4443c9e56ea09da998d2f9b47db71988109
+}
diff --git a/sha3/shake.go b/sha3/shake.go
index 97c9b06..a39e5d5 100644
--- a/sha3/shake.go
+++ b/sha3/shake.go
@@ -5,10 +5,18 @@
 package sha3
 
 // This file defines the ShakeHash interface, and provides
-// functions for creating SHAKE instances, as well as utility
+// functions for creating SHAKE and cSHAKE instances, as well as utility
 // functions for hashing bytes to arbitrary-length output.
+//
+//
+// SHAKE implementation is based on FIPS PUB 202 [1]
+// cSHAKE implementations is based on NIST SP 800-185 [2]
+//
+// [1] https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.202.pdf
+// [2] https://doi.org/10.6028/NIST.SP.800-185
 
 import (
+	"encoding/binary"
 	"io"
 )
 
@@ -31,8 +39,77 @@
 	Reset()
 }
 
-func (d *state) Clone() ShakeHash {
-	return d.clone()
+// cSHAKE specific context
+type cshakeState struct {
+	state // SHA-3 state context and Read/Write operations
+
+	// initBlock is the cSHAKE specific initialization set of bytes. It is initialized
+	// by newCShake function and stores concatenation of N followed by S, encoded
+	// by the method specified in 3.3 of [1].
+	// It is stored here in order for Reset() to be able to put context into
+	// initial state.
+	initBlock []byte
+}
+
+// Consts for configuring initial SHA-3 state
+const (
+	dsbyteShake  = 0x1f
+	dsbyteCShake = 0x04
+	rate128      = 168
+	rate256      = 136
+)
+
+func bytepad(input []byte, w int) []byte {
+	// leftEncode always returns max 9 bytes
+	buf := make([]byte, 0, 9+len(input)+w)
+	buf = append(buf, leftEncode(uint64(w))...)
+	buf = append(buf, input...)
+	padlen := w - (len(buf) % w)
+	return append(buf, make([]byte, padlen)...)
+}
+
+func leftEncode(value uint64) []byte {
+	var b [9]byte
+	binary.BigEndian.PutUint64(b[1:], value)
+	// Trim all but last leading zero bytes
+	i := byte(1)
+	for i < 8 && b[i] == 0 {
+		i++
+	}
+	// Prepend number of encoded bytes
+	b[i-1] = 9 - i
+	return b[i-1:]
+}
+
+func newCShake(N, S []byte, rate int, dsbyte byte) ShakeHash {
+	c := cshakeState{state: state{rate: rate, dsbyte: dsbyte}}
+
+	// leftEncode returns max 9 bytes
+	c.initBlock = make([]byte, 0, 9*2+len(N)+len(S))
+	c.initBlock = append(c.initBlock, leftEncode(uint64(len(N)*8))...)
+	c.initBlock = append(c.initBlock, N...)
+	c.initBlock = append(c.initBlock, leftEncode(uint64(len(S)*8))...)
+	c.initBlock = append(c.initBlock, S...)
+	c.Write(bytepad(c.initBlock, c.rate))
+	return &c
+}
+
+// Reset resets the hash to initial state.
+func (c *cshakeState) Reset() {
+	c.state.Reset()
+	c.Write(bytepad(c.initBlock, c.rate))
+}
+
+// Clone returns copy of a cSHAKE context within its current state.
+func (c *cshakeState) Clone() ShakeHash {
+	b := make([]byte, len(c.initBlock))
+	copy(b, c.initBlock)
+	return &cshakeState{state: *c.clone(), initBlock: b}
+}
+
+// Clone returns copy of SHAKE context within its current state.
+func (c *state) Clone() ShakeHash {
+	return c.clone()
 }
 
 // NewShake128 creates a new SHAKE128 variable-output-length ShakeHash.
@@ -42,7 +119,7 @@
 	if h := newShake128Asm(); h != nil {
 		return h
 	}
-	return &state{rate: 168, dsbyte: 0x1f}
+	return &state{rate: rate128, dsbyte: dsbyteShake}
 }
 
 // NewShake256 creates a new SHAKE256 variable-output-length ShakeHash.
@@ -52,7 +129,33 @@
 	if h := newShake256Asm(); h != nil {
 		return h
 	}
-	return &state{rate: 136, dsbyte: 0x1f}
+	return &state{rate: rate256, dsbyte: dsbyteShake}
+}
+
+// NewCShake128 creates a new instance of cSHAKE128 variable-output-length ShakeHash,
+// a customizable variant of SHAKE128.
+// N is used to define functions based on cSHAKE, it can be empty when plain cSHAKE is
+// desired. S is a customization byte string used for domain separation - two cSHAKE
+// computations on same input with different S yield unrelated outputs.
+// When N and S are both empty, this is equivalent to NewShake128.
+func NewCShake128(N, S []byte) ShakeHash {
+	if len(N) == 0 && len(S) == 0 {
+		return NewShake128()
+	}
+	return newCShake(N, S, rate128, dsbyteCShake)
+}
+
+// NewCShake256 creates a new instance of cSHAKE256 variable-output-length ShakeHash,
+// a customizable variant of SHAKE256.
+// N is used to define functions based on cSHAKE, it can be empty when plain cSHAKE is
+// desired. S is a customization byte string used for domain separation - two cSHAKE
+// computations on same input with different S yield unrelated outputs.
+// When N and S are both empty, this is equivalent to NewShake256.
+func NewCShake256(N, S []byte) ShakeHash {
+	if len(N) == 0 && len(S) == 0 {
+		return NewShake256()
+	}
+	return newCShake(N, S, rate256, dsbyteCShake)
 }
 
 // ShakeSum128 writes an arbitrary-length digest of data into hash.
diff --git a/sha3/testdata/keccakKats.json.deflate b/sha3/testdata/keccakKats.json.deflate
index 62e85ae..7a94c2f 100644
--- a/sha3/testdata/keccakKats.json.deflate
+++ b/sha3/testdata/keccakKats.json.deflate
Binary files differ