blob: 452e24b760c3e5276a5fdf013a938ab3522e342c [file] [log] [blame]
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zstd
import (
"io"
"math/bits"
)
// maxHuffmanBits is the largest possible Huffman table bits.
const maxHuffmanBits = 11
// readHuff reads Huffman table from data starting at off into table.
// Each entry in a Huffman table is a pair of bytes.
// The high byte is the encoded value. The low byte is the number
// of bits used to encode that value. We index into the table
// with a value of size tableBits. A value that requires fewer bits
// appear in the table multiple times.
// This returns the number of bits in the Huffman table and the new offset.
// RFC 4.2.1.
func (r *Reader) readHuff(data block, off int, table []uint16) (tableBits, roff int, err error) {
if off >= len(data) {
return 0, 0, r.makeEOFError(off)
}
hdr := data[off]
off++
var weights [256]uint8
var count int
if hdr < 128 {
// The table is compressed using an FSE. RFC 4.2.1.2.
if len(r.fseScratch) < 1<<6 {
r.fseScratch = make([]fseEntry, 1<<6)
}
fseBits, noff, err := r.readFSE(data, off, 255, 6, r.fseScratch)
if err != nil {
return 0, 0, err
}
fseTable := r.fseScratch
if off+int(hdr) > len(data) {
return 0, 0, r.makeEOFError(off)
}
rbr, err := r.makeReverseBitReader(data, off+int(hdr)-1, noff)
if err != nil {
return 0, 0, err
}
state1, err := rbr.val(uint8(fseBits))
if err != nil {
return 0, 0, err
}
state2, err := rbr.val(uint8(fseBits))
if err != nil {
return 0, 0, err
}
// There are two independent FSE streams, tracked by
// state1 and state2. We decode them alternately.
for {
pt := &fseTable[state1]
if !rbr.fetch(pt.bits) {
if count >= 254 {
return 0, 0, rbr.makeError("Huffman count overflow")
}
weights[count] = pt.sym
weights[count+1] = fseTable[state2].sym
count += 2
break
}
v, err := rbr.val(pt.bits)
if err != nil {
return 0, 0, err
}
state1 = uint32(pt.base) + v
if count >= 255 {
return 0, 0, rbr.makeError("Huffman count overflow")
}
weights[count] = pt.sym
count++
pt = &fseTable[state2]
if !rbr.fetch(pt.bits) {
if count >= 254 {
return 0, 0, rbr.makeError("Huffman count overflow")
}
weights[count] = pt.sym
weights[count+1] = fseTable[state1].sym
count += 2
break
}
v, err = rbr.val(pt.bits)
if err != nil {
return 0, 0, err
}
state2 = uint32(pt.base) + v
if count >= 255 {
return 0, 0, rbr.makeError("Huffman count overflow")
}
weights[count] = pt.sym
count++
}
off += int(hdr)
} else {
// The table is not compressed. Each weight is 4 bits.
count = int(hdr) - 127
if off+((count+1)/2) >= len(data) {
return 0, 0, io.ErrUnexpectedEOF
}
for i := 0; i < count; i += 2 {
b := data[off]
off++
weights[i] = b >> 4
weights[i+1] = b & 0xf
}
}
// RFC 4.2.1.3.
var weightMark [13]uint32
weightMask := uint32(0)
for _, w := range weights[:count] {
if w > 12 {
return 0, 0, r.makeError(off, "Huffman weight overflow")
}
weightMark[w]++
if w > 0 {
weightMask += 1 << (w - 1)
}
}
if weightMask == 0 {
return 0, 0, r.makeError(off, "bad Huffman weights")
}
tableBits = 32 - bits.LeadingZeros32(weightMask)
if tableBits > maxHuffmanBits {
return 0, 0, r.makeError(off, "bad Huffman weights")
}
if len(table) < 1<<tableBits {
return 0, 0, r.makeError(off, "Huffman table too small")
}
// Work out the last weight value, which is omitted because
// the weights must sum to a power of two.
left := (uint32(1) << tableBits) - weightMask
if left == 0 {
return 0, 0, r.makeError(off, "bad Huffman weights")
}
highBit := 31 - bits.LeadingZeros32(left)
if uint32(1)<<highBit != left {
return 0, 0, r.makeError(off, "bad Huffman weights")
}
if count >= 256 {
return 0, 0, r.makeError(off, "Huffman weight overflow")
}
weights[count] = uint8(highBit + 1)
count++
weightMark[highBit+1]++
if weightMark[1] < 2 || weightMark[1]&1 != 0 {
return 0, 0, r.makeError(off, "bad Huffman weights")
}
// Change weightMark from a count of weights to the index of
// the first symbol for that weight. We shift the indexes to
// also store how many we have seen so far,
next := uint32(0)
for i := 0; i < tableBits; i++ {
cur := next
next += weightMark[i+1] << i
weightMark[i+1] = cur
}
for i, w := range weights[:count] {
if w == 0 {
continue
}
length := uint32(1) << (w - 1)
tval := uint16(i)<<8 | (uint16(tableBits) + 1 - uint16(w))
start := weightMark[w]
for j := uint32(0); j < length; j++ {
table[start+j] = tval
}
weightMark[w] += length
}
return tableBits, off, nil
}