blob: 0d03a7317e276a98affb02810a79d8cfb6f30ecb [file] [log] [blame]
Nigel Tao8bf58722009-12-17 10:32:17 +11001// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package jpeg
6
7import (
8 "io"
9 "os"
10)
11
12// Each code is at most 16 bits long.
13const maxCodeLength = 16
14
15// Each decoded value is a uint8, so there are at most 256 such values.
16const maxNumValues = 256
17
18// Bit stream for the Huffman decoder.
19// The n least significant bits of a form the unread bits, to be read in MSB to LSB order.
20type bits struct {
21 a int // accumulator.
22 n int // the number of unread bits in a.
23 m int // mask. m==1<<(n-1) when n>0, with m==0 when n==0.
24}
25
26// Huffman table decoder, specified in section C.
27type huffman struct {
28 l [maxCodeLength]int
29 length int // sum of l[i].
30 val [maxNumValues]uint8 // the decoded values, as sorted by their encoding.
31 size [maxNumValues]int // size[i] is the number of bits to encode val[i].
32 code [maxNumValues]int // code[i] is the encoding of val[i].
33 minCode [maxCodeLength]int // min codes of length i, or -1 if no codes of that length.
34 maxCode [maxCodeLength]int // max codes of length i, or -1 if no codes of that length.
35 valIndex [maxCodeLength]int // index into val of minCode[i].
36}
37
38// Reads bytes from the io.Reader to ensure that bits.n is at least n.
39func (d *decoder) ensureNBits(n int) os.Error {
40 for d.b.n < n {
41 c, err := d.r.ReadByte()
42 if err != nil {
43 return err
44 }
45 d.b.a = d.b.a<<8 | int(c)
46 d.b.n += 8
47 if d.b.m == 0 {
48 d.b.m = 1 << 7
49 } else {
50 d.b.m <<= 8
51 }
52 // Byte stuffing, specified in section F.1.2.3.
53 if c == 0xff {
54 c, err = d.r.ReadByte()
55 if err != nil {
56 return err
57 }
58 if c != 0x00 {
59 return FormatError("missing 0xff00 sequence")
60 }
61 }
62 }
63 return nil
64}
65
66// The composition of RECEIVE and EXTEND, specified in section F.2.2.1.
67func (d *decoder) receiveExtend(t uint8) (int, os.Error) {
68 err := d.ensureNBits(int(t))
69 if err != nil {
70 return 0, err
71 }
72 d.b.n -= int(t)
73 d.b.m >>= t
74 s := 1 << t
75 x := (d.b.a >> uint8(d.b.n)) & (s - 1)
76 if x < s>>1 {
77 x += ((-1) << t) + 1
78 }
79 return x, nil
80}
81
82// Processes a Define Huffman Table marker, and initializes a huffman struct from its contents.
83// Specified in section B.2.4.2.
84func (d *decoder) processDHT(n int) os.Error {
85 for n > 0 {
86 if n < 17 {
87 return FormatError("DHT has wrong length")
88 }
89 _, err := io.ReadFull(d.r, d.tmp[0:17])
90 if err != nil {
91 return err
92 }
93 tc := d.tmp[0] >> 4
94 if tc > maxTc {
95 return FormatError("bad Tc value")
96 }
97 th := d.tmp[0] & 0x0f
98 const isBaseline = true // Progressive mode is not yet supported.
99 if th > maxTh || isBaseline && th > 1 {
100 return FormatError("bad Th value")
101 }
102 h := &d.huff[tc][th]
103
104 // Read l and val (and derive length).
105 h.length = 0
106 for i := 0; i < maxCodeLength; i++ {
107 h.l[i] = int(d.tmp[i+1])
108 h.length += h.l[i]
109 }
110 if h.length == 0 {
111 return FormatError("Huffman table has zero length")
112 }
113 if h.length > maxNumValues {
114 return FormatError("Huffman table has excessive length")
115 }
116 n -= h.length + 17
117 if n < 0 {
118 return FormatError("DHT has wrong length")
119 }
120 _, err = io.ReadFull(d.r, h.val[0:h.length])
121 if err != nil {
122 return err
123 }
124
125 // Derive size.
126 k := 0
127 for i := 0; i < maxCodeLength; i++ {
128 for j := 0; j < h.l[i]; j++ {
129 h.size[k] = i + 1
130 k++
131 }
132 }
133
134 // Derive code.
135 code := 0
136 size := h.size[0]
137 for i := 0; i < h.length; i++ {
138 if size != h.size[i] {
139 code <<= uint8(h.size[i] - size)
140 size = h.size[i]
141 }
142 h.code[i] = code
143 code++
144 }
145
146 // Derive minCode, maxCode, and valIndex.
147 k = 0
148 index := 0
149 for i := 0; i < maxCodeLength; i++ {
150 if h.l[i] == 0 {
151 h.minCode[i] = -1
152 h.maxCode[i] = -1
153 h.valIndex[i] = -1
154 } else {
155 h.minCode[i] = k
156 h.maxCode[i] = k + h.l[i] - 1
157 h.valIndex[i] = index
158 k += h.l[i]
159 index += h.l[i]
160 }
161 k <<= 1
162 }
163 }
164 return nil
165}
166
167// Returns the next Huffman-coded value from the bit stream, decoded according to h.
168// TODO(nigeltao): This decoding algorithm is simple, but slow. A lookahead table, instead of always
169// peeling off only 1 bit at at time, ought to be faster.
170func (d *decoder) decodeHuffman(h *huffman) (uint8, os.Error) {
171 if h.length == 0 {
172 return 0, FormatError("uninitialized Huffman table")
173 }
174 for i, code := 0, 0; i < maxCodeLength; i++ {
175 err := d.ensureNBits(1)
176 if err != nil {
177 return 0, err
178 }
179 if d.b.a&d.b.m != 0 {
180 code |= 1
181 }
182 d.b.n--
183 d.b.m >>= 1
184 if code <= h.maxCode[i] {
185 return h.val[h.valIndex[i]+code-h.minCode[i]], nil
186 }
187 code <<= 1
188 }
189 return 0, FormatError("bad Huffman code")
190}