blob: a1b94510b57d678e1c82e4acc0930574b564075d [file] [log] [blame]
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ccitt
import (
func TestMaxCodeLength(t *testing.T) {
br := bitReader{}
size := unsafe.Sizeof(br.bits)
size *= 8 // Convert from bytes to bits.
// Check that the size of the bitReader.bits field is large enough to hold
// nextBitMaxNBits bits.
if size < nextBitMaxNBits {
t.Fatalf("size: got %d, want >= %d", size, nextBitMaxNBits)
// Check that bitReader.nextBit will always leave enough spare bits in the
// bitReader.bits field such that the decode function can unread up to
// maxCodeLength bits.
if want := size - nextBitMaxNBits; maxCodeLength > want {
t.Fatalf("maxCodeLength: got %d, want <= %d", maxCodeLength, want)
// The decode function also assumes that, when saving bits to possibly
// unread later, those bits fit inside a uint32.
if maxCodeLength > 32 {
t.Fatalf("maxCodeLength: got %d, want <= %d", maxCodeLength, 32)
func testDecodeTable(t *testing.T, decodeTable [][2]int16, codes []code, values []uint32) {
// Build a map from values to codes.
m := map[uint32]string{}
for _, code := range codes {
m[code.val] = code.str
// Build the encoded form of those values.
enc := []byte(nil)
bits := uint8(0)
nBits := uint32(0)
for _, v := range values {
code := m[v]
if code == "" {
panic("unmapped code")
for _, c := range code {
bits |= uint8(c&1) << nBits
if nBits == 8 {
enc = append(enc, bits)
bits = 0
nBits = 0
if nBits > 0 {
enc = append(enc, bits)
// Decode that encoded form.
got := []uint32(nil)
r := &bitReader{
r: bytes.NewReader(enc),
finalValue := values[len(values)-1]
for {
v, err := decode(r, decodeTable)
if err != nil {
t.Fatalf("after got=%d: %v", got, err)
got = append(got, v)
if v == finalValue {
// Check that the round-tripped values were unchanged.
if !reflect.DeepEqual(got, values) {
t.Fatalf("\ngot: %v\nwant: %v", got, values)
func TestModeDecodeTable(t *testing.T) {
testDecodeTable(t, modeDecodeTable[:], modeCodes, []uint32{
func TestWhiteDecodeTable(t *testing.T) {
testDecodeTable(t, whiteDecodeTable[:], whiteCodes, []uint32{
0, 1, 256, 7, 128, 3, 2560,
func TestBlackDecodeTable(t *testing.T) {
testDecodeTable(t, blackDecodeTable[:], blackCodes, []uint32{
63, 64, 63, 64, 64, 63, 22, 1088, 2048, 7, 6, 5, 4, 3, 2, 1, 0,
func TestDecodeInvalidCode(t *testing.T) {
// The bit stream is:
// 1 010 000000011011
// Packing that LSB-first gives:
// 0b_1101_1000_0000_0101
src := []byte{0x05, 0xD8}
decodeTable := modeDecodeTable[:]
r := &bitReader{
r: bytes.NewReader(src),
// "1" decodes to the value 2.
if v, err := decode(r, decodeTable); v != 2 || err != nil {
t.Fatalf("decode #0: got (%v, %v), want (2, nil)", v, err)
// "010" decodes to the value 6.
if v, err := decode(r, decodeTable); v != 6 || err != nil {
t.Fatalf("decode #0: got (%v, %v), want (6, nil)", v, err)
// "00000001" is an invalid code.
if v, err := decode(r, decodeTable); v != 0 || err != errInvalidCode {
t.Fatalf("decode #0: got (%v, %v), want (0, %v)", v, err, errInvalidCode)
// The bitReader should not have advanced after encountering an invalid
// code. The remaining bits should be "000000011011".
remaining := []byte(nil)
for {
bit, err := r.nextBit()
if err == io.EOF {
} else if err != nil {
t.Fatalf("nextBit: %v", err)
remaining = append(remaining, uint8('0'+bit))
if got, want := string(remaining), "000000011011"; got != want {
t.Fatalf("remaining bits: got %q, want %q", got, want)
// TODO: more tests.