chacha20: don't panic encrypting the final blocks

Certain operations with counter values close to overflowing were causing
an unnecessary panic, which was reachable due to the SetCounter API and
actually observed in QUIC.

Tests by lukechampine <luke.champine@gmail.com> from CL 220591.

Fixes golang/go#37157

Change-Id: Iba52edb1ba36af391b8fe4ee615c5c41d7e64f48
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/224279
Run-TryBot: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Katie Hockman <katie@golang.org>
diff --git a/chacha20/chacha_generic.go b/chacha20/chacha_generic.go
index 7c498e9..a6fb59f 100644
--- a/chacha20/chacha_generic.go
+++ b/chacha20/chacha_generic.go
@@ -42,10 +42,14 @@
 
 	// The last len bytes of buf are leftover key stream bytes from the previous
 	// XORKeyStream invocation. The size of buf depends on how many blocks are
-	// computed at a time.
+	// computed at a time by xorKeyStreamBlocks.
 	buf [bufSize]byte
 	len int
 
+	// overflow is set when the counter overflowed, no more blocks can be
+	// generated, and the next XORKeyStream call should panic.
+	overflow bool
+
 	// The counter-independent results of the first round are cached after they
 	// are computed the first time.
 	precompDone      bool
@@ -139,15 +143,18 @@
 // SetCounter sets the Cipher counter. The next invocation of XORKeyStream will
 // behave as if (64 * counter) bytes had been encrypted so far.
 //
-// To prevent accidental counter reuse, SetCounter panics if counter is
-// less than the current value.
+// To prevent accidental counter reuse, SetCounter panics if counter is less
+// than the current value.
+//
+// Note that the execution time of XORKeyStream is not independent of the
+// counter value.
 func (s *Cipher) SetCounter(counter uint32) {
 	// Internally, s may buffer multiple blocks, which complicates this
 	// implementation slightly. When checking whether the counter has rolled
 	// back, we must use both s.counter and s.len to determine how many blocks
 	// we have already output.
 	outputCounter := s.counter - uint32(s.len)/blockSize
-	if counter < outputCounter {
+	if s.overflow || counter < outputCounter {
 		panic("chacha20: SetCounter attempted to rollback counter")
 	}
 
@@ -196,34 +203,52 @@
 			dst[i] = src[i] ^ b
 		}
 		s.len -= len(keyStream)
-		src = src[len(keyStream):]
-		dst = dst[len(keyStream):]
+		dst, src = dst[len(keyStream):], src[len(keyStream):]
+	}
+	if len(src) == 0 {
+		return
 	}
 
