chacha20: don't panic encrypting the final blocks

Certain operations with counter values close to overflowing were causing
an unnecessary panic, which was reachable due to the SetCounter API and
actually observed in QUIC.

Tests by lukechampine <luke.champine@gmail.com> from CL 220591.

Fixes golang/go#37157

Relanding of CL 224279, which was broken on multi-block buffers.
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/224279
Run-TryBot: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Katie Hockman <katie@golang.org>

Change-Id: Ia382c6f62ae49ffe257b67f7b794e8d7124d981e
(cherry picked from commit 1c2c788b11ecf76cd7fbd7bba62146eb7082bdd8)
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/229119
Run-TryBot: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Katie Hockman <katie@golang.org>
diff --git a/chacha20/chacha_generic.go b/chacha20/chacha_generic.go
index 18c8bc0..a2ecf5c 100644
--- a/chacha20/chacha_generic.go
+++ b/chacha20/chacha_generic.go
@@ -42,10 +42,14 @@
 
 	// The last len bytes of buf are leftover key stream bytes from the previous
 	// XORKeyStream invocation. The size of buf depends on how many blocks are
-	// computed at a time.
+	// computed at a time by xorKeyStreamBlocks.
 	buf [bufSize]byte
 	len int
 
+	// overflow is set when the counter overflowed, no more blocks can be
+	// generated, and the next XORKeyStream call should panic.
+	overflow bool
+
 	// The counter-independent results of the first round are cached after they
 	// are computed the first time.
 	precompDone      bool
@@ -140,15 +144,18 @@
 // SetCounter sets the Cipher counter. The next invocation of XORKeyStream will
 // behave as if (64 * counter) bytes had been encrypted so far.
 //
-// To prevent accidental counter reuse, SetCounter panics if counter is
-// less than the current value.
+// To prevent accidental counter reuse, SetCounter panics if counter is less
+// than the current value.
+//
+// Note that the execution time of XORKeyStream is not independent of the
+// counter value.
 func (s *Cipher) SetCounter(counter uint32) {
 	// Internally, s may buffer multiple blocks, which complicates this
 	// implementation slightly. When checking whether the counter has rolled
 	// back, we must use both s.counter and s.len to determine how many blocks
 	// we have already output.
 	outputCounter := s.counter - uint32(s.len)/blockSize
-	if counter < outputCounter {
+	if s.overflow || counter < outputCounter {
 		panic("chacha20: SetCounter attempted to rollback counter")
 	}
 
@@ -197,34 +204,52 @@
 			dst[i] = src[i] ^ b
 		}
 		s.len -= len(keyStream)
-		src = src[len(keyStream):]
-		dst = dst[len(keyStream):]
+		dst, src = dst[len(keyStream):], src[len(keyStream):]
+	}
+	if len(src) == 0 {
+		return
 	}
 
-	const blocksPerBuf = bufSize / blockSize
-	numBufs := (uint64(len(src)) + bufSize - 1) / bufSize
-	if uint64(s.counter)+numBufs*blocksPerBuf >= 1<<32 {
+	// If we'd need to let the counter overflow and keep generating output,
+	// panic immediately. If instead we'd only reach the last block, remember
+	// not to generate any more output after the buffer is drained.
+	numBlocks := (uint64(len(src)) + blockSize - 1) / blockSize
+	if s.overflow || uint64(s.counter)+numBlocks > 1<<32 {
 		panic("chacha20: counter overflow")
+	} else if uint64(s.counter)+numBlocks == 1<<32 {
+		s.overflow = true
 	}
 
 	// xorKeyStreamBlocks implementations expect input lengths that are a
 	// multiple of bufSize. Platform-specific ones process multiple blocks at a
 	// time, so have bufSizes that are a multiple of blockSize.
 
-	rem := len(src) % bufSize
-	full := len(src) - rem
-
+	full := len(src) - len(src)%bufSize
 	if full > 0 {
 		s.xorKeyStreamBlocks(dst[:full], src[:full])
 	}
+	dst, src = dst[full:], src[full:]
+
+	// If using a multi-block xorKeyStreamBlocks would overflow, use the generic
+	// one that does one block at a time.
+	const blocksPerBuf = bufSize / blockSize
+	if uint64(s.counter)+blocksPerBuf > 1<<32 {
+		s.buf = [bufSize]byte{}
+		numBlocks := (len(src) + blockSize - 1) / blockSize
+		buf := s.buf[bufSize-numBlocks*blockSize:]
+		copy(buf, src)
+		s.xorKeyStreamBlocksGeneric(buf, buf)
+		s.len = len(buf) - copy(dst, buf)
+		return
+	}
 
 	// If we have a partial (multi-)block, pad it for xorKeyStreamBlocks, and
 	// keep the leftover keystream for the next XORKeyStream invocation.
-	if rem > 0 {
+	if len(src) > 0 {
 		s.buf = [bufSize]byte{}
-		copy(s.buf[:], src[full:])
+		copy(s.buf[:], src)
 		s.xorKeyStreamBlocks(s.buf[:], s.buf[:])
-		s.len = bufSize - copy(dst[full:], s.buf[:])
+		s.len = bufSize - copy(dst, s.buf[:])
 	}
 }
 
@@ -308,9 +333,6 @@
 		addXor(dst[60:64], src[60:64], x15, c15)
 
 		s.counter += 1
-		if s.counter == 0 {
-			panic("chacha20: internal error: counter overflow")
-		}
 
 		src, dst = src[blockSize:], dst[blockSize:]
 	}
