ccitt: don't advance bitReader after invalid code

Change-Id: I721f86aba22a6506b5b7ac162977cd31dc371b0a
Reviewed-on: https://go-review.googlesource.com/c/image/+/177297
Reviewed-by: Horst Rutter <hhrutter@gmail.com>
Reviewed-by: Benny Siegert <bsiegert@gmail.com>
diff --git a/ccitt/gen.go b/ccitt/gen.go
index 2972acd..c4d4bc9 100644
--- a/ccitt/gen.go
+++ b/ccitt/gen.go
@@ -32,6 +32,7 @@
 			"// whiteTable represents Tables 2 and 3 for a white run.\n")
 		write(w, build(blackCodes[:], 0), "blackTable",
 			"// blackTable represents Tables 2 and 3 for a black run.\n")
+		writeMaxCodeLength(w, modeCodes[:], whiteCodes[:], blackCodes[:])
 		finish(w, "table.go")
 	}
 
@@ -196,6 +197,18 @@
 	w.WriteString("+-+\n")
 }
 
+func writeMaxCodeLength(w *bytes.Buffer, codesList ...[]code) {
+	maxCodeLength := 0
+	for _, codes := range codesList {
+		for _, code := range codes {
+			if n := len(code.str); maxCodeLength < n {
+				maxCodeLength = n
+			}
+		}
+	}
+	fmt.Fprintf(w, "const maxCodeLength = %d\n\n", maxCodeLength)
+}
+
 func finish(w *bytes.Buffer, filename string) {
 	copyPaste(w, filename)
 	if *debug {
diff --git a/ccitt/reader.go b/ccitt/reader.go
index ad5f2fa..bd856a1 100644
--- a/ccitt/reader.go
+++ b/ccitt/reader.go
@@ -54,6 +54,14 @@
 	b.nBits -= n
 }
 
+// nextBitMaxNBits is the maximum possible value of bitReader.nBits after a
+// bitReader.nextBit call, provided that bitReader.nBits was not more than this
+// value before that call.
+//
+// Note that the decode function can unread bits, which can temporarily set the
+// bitReader.nBits value above nextBitMaxNBits.
+const nextBitMaxNBits = 31
+
 func (b *bitReader) nextBit() (uint32, error) {
 	for {
 		if b.nBits > 0 {
@@ -63,10 +71,16 @@
 			return bit, nil
 		}
 
-		if available := b.bw - b.br; available >= 8 {
-			b.bits = binary.LittleEndian.Uint64(b.bytes[b.br:])
-			b.br += 8
-			b.nBits = 64
+		if available := b.bw - b.br; available >= 4 {
+			// Read 32 bits, even though b.bits is a uint64, since the decode
+			// function may need to unread up to maxCodeLength bits, putting
+			// them back in the remaining (64 - 32) bits. TestMaxCodeLength
+			// checks that the generated maxCodeLength constant fits.
+			//
+			// If changing the Uint32 call, also change nextBitMaxNBits.
+			b.bits = uint64(binary.LittleEndian.Uint32(b.bytes[b.br:]))
+			b.br += 4
+			b.nBits = 32
 			continue
 		} else if available > 0 {
 			b.bits = uint64(b.bytes[b.br])
@@ -94,16 +108,22 @@
 }
 
 func decode(b *bitReader, table [][2]int16) (uint32, error) {
-	for state := int32(1); ; {
+	nBitsRead, bitsRead, state := uint32(0), uint32(0), int32(1)
+	for {
 		bit, err := b.nextBit()
 		if err != nil {
 			return 0, err
 		}
+		bitsRead |= bit << nBitsRead
+		nBitsRead++
 		// The "&1" is redundant, but can eliminate a bounds check.
 		state = int32(table[state][bit&1])
 		if state < 0 {
 			return uint32(^state), nil
 		} else if state == 0 {
+			// Unread the bits we've read, then return errInvalidCode.
+			b.bits = (b.bits << nBitsRead) | uint64(bitsRead)
+			b.nBits += nBitsRead
 			return 0, errInvalidCode
 		}
 	}
diff --git a/ccitt/reader_test.go b/ccitt/reader_test.go
index f6b47d8..8564d28 100644
--- a/ccitt/reader_test.go
+++ b/ccitt/reader_test.go
@@ -6,10 +6,37 @@
 
 import (
 	"bytes"
+	"io"
 	"reflect"
 	"testing"
+	"unsafe"
 )
 
+func TestMaxCodeLength(t *testing.T) {
+	br := bitReader{}
+	size := unsafe.Sizeof(br.bits)
+	size *= 8 // Convert from bytes to bits.
+
+	// Check that the size of the bitReader.bits field is large enough to hold
+	// nextBitMaxNBits bits.
+	if size < nextBitMaxNBits {
+		t.Fatalf("size: got %d, want >= %d", size, nextBitMaxNBits)
+	}
+
+	// Check that bitReader.nextBit will always leave enough spare bits in the
+	// bitReader.bits field such that the decode function can unread up to
+	// maxCodeLength bits.
+	if want := size - nextBitMaxNBits; maxCodeLength > want {
+		t.Fatalf("maxCodeLength: got %d, want <= %d", maxCodeLength, want)
+	}
+
+	// The decode function also assumes that, when saving bits to possibly
+	// unread later, those bits fit inside a uint32.
+	if maxCodeLength > 32 {
+		t.Fatalf("maxCodeLength: got %d, want <= %d", maxCodeLength, 32)
+	}
+}
+
 func testTable(t *testing.T, table [][2]int16, codes []code, values []uint32) {
 	// Build a map from values to codes.
 	m := map[uint32]string{}
@@ -92,4 +119,48 @@
 	})
 }
 
+func TestInvalidCode(t *testing.T) {
+	// The bit stream is:
+	// 1 010 000000011011
+	// Packing that LSB-first gives:
+	// 0b_1101_1000_0000_0101
+	src := []byte{0x05, 0xD8}
+
+	table := modeTable[:]
+	r := &bitReader{
+		r: bytes.NewReader(src),
+	}
+
+	// "1" decodes to the value 2.
+	if v, err := decode(r, table); v != 2 || err != nil {
+		t.Fatalf("decode #0: got (%v, %v), want (2, nil)", v, err)
+	}
+
+	// "010" decodes to the value 6.
+	if v, err := decode(r, table); v != 6 || err != nil {
+		t.Fatalf("decode #0: got (%v, %v), want (6, nil)", v, err)
+	}
+
+	// "00000001" is an invalid code.
+	if v, err := decode(r, table); v != 0 || err != errInvalidCode {
+		t.Fatalf("decode #0: got (%v, %v), want (0, %v)", v, err, errInvalidCode)
+	}
+
+	// The bitReader should not have advanced after encountering an invalid
+	// code. The remaining bits should be "000000011011".
+	remaining := []byte(nil)
+	for {
+		bit, err := r.nextBit()
+		if err == io.EOF {
+			break
+		} else if err != nil {
+			t.Fatalf("nextBit: %v", err)
+		}
+		remaining = append(remaining, uint8('0'+bit))
+	}
+	if got, want := string(remaining), "000000011011"; got != want {
+		t.Fatalf("remaining bits: got %q, want %q", got, want)
+	}
+}
+
 // TODO: more tests.
diff --git a/ccitt/table.go b/ccitt/table.go
index 266dfcf..6a8cfe1 100644
--- a/ccitt/table.go
+++ b/ccitt/table.go
@@ -719,6 +719,8 @@
 	104: {^1152, ^1216},
 }
 
+const maxCodeLength = 13
+
 // COPY PASTE table.go BEGIN
 
 const (