-	const blocksPerBuf = bufSize / blockSize
-	numBufs := (uint64(len(src)) + bufSize - 1) / bufSize
-	if uint64(s.counter)+numBufs*blocksPerBuf >= 1<<32 {
+	// If we'd need to let the counter overflow and keep generating output,
+	// panic immediately. If instead we'd only reach the last block, remember
+	// not to generate any more output after the buffer is drained.
+	numBlocks := (uint64(len(src)) + blockSize - 1) / blockSize
+	if s.overflow || uint64(s.counter)+numBlocks > 1<<32 {
 		panic("chacha20: counter overflow")
+	} else if uint64(s.counter)+numBlocks == 1<<32 {
+		s.overflow = true
 	}
 
 	// xorKeyStreamBlocks implementations expect input lengths that are a
 	// multiple of bufSize. Platform-specific ones process multiple blocks at a
 	// time, so have bufSizes that are a multiple of blockSize.
 
-	rem := len(src) % bufSize
-	full := len(src) - rem
-
+	full := len(src) - len(src)%bufSize
 	if full > 0 {
 		s.xorKeyStreamBlocks(dst[:full], src[:full])
 	}
+	dst, src = dst[full:], src[full:]
+
+	// If using a multi-block xorKeyStreamBlocks would overflow, use the generic
+	// one that does one block at a time.
+	const blocksPerBuf = bufSize / blockSize
+	if uint64(s.counter)+blocksPerBuf > 1<<32 {
+		s.buf = [bufSize]byte{}
+		numBlocks := (len(src) + blockSize - 1) / blockSize
+		buf := s.buf[bufSize-numBlocks*blockSize:]
+		copy(buf, src)
+		s.xorKeyStreamBlocksGeneric(buf, buf)
+		s.len = bufSize - copy(dst, buf)
+		return
+	}
 
 	// If we have a partial (multi-)block, pad it for xorKeyStreamBlocks, and
 	// keep the leftover keystream for the next XORKeyStream invocation.
-	if rem > 0 {
+	if len(src) > 0 {
 		s.buf = [bufSize]byte{}
-		copy(s.buf[:], src[full:])
+		copy(s.buf[:], src)
 		s.xorKeyStreamBlocks(s.buf[:], s.buf[:])
-		s.len = bufSize - copy(dst[full:], s.buf[:])
+		s.len = bufSize - copy(dst, s.buf[:])
 	}
 }
 
@@ -304,9 +329,6 @@
 		x15 += c15
 
 		s.counter += 1
-		if s.counter == 0 {
-			panic("chacha20: internal error: counter overflow")
-		}
 
 		in, out := src[i:], dst[i:]
 		in, out = in[:blockSize], out[:blockSize] // bounds check elimination hint
diff --git a/chacha20/chacha_test.go b/chacha20/chacha_test.go
index 554afbf..e150182 100644
--- a/chacha20/chacha_test.go
+++ b/chacha20/chacha_test.go
@@ -148,11 +148,54 @@
 	if !panics(func() { s.SetCounter(0) }) {
 		t.Error("counter decreasing should trigger a panic")
 	}
-	// advancing to ^uint32(0) and then calling XORKeyStream should cause a panic
-	s = newCipher()
-	s.SetCounter(^uint32(0))
-	if !panics(func() { s.XORKeyStream([]byte{0}, []byte{0}) }) {
-		t.Error("counter overflowing should trigger a panic")
+}
+
+func TestLastBlock(t *testing.T) {
+	panics := func(fn func()) (p bool) {
+		defer func() { p = recover() != nil }()
+		fn()
+		return
+	}
+
+	// setting the counter to 0xffffffff and crypting multiple blocks should
+	// trigger a panic
+	s, _ := NewUnauthenticatedCipher(make([]byte, KeySize), make([]byte, NonceSize))
+	s.SetCounter(0xffffffff)
+	blocks := make([]byte, blockSize*2)
+	if !panics(func() { s.XORKeyStream(blocks, blocks) }) {
+		t.Error("crypting multiple blocks should trigger a panic")
+	}
+
+	// setting the counter to 0xffffffff - 1 and crypting two blocks should not
+	// trigger a panic
+	s, _ = NewUnauthenticatedCipher(make([]byte, KeySize), make([]byte, NonceSize))
+	s.SetCounter(0xffffffff - 1)
+	if panics(func() { s.XORKeyStream(blocks, blocks) }) {
+		t.Error("crypting the last blocks should not trigger a panic")
+	}
+	// once all the keystream is spent, setting the counter should panic
+	if !panics(func() { s.SetCounter(0xffffffff) }) {
+		t.Error("setting the counter after overflow should trigger a panic")
+	}
+	// crypting a subsequent block *should* panic
+	block := make([]byte, blockSize)
+	if !panics(func() { s.XORKeyStream(block, block) }) {
+		t.Error("crypting after overflow should trigger a panic")
+	}
+
+	// if we crypt less than a full block, we should be able to crypt the rest
+	// in a subsequent call without panicking
+	s, _ = NewUnauthenticatedCipher(make([]byte, KeySize), make([]byte, NonceSize))
+	s.SetCounter(0xffffffff)
+	if panics(func() { s.XORKeyStream(block[7:], block[7:]) }) {
+		t.Error("crypting part of the last block should not trigger a panic")
+	}
+	if panics(func() { s.XORKeyStream(block[:7], block[:7]) }) {
+		t.Error("crypting part of the last block should not trigger a panic")
+	}
+	// as before, a third call should trigger a panic because all keystream is spent
+	if !panics(func() { s.XORKeyStream(block[:1], block[:1]) }) {
+		t.Error("crypting after overflow should trigger a panic")
 	}
 }