| // Copyright 2023 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| // Package zstd provides a decompressor for zstd streams, |
| // described in RFC 8878. It does not support dictionaries. |
| package zstd |
| |
| import ( |
| "encoding/binary" |
| "errors" |
| "fmt" |
| "io" |
| ) |
| |
| // fuzzing is a fuzzer hook set to true when fuzzing. |
| // This is used to reject cases where we don't match zstd. |
| var fuzzing = false |
| |
| // Reader implements [io.Reader] to read a zstd compressed stream. |
| type Reader struct { |
| // The underlying Reader. |
| r io.Reader |
| |
| // Whether we have read the frame header. |
| // This is of interest when buffer is empty. |
| // If true we expect to see a new block. |
| sawFrameHeader bool |
| |
| // Whether the current frame expects a checksum. |
| hasChecksum bool |
| |
| // Whether we have read at least one frame. |
| readOneFrame bool |
| |
| // True if the frame size is not known. |
| frameSizeUnknown bool |
| |
| // The number of uncompressed bytes remaining in the current frame. |
| // If frameSizeUnknown is true, this is not valid. |
| remainingFrameSize uint64 |
| |
| // The number of bytes read from r up to the start of the current |
| // block, for error reporting. |
| blockOffset int64 |
| |
| // Buffered decompressed data. |
| buffer []byte |
| // Current read offset in buffer. |
| off int |
| |
| // The current repeated offsets. |
| repeatedOffset1 uint32 |
| repeatedOffset2 uint32 |
| repeatedOffset3 uint32 |
| |
| // The current Huffman tree used for compressing literals. |
| huffmanTable []uint16 |
| huffmanTableBits int |
| |
| // The window for back references. |
| window window |
| |
| // A buffer available to hold a compressed block. |
| compressedBuf []byte |
| |
| // A buffer for literals. |
| literals []byte |
| |
| // Sequence decode FSE tables. |
| seqTables [3][]fseBaselineEntry |
| seqTableBits [3]uint8 |
| |
| // Buffers for sequence decode FSE tables. |
| seqTableBuffers [3][]fseBaselineEntry |
| |
| // Scratch space used for small reads, to avoid allocation. |
| scratch [16]byte |
| |
| // A scratch table for reading an FSE. Only temporarily valid. |
| fseScratch []fseEntry |
| |
| // For checksum computation. |
| checksum xxhash64 |
| } |
| |
| // NewReader creates a new Reader that decompresses data from the given reader. |
| func NewReader(input io.Reader) *Reader { |
| r := new(Reader) |
| r.Reset(input) |
| return r |
| } |
| |
| // Reset discards the current state and starts reading a new stream from r. |
| // This permits reusing a Reader rather than allocating a new one. |
| func (r *Reader) Reset(input io.Reader) { |
| r.r = input |
| |
| // Several fields are preserved to avoid allocation. |
| // Others are always set before they are used. |
| r.sawFrameHeader = false |
| r.hasChecksum = false |
| r.readOneFrame = false |
| r.frameSizeUnknown = false |
| r.remainingFrameSize = 0 |
| r.blockOffset = 0 |
| r.buffer = r.buffer[:0] |
| r.off = 0 |
| // repeatedOffset1 |
| // repeatedOffset2 |
| // repeatedOffset3 |
| // huffmanTable |
| // huffmanTableBits |
| // window |
| // compressedBuf |
| // literals |
| // seqTables |
| // seqTableBits |
| // seqTableBuffers |
| // scratch |
| // fseScratch |
| } |
| |
| // Read implements [io.Reader]. |
| func (r *Reader) Read(p []byte) (int, error) { |
| if err := r.refillIfNeeded(); err != nil { |
| return 0, err |
| } |
| n := copy(p, r.buffer[r.off:]) |
| r.off += n |
| return n, nil |
| } |
| |
| // ReadByte implements [io.ByteReader]. |
| func (r *Reader) ReadByte() (byte, error) { |
| if err := r.refillIfNeeded(); err != nil { |
| return 0, err |
| } |
| ret := r.buffer[r.off] |
| r.off++ |
| return ret, nil |
| } |
| |
| // refillIfNeeded reads the next block if necessary. |
| func (r *Reader) refillIfNeeded() error { |
| for r.off >= len(r.buffer) { |
| if err := r.refill(); err != nil { |
| return err |
| } |
| r.off = 0 |
| } |
| return nil |
| } |
| |
| // refill reads and decompresses the next block. |
| func (r *Reader) refill() error { |
| if !r.sawFrameHeader { |
| if err := r.readFrameHeader(); err != nil { |
| return err |
| } |
| } |
| return r.readBlock() |
| } |
| |
| // readFrameHeader reads the frame header and prepares to read a block. |
| func (r *Reader) readFrameHeader() error { |
| retry: |
| relativeOffset := 0 |
| |
| // Read magic number. RFC 3.1.1. |
| if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil { |
| // We require that the stream contains at least one frame. |
| if err == io.EOF && !r.readOneFrame { |
| err = io.ErrUnexpectedEOF |
| } |
| return r.wrapError(relativeOffset, err) |
| } |
| |
| if magic := binary.LittleEndian.Uint32(r.scratch[:4]); magic != 0xfd2fb528 { |
| if magic >= 0x184d2a50 && magic <= 0x184d2a5f { |
| // This is a skippable frame. |
| r.blockOffset += int64(relativeOffset) + 4 |
| if err := r.skipFrame(); err != nil { |
| return err |
| } |
| r.readOneFrame = true |
| goto retry |
| } |
| |
| return r.makeError(relativeOffset, "invalid magic number") |
| } |
| |
| relativeOffset += 4 |
| |
| // Read Frame_Header_Descriptor. RFC 3.1.1.1.1. |
| if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil { |
| return r.wrapNonEOFError(relativeOffset, err) |
| } |
| descriptor := r.scratch[0] |
| |
| singleSegment := descriptor&(1<<5) != 0 |
| |
| fcsFieldSize := 1 << (descriptor >> 6) |
| if fcsFieldSize == 1 && !singleSegment { |
| fcsFieldSize = 0 |
| } |
| |
| var windowDescriptorSize int |
| if singleSegment { |
| windowDescriptorSize = 0 |
| } else { |
| windowDescriptorSize = 1 |
| } |
| |
| if descriptor&(1<<3) != 0 { |
| return r.makeError(relativeOffset, "reserved bit set in frame header descriptor") |
| } |
| |
| r.hasChecksum = descriptor&(1<<2) != 0 |
| if r.hasChecksum { |
| r.checksum.reset() |
| } |
| |
| // Dictionary_ID_Flag. RFC 3.1.1.1.1.6. |
| dictionaryIdSize := 0 |
| if dictIdFlag := descriptor & 3; dictIdFlag != 0 { |
| dictionaryIdSize = 1 << (dictIdFlag - 1) |
| } |
| |
| relativeOffset++ |
| |
| headerSize := windowDescriptorSize + dictionaryIdSize + fcsFieldSize |
| |
| if _, err := io.ReadFull(r.r, r.scratch[:headerSize]); err != nil { |
| return r.wrapNonEOFError(relativeOffset, err) |
| } |
| |
| // Figure out the maximum amount of data we need to retain |
| // for backreferences. |
| var windowSize uint64 |
| if !singleSegment { |
| // Window descriptor. RFC 3.1.1.1.2. |
| windowDescriptor := r.scratch[0] |
| exponent := uint64(windowDescriptor >> 3) |
| mantissa := uint64(windowDescriptor & 7) |
| windowLog := exponent + 10 |
| windowBase := uint64(1) << windowLog |
| windowAdd := (windowBase / 8) * mantissa |
| windowSize = windowBase + windowAdd |
| |
| // Default zstd sets limits on the window size. |
| if fuzzing && (windowLog > 31 || windowSize > 1<<27) { |
| return r.makeError(relativeOffset, "windowSize too large") |
| } |
| } |
| |
| // Dictionary_ID. RFC 3.1.1.1.3. |
| if dictionaryIdSize != 0 { |
| dictionaryId := r.scratch[windowDescriptorSize : windowDescriptorSize+dictionaryIdSize] |
| // Allow only zero Dictionary ID. |
| for _, b := range dictionaryId { |
| if b != 0 { |
| return r.makeError(relativeOffset, "dictionaries are not supported") |
| } |
| } |
| } |
| |
| // Frame_Content_Size. RFC 3.1.1.1.4. |
| r.frameSizeUnknown = false |
| r.remainingFrameSize = 0 |
| fb := r.scratch[windowDescriptorSize+dictionaryIdSize:] |
| switch fcsFieldSize { |
| case 0: |
| r.frameSizeUnknown = true |
| case 1: |
| r.remainingFrameSize = uint64(fb[0]) |
| case 2: |
| r.remainingFrameSize = 256 + uint64(binary.LittleEndian.Uint16(fb)) |
| case 4: |
| r.remainingFrameSize = uint64(binary.LittleEndian.Uint32(fb)) |
| case 8: |
| r.remainingFrameSize = binary.LittleEndian.Uint64(fb) |
| default: |
| panic("unreachable") |
| } |
| |
| // RFC 3.1.1.1.2. |
| // When Single_Segment_Flag is set, Window_Descriptor is not present. |
| // In this case, Window_Size is Frame_Content_Size. |
| if singleSegment { |
| windowSize = r.remainingFrameSize |
| } |
| |
| // RFC 8878 3.1.1.1.1.2. permits us to set an 8M max on window size. |
| const maxWindowSize = 8 << 20 |
| if windowSize > maxWindowSize { |
| windowSize = maxWindowSize |
| } |
| |
| relativeOffset += headerSize |
| |
| r.sawFrameHeader = true |
| r.readOneFrame = true |
| r.blockOffset += int64(relativeOffset) |
| |
| // Prepare to read blocks from the frame. |
| r.repeatedOffset1 = 1 |
| r.repeatedOffset2 = 4 |
| r.repeatedOffset3 = 8 |
| r.huffmanTableBits = 0 |
| r.window.reset(int(windowSize)) |
| r.seqTables[0] = nil |
| r.seqTables[1] = nil |
| r.seqTables[2] = nil |
| |
| return nil |
| } |
| |
| // skipFrame skips a skippable frame. RFC 3.1.2. |
| func (r *Reader) skipFrame() error { |
| relativeOffset := 0 |
| |
| if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil { |
| return r.wrapNonEOFError(relativeOffset, err) |
| } |
| |
| relativeOffset += 4 |
| |
| size := binary.LittleEndian.Uint32(r.scratch[:4]) |
| if size == 0 { |
| r.blockOffset += int64(relativeOffset) |
| return nil |
| } |
| |
| if seeker, ok := r.r.(io.Seeker); ok { |
| r.blockOffset += int64(relativeOffset) |
| // Implementations of Seeker do not always detect invalid offsets, |
| // so check that the new offset is valid by comparing to the end. |
| prev, err := seeker.Seek(0, io.SeekCurrent) |
| if err != nil { |
| return r.wrapError(0, err) |
| } |
| end, err := seeker.Seek(0, io.SeekEnd) |
| if err != nil { |
| return r.wrapError(0, err) |
| } |
| if prev > end-int64(size) { |
| r.blockOffset += end - prev |
| return r.makeEOFError(0) |
| } |
| |
| // The new offset is valid, so seek to it. |
| _, err = seeker.Seek(prev+int64(size), io.SeekStart) |
| if err != nil { |
| return r.wrapError(0, err) |
| } |
| r.blockOffset += int64(size) |
| return nil |
| } |
| |
| var skip []byte |
| const chunk = 1 << 20 // 1M |
| for size >= chunk { |
| if len(skip) == 0 { |
| skip = make([]byte, chunk) |
| } |
| if _, err := io.ReadFull(r.r, skip); err != nil { |
| return r.wrapNonEOFError(relativeOffset, err) |
| } |
| relativeOffset += chunk |
| size -= chunk |
| } |
| if size > 0 { |
| if len(skip) == 0 { |
| skip = make([]byte, size) |
| } |
| if _, err := io.ReadFull(r.r, skip); err != nil { |
| return r.wrapNonEOFError(relativeOffset, err) |
| } |
| relativeOffset += int(size) |
| } |
| |
| r.blockOffset += int64(relativeOffset) |
| |
| return nil |
| } |
| |
| // readBlock reads the next block from a frame. |
| func (r *Reader) readBlock() error { |
| relativeOffset := 0 |
| |
| // Read Block_Header. RFC 3.1.1.2. |
| if _, err := io.ReadFull(r.r, r.scratch[:3]); err != nil { |
| return r.wrapNonEOFError(relativeOffset, err) |
| } |
| |
| relativeOffset += 3 |
| |
| header := uint32(r.scratch[0]) | (uint32(r.scratch[1]) << 8) | (uint32(r.scratch[2]) << 16) |
| |
| lastBlock := header&1 != 0 |
| blockType := (header >> 1) & 3 |
| blockSize := int(header >> 3) |
| |
| // Maximum block size is smaller of window size and 128K. |
| // We don't record the window size for a single segment frame, |
| // so just use 128K. RFC 3.1.1.2.3, 3.1.1.2.4. |
| if blockSize > 128<<10 || (r.window.size > 0 && blockSize > r.window.size) { |
| return r.makeError(relativeOffset, "block size too large") |
| } |
| |
| // Handle different block types. RFC 3.1.1.2.2. |
| switch blockType { |
| case 0: |
| r.setBufferSize(blockSize) |
| if _, err := io.ReadFull(r.r, r.buffer); err != nil { |
| return r.wrapNonEOFError(relativeOffset, err) |
| } |
| relativeOffset += blockSize |
| r.blockOffset += int64(relativeOffset) |
| case 1: |
| r.setBufferSize(blockSize) |
| if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil { |
| return r.wrapNonEOFError(relativeOffset, err) |
| } |
| relativeOffset++ |
| v := r.scratch[0] |
| for i := range r.buffer { |
| r.buffer[i] = v |
| } |
| r.blockOffset += int64(relativeOffset) |
| case 2: |
| r.blockOffset += int64(relativeOffset) |
| if err := r.compressedBlock(blockSize); err != nil { |
| return err |
| } |
| r.blockOffset += int64(blockSize) |
| case 3: |
| return r.makeError(relativeOffset, "invalid block type") |
| } |
| |
| if !r.frameSizeUnknown { |
| if uint64(len(r.buffer)) > r.remainingFrameSize { |
| return r.makeError(relativeOffset, "too many uncompressed bytes in frame") |
| } |
| r.remainingFrameSize -= uint64(len(r.buffer)) |
| } |
| |
| if r.hasChecksum { |
| r.checksum.update(r.buffer) |
| } |
| |
| if !lastBlock { |
| r.window.save(r.buffer) |
| } else { |
| if !r.frameSizeUnknown && r.remainingFrameSize != 0 { |
| return r.makeError(relativeOffset, "not enough uncompressed bytes for frame") |
| } |
| // Check for checksum at end of frame. RFC 3.1.1. |
| if r.hasChecksum { |
| if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil { |
| return r.wrapNonEOFError(0, err) |
| } |
| |
| inputChecksum := binary.LittleEndian.Uint32(r.scratch[:4]) |
| dataChecksum := uint32(r.checksum.digest()) |
| if inputChecksum != dataChecksum { |
| return r.wrapError(0, fmt.Errorf("invalid checksum: got %#x want %#x", dataChecksum, inputChecksum)) |
| } |
| |
| r.blockOffset += 4 |
| } |
| r.sawFrameHeader = false |
| } |
| |
| return nil |
| } |
| |
| // setBufferSize sets the decompressed buffer size. |
| // When this is called the buffer is empty. |
| func (r *Reader) setBufferSize(size int) { |
| if cap(r.buffer) < size { |
| need := size - cap(r.buffer) |
| r.buffer = append(r.buffer[:cap(r.buffer)], make([]byte, need)...) |
| } |
| r.buffer = r.buffer[:size] |
| } |
| |
| // zstdError is an error while decompressing. |
| type zstdError struct { |
| offset int64 |
| err error |
| } |
| |
| func (ze *zstdError) Error() string { |
| return fmt.Sprintf("zstd decompression error at %d: %v", ze.offset, ze.err) |
| } |
| |
| func (ze *zstdError) Unwrap() error { |
| return ze.err |
| } |
| |
| func (r *Reader) makeEOFError(off int) error { |
| return r.wrapError(off, io.ErrUnexpectedEOF) |
| } |
| |
| func (r *Reader) wrapNonEOFError(off int, err error) error { |
| if err == io.EOF { |
| err = io.ErrUnexpectedEOF |
| } |
| return r.wrapError(off, err) |
| } |
| |
| func (r *Reader) makeError(off int, msg string) error { |
| return r.wrapError(off, errors.New(msg)) |
| } |
| |
| func (r *Reader) wrapError(off int, err error) error { |
| if err == io.EOF { |
| return err |
| } |
| return &zstdError{r.blockOffset + int64(off), err} |
| } |