diff --git a/chacha20/chacha_test.go b/chacha20/chacha_test.go
index 554afbf..d75873c 100644
--- a/chacha20/chacha_test.go
+++ b/chacha20/chacha_test.go
@@ -148,11 +148,66 @@
 	if !panics(func() { s.SetCounter(0) }) {
 		t.Error("counter decreasing should trigger a panic")
 	}
-	// advancing to ^uint32(0) and then calling XORKeyStream should cause a panic
-	s = newCipher()
-	s.SetCounter(^uint32(0))
-	if !panics(func() { s.XORKeyStream([]byte{0}, []byte{0}) }) {
-		t.Error("counter overflowing should trigger a panic")
+}
+
+func TestLastBlock(t *testing.T) {
+	panics := func(fn func()) (p bool) {
+		defer func() { p = recover() != nil }()
+		fn()
+		return
+	}
+
+	checkLastBlock := func(b []byte) {
+		t.Helper()
+		// Hardcoded result to check all implementations generate the same output.
+		lastBlock := "ace4cd09e294d1912d4ad205d06f95d9c2f2bfcf453e8753f128765b62215f4d" +
+			"92c74f2f626c6a640c0b1284d839ec81f1696281dafc3e684593937023b58b1d"
+		if got := hex.EncodeToString(b); got != lastBlock {
+			t.Errorf("wrong output for the last block, got %q, want %q", got, lastBlock)
+		}
+	}
+
+	// setting the counter to 0xffffffff and crypting multiple blocks should
+	// trigger a panic
+	s, _ := NewUnauthenticatedCipher(make([]byte, KeySize), make([]byte, NonceSize))
+	s.SetCounter(0xffffffff)
+	blocks := make([]byte, blockSize*2)
+	if !panics(func() { s.XORKeyStream(blocks, blocks) }) {
+		t.Error("crypting multiple blocks should trigger a panic")
+	}
+
+	// setting the counter to 0xffffffff - 1 and crypting two blocks should not
+	// trigger a panic
+	s, _ = NewUnauthenticatedCipher(make([]byte, KeySize), make([]byte, NonceSize))
+	s.SetCounter(0xffffffff - 1)
+	if panics(func() { s.XORKeyStream(blocks, blocks) }) {
+		t.Error("crypting the last blocks should not trigger a panic")
+	}
+	checkLastBlock(blocks[blockSize:])
+	// once all the keystream is spent, setting the counter should panic
+	if !panics(func() { s.SetCounter(0xffffffff) }) {
+		t.Error("setting the counter after overflow should trigger a panic")
+	}
+	// crypting a subsequent block *should* panic
+	block := make([]byte, blockSize)
+	if !panics(func() { s.XORKeyStream(block, block) }) {
+		t.Error("crypting after overflow should trigger a panic")
+	}
+
+	// if we crypt less than a full block, we should be able to crypt the rest
+	// in a subsequent call without panicking
+	s, _ = NewUnauthenticatedCipher(make([]byte, KeySize), make([]byte, NonceSize))
+	s.SetCounter(0xffffffff)
+	if panics(func() { s.XORKeyStream(block[:7], block[:7]) }) {
+		t.Error("crypting part of the last block should not trigger a panic")
+	}
+	if panics(func() { s.XORKeyStream(block[7:], block[7:]) }) {
+		t.Error("crypting part of the last block should not trigger a panic")
+	}
+	checkLastBlock(block)
+	// as before, a third call should trigger a panic because all keystream is spent
+	if !panics(func() { s.XORKeyStream(block[:1], block[:1]) }) {
+		t.Error("crypting after overflow should trigger a panic")
 	}
 }