ccitt: don't advance bitReader after invalid code
Change-Id: I721f86aba22a6506b5b7ac162977cd31dc371b0a
Reviewed-on: https://go-review.googlesource.com/c/image/+/177297
Reviewed-by: Horst Rutter <hhrutter@gmail.com>
Reviewed-by: Benny Siegert <bsiegert@gmail.com>
diff --git a/ccitt/gen.go b/ccitt/gen.go
index 2972acd..c4d4bc9 100644
--- a/ccitt/gen.go
+++ b/ccitt/gen.go
@@ -32,6 +32,7 @@
"// whiteTable represents Tables 2 and 3 for a white run.\n")
write(w, build(blackCodes[:], 0), "blackTable",
"// blackTable represents Tables 2 and 3 for a black run.\n")
+ writeMaxCodeLength(w, modeCodes[:], whiteCodes[:], blackCodes[:])
finish(w, "table.go")
}
@@ -196,6 +197,18 @@
w.WriteString("+-+\n")
}
+func writeMaxCodeLength(w *bytes.Buffer, codesList ...[]code) {
+ maxCodeLength := 0
+ for _, codes := range codesList {
+ for _, code := range codes {
+ if n := len(code.str); maxCodeLength < n {
+ maxCodeLength = n
+ }
+ }
+ }
+ fmt.Fprintf(w, "const maxCodeLength = %d\n\n", maxCodeLength)
+}
+
func finish(w *bytes.Buffer, filename string) {
copyPaste(w, filename)
if *debug {
diff --git a/ccitt/reader.go b/ccitt/reader.go
index ad5f2fa..bd856a1 100644
--- a/ccitt/reader.go
+++ b/ccitt/reader.go
@@ -54,6 +54,14 @@
b.nBits -= n
}
+// nextBitMaxNBits is the maximum possible value of bitReader.nBits after a
+// bitReader.nextBit call, provided that bitReader.nBits was not more than this
+// value before that call.
+//
+// Note that the decode function can unread bits, which can temporarily set the
+// bitReader.nBits value above nextBitMaxNBits.
+const nextBitMaxNBits = 31
+
func (b *bitReader) nextBit() (uint32, error) {
for {
if b.nBits > 0 {
@@ -63,10 +71,16 @@
return bit, nil
}
- if available := b.bw - b.br; available >= 8 {
- b.bits = binary.LittleEndian.Uint64(b.bytes[b.br:])
- b.br += 8
- b.nBits = 64
+ if available := b.bw - b.br; available >= 4 {
+ // Read 32 bits, even though b.bits is a uint64, since the decode
+ // function may need to unread up to maxCodeLength bits, putting
+ // them back in the remaining (64 - 32) bits. TestMaxCodeLength
+ // checks that the generated maxCodeLength constant fits.
+ //
+ // If changing the Uint32 call, also change nextBitMaxNBits.
+ b.bits = uint64(binary.LittleEndian.Uint32(b.bytes[b.br:]))
+ b.br += 4
+ b.nBits = 32
continue
} else if available > 0 {
b.bits = uint64(b.bytes[b.br])
@@ -94,16 +108,22 @@
}
func decode(b *bitReader, table [][2]int16) (uint32, error) {
- for state := int32(1); ; {
+ nBitsRead, bitsRead, state := uint32(0), uint32(0), int32(1)
+ for {
bit, err := b.nextBit()
if err != nil {
return 0, err
}
+ bitsRead |= bit << nBitsRead
+ nBitsRead++
// The "&1" is redundant, but can eliminate a bounds check.
state = int32(table[state][bit&1])
if state < 0 {
return uint32(^state), nil
} else if state == 0 {
+ // Unread the bits we've read, then return errInvalidCode.
+ b.bits = (b.bits << nBitsRead) | uint64(bitsRead)
+ b.nBits += nBitsRead
return 0, errInvalidCode
}
}
diff --git a/ccitt/reader_test.go b/ccitt/reader_test.go
index f6b47d8..8564d28 100644
--- a/ccitt/reader_test.go
+++ b/ccitt/reader_test.go
@@ -6,10 +6,37 @@
import (
"bytes"
+ "io"
"reflect"
"testing"
+ "unsafe"
)
+func TestMaxCodeLength(t *testing.T) {
+ br := bitReader{}
+ size := unsafe.Sizeof(br.bits)
+ size *= 8 // Convert from bytes to bits.
+
+ // Check that the size of the bitReader.bits field is large enough to hold
+ // nextBitMaxNBits bits.
+ if size < nextBitMaxNBits {
+ t.Fatalf("size: got %d, want >= %d", size, nextBitMaxNBits)
+ }
+
+ // Check that bitReader.nextBit will always leave enough spare bits in the
+ // bitReader.bits field such that the decode function can unread up to
+ // maxCodeLength bits.
+ if want := size - nextBitMaxNBits; maxCodeLength > want {
+ t.Fatalf("maxCodeLength: got %d, want <= %d", maxCodeLength, want)
+ }
+
+ // The decode function also assumes that, when saving bits to possibly
+ // unread later, those bits fit inside a uint32.
+ if maxCodeLength > 32 {
+ t.Fatalf("maxCodeLength: got %d, want <= %d", maxCodeLength, 32)
+ }
+}
+
func testTable(t *testing.T, table [][2]int16, codes []code, values []uint32) {
// Build a map from values to codes.
m := map[uint32]string{}
@@ -92,4 +119,48 @@
})
}
+func TestInvalidCode(t *testing.T) {
+ // The bit stream is:
+ // 1 010 000000011011
+ // Packing that LSB-first gives:
+ // 0b_1101_1000_0000_0101
+ src := []byte{0x05, 0xD8}
+
+ table := modeTable[:]
+ r := &bitReader{
+ r: bytes.NewReader(src),
+ }
+
+ // "1" decodes to the value 2.
+ if v, err := decode(r, table); v != 2 || err != nil {
+ t.Fatalf("decode #0: got (%v, %v), want (2, nil)", v, err)
+ }
+
+ // "010" decodes to the value 6.
+ if v, err := decode(r, table); v != 6 || err != nil {
+ t.Fatalf("decode #0: got (%v, %v), want (6, nil)", v, err)
+ }
+
+ // "00000001" is an invalid code.
+ if v, err := decode(r, table); v != 0 || err != errInvalidCode {
+ t.Fatalf("decode #0: got (%v, %v), want (0, %v)", v, err, errInvalidCode)
+ }
+
+ // The bitReader should not have advanced after encountering an invalid
+ // code. The remaining bits should be "000000011011".
+ remaining := []byte(nil)
+ for {
+ bit, err := r.nextBit()
+ if err == io.EOF {
+ break
+ } else if err != nil {
+ t.Fatalf("nextBit: %v", err)
+ }
+ remaining = append(remaining, uint8('0'+bit))
+ }
+ if got, want := string(remaining), "000000011011"; got != want {
+ t.Fatalf("remaining bits: got %q, want %q", got, want)
+ }
+}
+
// TODO: more tests.
diff --git a/ccitt/table.go b/ccitt/table.go
index 266dfcf..6a8cfe1 100644
--- a/ccitt/table.go
+++ b/ccitt/table.go
@@ -719,6 +719,8 @@
104: {^1152, ^1216},
}
+const maxCodeLength = 13
+
// COPY PASTE table.go BEGIN
const (