| // Copyright 2023 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package zstd |
| |
| import ( |
| "io" |
| ) |
| |
| // debug can be set in the source to print debug info using println. |
| const debug = false |
| |
| // compressedBlock decompresses a compressed block, storing the decompressed |
| // data in r.buffer. The blockSize argument is the compressed size. |
| // RFC 3.1.1.3. |
| func (r *Reader) compressedBlock(blockSize int) error { |
| if len(r.compressedBuf) >= blockSize { |
| r.compressedBuf = r.compressedBuf[:blockSize] |
| } else { |
| // We know that blockSize <= 128K, |
| // so this won't allocate an enormous amount. |
| need := blockSize - len(r.compressedBuf) |
| r.compressedBuf = append(r.compressedBuf, make([]byte, need)...) |
| } |
| |
| if _, err := io.ReadFull(r.r, r.compressedBuf); err != nil { |
| return r.wrapNonEOFError(0, err) |
| } |
| |
| data := block(r.compressedBuf) |
| off := 0 |
| r.buffer = r.buffer[:0] |
| |
| litoff, litbuf, err := r.readLiterals(data, off, r.literals[:0]) |
| if err != nil { |
| return err |
| } |
| r.literals = litbuf |
| |
| off = litoff |
| |
| seqCount, off, err := r.initSeqs(data, off) |
| if err != nil { |
| return err |
| } |
| |
| if seqCount == 0 { |
| // No sequences, just literals. |
| if off < len(data) { |
| return r.makeError(off, "extraneous data after no sequences") |
| } |
| |
| r.buffer = append(r.buffer, litbuf...) |
| |
| return nil |
| } |
| |
| return r.execSeqs(data, off, litbuf, seqCount) |
| } |
| |
| // seqCode is the kind of sequence codes we have to handle. |
| type seqCode int |
| |
| const ( |
| seqLiteral seqCode = iota |
| seqOffset |
| seqMatch |
| ) |
| |
| // seqCodeInfoData is the information needed to set up seqTables and |
| // seqTableBits for a particular kind of sequence code. |
| type seqCodeInfoData struct { |
| predefTable []fseBaselineEntry // predefined FSE |
| predefTableBits int // number of bits in predefTable |
| maxSym int // max symbol value in FSE |
| maxBits int // max bits for FSE |
| |
| // toBaseline converts from an FSE table to an FSE baseline table. |
| toBaseline func(*Reader, int, []fseEntry, []fseBaselineEntry) error |
| } |
| |
| // seqCodeInfo is the seqCodeInfoData for each kind of sequence code. |
| var seqCodeInfo = [3]seqCodeInfoData{ |
| seqLiteral: { |
| predefTable: predefinedLiteralTable[:], |
| predefTableBits: 6, |
| maxSym: 35, |
| maxBits: 9, |
| toBaseline: (*Reader).makeLiteralBaselineFSE, |
| }, |
| seqOffset: { |
| predefTable: predefinedOffsetTable[:], |
| predefTableBits: 5, |
| maxSym: 31, |
| maxBits: 8, |
| toBaseline: (*Reader).makeOffsetBaselineFSE, |
| }, |
| seqMatch: { |
| predefTable: predefinedMatchTable[:], |
| predefTableBits: 6, |
| maxSym: 52, |
| maxBits: 9, |
| toBaseline: (*Reader).makeMatchBaselineFSE, |
| }, |
| } |
| |
| // initSeqs reads the Sequences_Section_Header and sets up the FSE |
| // tables used to read the sequence codes. It returns the number of |
| // sequences and the new offset. RFC 3.1.1.3.2.1. |
| func (r *Reader) initSeqs(data block, off int) (int, int, error) { |
| if off >= len(data) { |
| return 0, 0, r.makeEOFError(off) |
| } |
| |
| seqHdr := data[off] |
| off++ |
| if seqHdr == 0 { |
| return 0, off, nil |
| } |
| |
| var seqCount int |
| if seqHdr < 128 { |
| seqCount = int(seqHdr) |
| } else if seqHdr < 255 { |
| if off >= len(data) { |
| return 0, 0, r.makeEOFError(off) |
| } |
| seqCount = ((int(seqHdr) - 128) << 8) + int(data[off]) |
| off++ |
| } else { |
| if off+1 >= len(data) { |
| return 0, 0, r.makeEOFError(off) |
| } |
| seqCount = int(data[off]) + (int(data[off+1]) << 8) + 0x7f00 |
| off += 2 |
| } |
| |
| // Read the Symbol_Compression_Modes byte. |
| |
| if off >= len(data) { |
| return 0, 0, r.makeEOFError(off) |
| } |
| symMode := data[off] |
| if symMode&3 != 0 { |
| return 0, 0, r.makeError(off, "invalid symbol compression mode") |
| } |
| off++ |
| |
| // Set up the FSE tables used to decode the sequence codes. |
| |
| var err error |
| off, err = r.setSeqTable(data, off, seqLiteral, (symMode>>6)&3) |
| if err != nil { |
| return 0, 0, err |
| } |
| |
| off, err = r.setSeqTable(data, off, seqOffset, (symMode>>4)&3) |
| if err != nil { |
| return 0, 0, err |
| } |
| |
| off, err = r.setSeqTable(data, off, seqMatch, (symMode>>2)&3) |
| if err != nil { |
| return 0, 0, err |
| } |
| |
| return seqCount, off, nil |
| } |
| |
| // setSeqTable uses the Compression_Mode in mode to set up r.seqTables and |
| // r.seqTableBits for kind. We store these in the Reader because one of |
| // the modes simply reuses the value from the last block in the frame. |
| func (r *Reader) setSeqTable(data block, off int, kind seqCode, mode byte) (int, error) { |
| info := &seqCodeInfo[kind] |
| switch mode { |
| case 0: |
| // Predefined_Mode |
| r.seqTables[kind] = info.predefTable |
| r.seqTableBits[kind] = uint8(info.predefTableBits) |
| return off, nil |
| |
| case 1: |
| // RLE_Mode |
| if off >= len(data) { |
| return 0, r.makeEOFError(off) |
| } |
| rle := data[off] |
| off++ |
| |
| // Build a simple baseline table that always returns rle. |
| |
| entry := []fseEntry{ |
| { |
| sym: rle, |
| bits: 0, |
| base: 0, |
| }, |
| } |
| if cap(r.seqTableBuffers[kind]) == 0 { |
| r.seqTableBuffers[kind] = make([]fseBaselineEntry, 1<<info.maxBits) |
| } |
| r.seqTableBuffers[kind] = r.seqTableBuffers[kind][:1] |
| if err := info.toBaseline(r, off, entry, r.seqTableBuffers[kind]); err != nil { |
| return 0, err |
| } |
| |
| r.seqTables[kind] = r.seqTableBuffers[kind] |
| r.seqTableBits[kind] = 0 |
| return off, nil |
| |
| case 2: |
| // FSE_Compressed_Mode |
| if cap(r.fseScratch) < 1<<info.maxBits { |
| r.fseScratch = make([]fseEntry, 1<<info.maxBits) |
| } |
| r.fseScratch = r.fseScratch[:1<<info.maxBits] |
| |
| tableBits, roff, err := r.readFSE(data, off, info.maxSym, info.maxBits, r.fseScratch) |
| if err != nil { |
| return 0, err |
| } |
| r.fseScratch = r.fseScratch[:1<<tableBits] |
| |
| if cap(r.seqTableBuffers[kind]) == 0 { |
| r.seqTableBuffers[kind] = make([]fseBaselineEntry, 1<<info.maxBits) |
| } |
| r.seqTableBuffers[kind] = r.seqTableBuffers[kind][:1<<tableBits] |
| |
| if err := info.toBaseline(r, roff, r.fseScratch, r.seqTableBuffers[kind]); err != nil { |
| return 0, err |
| } |
| |
| r.seqTables[kind] = r.seqTableBuffers[kind] |
| r.seqTableBits[kind] = uint8(tableBits) |
| return roff, nil |
| |
| case 3: |
| // Repeat_Mode |
| if len(r.seqTables[kind]) == 0 { |
| return 0, r.makeError(off, "missing repeat sequence FSE table") |
| } |
| return off, nil |
| } |
| panic("unreachable") |
| } |
| |
| // execSeqs reads and executes the sequences. RFC 3.1.1.3.2.1.2. |
| func (r *Reader) execSeqs(data block, off int, litbuf []byte, seqCount int) error { |
| // Set up the initial states for the sequence code readers. |
| |
| rbr, err := r.makeReverseBitReader(data, len(data)-1, off) |
| if err != nil { |
| return err |
| } |
| |
| literalState, err := rbr.val(r.seqTableBits[seqLiteral]) |
| if err != nil { |
| return err |
| } |
| |
| offsetState, err := rbr.val(r.seqTableBits[seqOffset]) |
| if err != nil { |
| return err |
| } |
| |
| matchState, err := rbr.val(r.seqTableBits[seqMatch]) |
| if err != nil { |
| return err |
| } |
| |
| // Read and perform all the sequences. RFC 3.1.1.4. |
| |
| seq := 0 |
| for seq < seqCount { |
| if len(r.buffer)+len(litbuf) > 128<<10 { |
| return rbr.makeError("uncompressed size too big") |
| } |
| |
| ptoffset := &r.seqTables[seqOffset][offsetState] |
| ptmatch := &r.seqTables[seqMatch][matchState] |
| ptliteral := &r.seqTables[seqLiteral][literalState] |
| |
| add, err := rbr.val(ptoffset.basebits) |
| if err != nil { |
| return err |
| } |
| offset := ptoffset.baseline + add |
| |
| add, err = rbr.val(ptmatch.basebits) |
| if err != nil { |
| return err |
| } |
| match := ptmatch.baseline + add |
| |
| add, err = rbr.val(ptliteral.basebits) |
| if err != nil { |
| return err |
| } |
| literal := ptliteral.baseline + add |
| |
| // Handle repeat offsets. RFC 3.1.1.5. |
| // See the comment in makeOffsetBaselineFSE. |
| if ptoffset.basebits > 1 { |
| r.repeatedOffset3 = r.repeatedOffset2 |
| r.repeatedOffset2 = r.repeatedOffset1 |
| r.repeatedOffset1 = offset |
| } else { |
| if literal == 0 { |
| offset++ |
| } |
| switch offset { |
| case 1: |
| offset = r.repeatedOffset1 |
| case 2: |
| offset = r.repeatedOffset2 |
| r.repeatedOffset2 = r.repeatedOffset1 |
| r.repeatedOffset1 = offset |
| case 3: |
| offset = r.repeatedOffset3 |
| r.repeatedOffset3 = r.repeatedOffset2 |
| r.repeatedOffset2 = r.repeatedOffset1 |
| r.repeatedOffset1 = offset |
| case 4: |
| offset = r.repeatedOffset1 - 1 |
| r.repeatedOffset3 = r.repeatedOffset2 |
| r.repeatedOffset2 = r.repeatedOffset1 |
| r.repeatedOffset1 = offset |
| } |
| } |
| |
| seq++ |
| if seq < seqCount { |
| // Update the states. |
| add, err = rbr.val(ptliteral.bits) |
| if err != nil { |
| return err |
| } |
| literalState = uint32(ptliteral.base) + add |
| |
| add, err = rbr.val(ptmatch.bits) |
| if err != nil { |
| return err |
| } |
| matchState = uint32(ptmatch.base) + add |
| |
| add, err = rbr.val(ptoffset.bits) |
| if err != nil { |
| return err |
| } |
| offsetState = uint32(ptoffset.base) + add |
| } |
| |
| // The next sequence is now in literal, offset, match. |
| |
| if debug { |
| println("literal", literal, "offset", offset, "match", match) |
| } |
| |
| // Copy literal bytes from litbuf. |
| if literal > uint32(len(litbuf)) { |
| return rbr.makeError("literal byte overflow") |
| } |
| if literal > 0 { |
| r.buffer = append(r.buffer, litbuf[:literal]...) |
| litbuf = litbuf[literal:] |
| } |
| |
| if match > 0 { |
| if err := r.copyFromWindow(&rbr, offset, match); err != nil { |
| return err |
| } |
| } |
| } |
| |
| r.buffer = append(r.buffer, litbuf...) |
| |
| if rbr.cnt != 0 { |
| return r.makeError(off, "extraneous data after sequences") |
| } |
| |
| return nil |
| } |
| |
| // Copy match bytes from the decoded output, or the window, at offset. |
| func (r *Reader) copyFromWindow(rbr *reverseBitReader, offset, match uint32) error { |
| if offset == 0 { |
| return rbr.makeError("invalid zero offset") |
| } |
| |
| // Offset may point into the buffer or the window and |
| // match may extend past the end of the initial buffer. |
| // |--r.window--|--r.buffer--| |
| // |<-----offset------| |
| // |------match----------->| |
| bufferOffset := uint32(0) |
| lenBlock := uint32(len(r.buffer)) |
| if lenBlock < offset { |
| lenWindow := r.window.len() |
| copy := offset - lenBlock |
| if copy > lenWindow { |
| return rbr.makeError("offset past window") |
| } |
| windowOffset := lenWindow - copy |
| if copy > match { |
| copy = match |
| } |
| r.buffer = r.window.appendTo(r.buffer, windowOffset, windowOffset+copy) |
| match -= copy |
| } else { |
| bufferOffset = lenBlock - offset |
| } |
| |
| // We are being asked to copy data that we are adding to the |
| // buffer in the same copy. |
| for match > 0 { |
| copy := uint32(len(r.buffer)) - bufferOffset |
| if copy > match { |
| copy = match |
| } |
| r.buffer = append(r.buffer, r.buffer[bufferOffset:bufferOffset+copy]...) |
| match -= copy |
| } |
| return nil |
| } |