Revert "chacha20: don't panic encrypting the final blocks"

This reverts CL 224279.

Reason for revert: broken on arm64, ppc64le and s390x 😢

Change-Id: I8632ee78a79696a3117c81729904797233e0071d
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/229118
Reviewed-by: Katie Hockman <katie@golang.org>
diff --git a/chacha20/chacha_generic.go b/chacha20/chacha_generic.go
index a6fb59f..7c498e9 100644
--- a/chacha20/chacha_generic.go
+++ b/chacha20/chacha_generic.go
@@ -42,14 +42,10 @@
 
 	// The last len bytes of buf are leftover key stream bytes from the previous
 	// XORKeyStream invocation. The size of buf depends on how many blocks are
-	// computed at a time by xorKeyStreamBlocks.
+	// computed at a time.
 	buf [bufSize]byte
 	len int
 
-	// overflow is set when the counter overflowed, no more blocks can be
-	// generated, and the next XORKeyStream call should panic.
-	overflow bool
-
 	// The counter-independent results of the first round are cached after they
 	// are computed the first time.
 	precompDone      bool
@@ -143,18 +139,15 @@
 // SetCounter sets the Cipher counter. The next invocation of XORKeyStream will
 // behave as if (64 * counter) bytes had been encrypted so far.
 //
-// To prevent accidental counter reuse, SetCounter panics if counter is less
-// than the current value.
-//
-// Note that the execution time of XORKeyStream is not independent of the
-// counter value.
+// To prevent accidental counter reuse, SetCounter panics if counter is
+// less than the current value.
 func (s *Cipher) SetCounter(counter uint32) {
 	// Internally, s may buffer multiple blocks, which complicates this
 	// implementation slightly. When checking whether the counter has rolled
 	// back, we must use both s.counter and s.len to determine how many blocks
 	// we have already output.
 	outputCounter := s.counter - uint32(s.len)/blockSize
-	if s.overflow || counter < outputCounter {
+	if counter < outputCounter {
 		panic("chacha20: SetCounter attempted to rollback counter")
 	}
 
@@ -203,52 +196,34 @@
 			dst[i] = src[i] ^ b
 		}
 		s.len -= len(keyStream)
-		dst, src = dst[len(keyStream):], src[len(keyStream):]
-	}
-	if len(src) == 0 {
-		return
+		src = src[len(keyStream):]
+		dst = dst[len(keyStream):]
 	}
 
