ccitt: factor out a highBits function

Change-Id: I27775c06a0bb95617e0a809e5902461aea0cafde
Reviewed-on: https://go-review.googlesource.com/c/image/+/191939
Reviewed-by: Benny Siegert <bsiegert@gmail.com>
Run-TryBot: Benny Siegert <bsiegert@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/ccitt/reader.go b/ccitt/reader.go
index 680f8a6..792a9ce 100644
--- a/ccitt/reader.go
+++ b/ccitt/reader.go
@@ -73,6 +73,43 @@
 	}
 }
 
+// highBits writes to dst (1 bit per pixel, most significant bit first) the
+// high (0x80) bits from src (1 byte per pixel). It returns the number of bytes
+// written and read such that dst[:d] is the packed form of src[:s].
+//
+// For example, if src starts with the 8 bytes [0x7D, 0x7E, 0x7F, 0x80, 0x81,
+// 0x82, 0x00, 0xFF] then 0x1D will be written to dst[0].
+//
+// If src has (8 * len(dst)) or more bytes then only len(dst) bytes are
+// written, (8 * len(dst)) bytes are read, and invert is ignored.
+//
+// Otherwise, if len(src) is not a multiple of 8 then the final byte written to
+// dst is padded with 1 bits (if invert is true) or 0 bits. If inverted, the 1s
+// are typically temporary, e.g. they will be flipped back to 0s by an
+// invertBytes call in the highBits caller, reader.Read.
+func highBits(dst []byte, src []byte, invert bool) (d int, s int) {
+	for d < len(dst) {
+		numToPack := len(src) - s
+		if numToPack <= 0 {
+			break
+		} else if numToPack > 8 {
+			numToPack = 8
+		}
+
+		byteValue := byte(0)
+		if invert {
+			byteValue = 0xFF >> uint(numToPack)
+		}
+		for n := 0; n < numToPack; n++ {
+			byteValue |= (src[s] & 0x80) >> uint(n)
+			s++
+		}
+		dst[d] = byteValue
+		d++
+	}
+	return d, s
+}
+
 type bitReader struct {
 	r io.Reader
 
@@ -257,31 +294,10 @@
 			z.rowsRemaining--
 		}
 
-		// Pack from z.curr (1 byte per pixel) to p (1 bit per pixel), up to 8
-		// elements per iteration.
-		i := 0
-		for ; i < len(p); i++ {
-			numToPack := len(z.curr) - z.ri
-			if numToPack <= 0 {
-				break
-			} else if numToPack > 8 {
-				numToPack = 8
-			}
-
-			byteValue := byte(0)
-			if z.invert {
-				// Set the end-of-row padding bits to 1 (if inverted) or 0. If inverted, the 1s
-				// are temporary, and will be flipped back to 0s by the invertBytes call below.
-				byteValue = 0xFF >> uint(numToPack)
-			}
-
-			for j := 0; j < numToPack; j++ {
-				byteValue |= (z.curr[z.ri] & 0x80) >> uint(j)
-				z.ri++
-			}
-			p[i] = byteValue
-		}
-		p = p[i:]
+		// Pack from z.curr (1 byte per pixel) to p (1 bit per pixel).
+		packD, packS := highBits(p, z.curr[z.ri:], z.invert)
+		p = p[packD:]
+		z.ri += packS
 
 		// Prepare to decode the next row, if necessary.
 		if z.ri == len(z.curr) {
diff --git a/ccitt/reader_test.go b/ccitt/reader_test.go
index 0224618..80c0e9f 100644
--- a/ccitt/reader_test.go
+++ b/ccitt/reader_test.go
@@ -11,6 +11,7 @@
 	"image/png"
 	"io"
 	"io/ioutil"
+	"math/rand"
 	"os"
 	"path/filepath"
 	"reflect"
@@ -48,6 +49,86 @@
 	return png.Decode(f)
 }
 
+// simpleHB is a simple implementation of highBits.
+func simpleHB(dst []byte, src []byte, invert bool) (d int, s int) {
+	for d < len(dst) {
+		numToPack := len(src) - s
+		if numToPack <= 0 {
+			break
+		} else if numToPack > 8 {
+			numToPack = 8
+		}
+
+		byteValue := byte(0)
+		if invert {
+			byteValue = 0xFF >> uint(numToPack)
+		}
+		for n := 0; n < numToPack; n++ {
+			byteValue |= (src[s] & 0x80) >> uint(n)
+			s++
+		}
+		dst[d] = byteValue
+		d++
+	}
+	return d, s
+}
+
+func TestPackBits(t *testing.T) {
+	rng := rand.New(rand.NewSource(1))
+	dst0 := make([]byte, 3)
+	dst1 := make([]byte, 3)
+	src := make([]byte, 20)
+
+	for r := 0; r < 1000; r++ {
+		numDst := rng.Intn(len(dst0) + 1)
+		randomByte := byte(rng.Intn(256))
+		for i := 0; i < numDst; i++ {
+			dst0[i] = randomByte
+			dst1[i] = randomByte
+		}
+
+		numSrc := rng.Intn(len(src) + 1)
+		for i := 0; i < numSrc; i++ {
+			src[i] = byte(rng.Intn(256))
+		}
+
+		invert := rng.Intn(2) == 0
+
+		d0, s0 := highBits(dst0[:numDst], src[:numSrc], invert)
+		d1, s1 := simpleHB(dst1[:numDst], src[:numSrc], invert)
+
+		if (d0 != d1) || (s0 != s1) || !bytes.Equal(dst0[:numDst], dst1[:numDst]) {
+			srcHighBits := make([]byte, numSrc)
+			for i := range srcHighBits {
+				srcHighBits[i] = src[i] >> 7
+			}
+
+			t.Fatalf("r=%d, numDst=%d, numSrc=%d, invert=%t:\nsrcHighBits=%d\n"+
+				"got  d=%d, s=%d, bytes=[% 02X]\n"+
+				"want d=%d, s=%d, bytes=[% 02X]",
+				r, numDst, numSrc, invert, srcHighBits,
+				d0, s0, dst0[:numDst],
+				d1, s1, dst1[:numDst],
+			)
+		}
+	}
+}
+
+func BenchmarkPackBits(b *testing.B) {
+	rng := rand.New(rand.NewSource(1))
+	dst := make([]byte, 1024)
+	src := make([]byte, 7777)
+	for i := range src {
+		src[i] = uint8(rng.Intn(256))
+	}
+
+	b.ResetTimer()
+	for n := 0; n < b.N; n++ {
+		highBits(dst, src, false)
+		highBits(dst, src, true)
+	}
+}
+
 func TestMaxCodeLength(t *testing.T) {
 	br := bitReader{}
 	size := unsafe.Sizeof(br.bits)