-	// If we'd need to let the counter overflow and keep generating output,
-	// panic immediately. If instead we'd only reach the last block, remember
-	// not to generate any more output after the buffer is drained.
-	numBlocks := (uint64(len(src)) + blockSize - 1) / blockSize
-	if s.overflow || uint64(s.counter)+numBlocks > 1<<32 {
+	const blocksPerBuf = bufSize / blockSize
+	numBufs := (uint64(len(src)) + bufSize - 1) / bufSize
+	if uint64(s.counter)+numBufs*blocksPerBuf >= 1<<32 {
 		panic("chacha20: counter overflow")
-	} else if uint64(s.counter)+numBlocks == 1<<32 {
-		s.overflow = true
 	}
 
 	// xorKeyStreamBlocks implementations expect input lengths that are a
 	// multiple of bufSize. Platform-specific ones process multiple blocks at a
 	// time, so have bufSizes that are a multiple of blockSize.
 
-	full := len(src) - len(src)%bufSize
+	rem := len(src) % bufSize
+	full := len(src) - rem
+
 	if full > 0 {
 		s.xorKeyStreamBlocks(dst[:full], src[:full])
 	}
-	dst, src = dst[full:], src[full:]
-
-	// If using a multi-block xorKeyStreamBlocks would overflow, use the generic
-	// one that does one block at a time.
-	const blocksPerBuf = bufSize / blockSize
-	if uint64(s.counter)+blocksPerBuf > 1<<32 {
-		s.buf = [bufSize]byte{}
-		numBlocks := (len(src) + blockSize - 1) / blockSize
-		buf := s.buf[bufSize-numBlocks*blockSize:]
-		copy(buf, src)
-		s.xorKeyStreamBlocksGeneric(buf, buf)
-		s.len = bufSize - copy(dst, buf)
-		return
-	}
 
 	// If we have a partial (multi-)block, pad it for xorKeyStreamBlocks, and
 	// keep the leftover keystream for the next XORKeyStream invocation.
-	if len(src) > 0 {
+	if rem > 0 {
 		s.buf = [bufSize]byte{}
-		copy(s.buf[:], src)
+		copy(s.buf[:], src[full:])
 		s.xorKeyStreamBlocks(s.buf[:], s.buf[:])
-		s.len = bufSize - copy(dst, s.buf[:])
+		s.len = bufSize - copy(dst[full:], s.buf[:])
 	}
 }
 
@@ -329,6 +304,9 @@
 		x15 += c15
 
 		s.counter += 1
+		if s.counter == 0 {
+			panic("chacha20: internal error: counter overflow")
+		}
 
 		in, out := src[i:], dst[i:]
 		in, out = in[:blockSize], out[:blockSize] // bounds check elimination hint
diff --git a/chacha20/chacha_test.go b/chacha20/chacha_test.go
index e150182..554afbf 100644
--- a/chacha20/chacha_test.go
+++ b/chacha20/chacha_test.go
@@ -148,54 +148,11 @@
 	if !panics(func() { s.SetCounter(0) }) {
 		t.Error("counter decreasing should trigger a panic")
 	}
-}
-
-func TestLastBlock(t *testing.T) {
-	panics := func(fn func()) (p bool) {
-		defer func() { p = recover() != nil }()
-		fn()
-		return
-	}
-
-	// setting the counter to 0xffffffff and crypting multiple blocks should
-	// trigger a panic
-	s, _ := NewUnauthenticatedCipher(make([]byte, KeySize), make([]byte, NonceSize))
-	s.SetCounter(0xffffffff)
-	blocks := make([]byte, blockSize*2)
-	if !panics(func() { s.XORKeyStream(blocks, blocks) }) {
-		t.Error("crypting multiple blocks should trigger a panic")
-	}
-
-	// setting the counter to 0xffffffff - 1 and crypting two blocks should not
-	// trigger a panic
-	s, _ = NewUnauthenticatedCipher(make([]byte, KeySize), make([]byte, NonceSize))
-	s.SetCounter(0xffffffff - 1)
-	if panics(func() { s.XORKeyStream(blocks, blocks) }) {
-		t.Error("crypting the last blocks should not trigger a panic")
-	}
-	// once all the keystream is spent, setting the counter should panic
-	if !panics(func() { s.SetCounter(0xffffffff) }) {
-		t.Error("setting the counter after overflow should trigger a panic")
-	}
-	// crypting a subsequent block *should* panic
-	block := make([]byte, blockSize)
-	if !panics(func() { s.XORKeyStream(block, block) }) {
-		t.Error("crypting after overflow should trigger a panic")
-	}
-
-	// if we crypt less than a full block, we should be able to crypt the rest
-	// in a subsequent call without panicking
-	s, _ = NewUnauthenticatedCipher(make([]byte, KeySize), make([]byte, NonceSize))
-	s.SetCounter(0xffffffff)
-	if panics(func() { s.XORKeyStream(block[7:], block[7:]) }) {
-		t.Error("crypting part of the last block should not trigger a panic")
-	}
-	if panics(func() { s.XORKeyStream(block[:7], block[:7]) }) {
-		t.Error("crypting part of the last block should not trigger a panic")
-	}
-	// as before, a third call should trigger a panic because all keystream is spent
-	if !panics(func() { s.XORKeyStream(block[:1], block[:1]) }) {
-		t.Error("crypting after overflow should trigger a panic")
+	// advancing to ^uint32(0) and then calling XORKeyStream should cause a panic
+	s = newCipher()
+	s.SetCounter(^uint32(0))
+	if !panics(func() { s.XORKeyStream([]byte{0}, []byte{0}) }) {
+		t.Error("counter overflowing should trigger a panic")
 	}
 }