// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Package zstd provides a decompressor for zstd streams,
// described in RFC 8878. It does not support dictionaries.
package zstd

import (
	"encoding/binary"
	"errors"
	"fmt"
	"io"
)

// fuzzing is a fuzzer hook set to true when fuzzing.
// This is used to reject cases where we don't match zstd.
var fuzzing = false

// Reader implements [io.Reader] to read a zstd compressed stream.
type Reader struct {
	// The underlying Reader.
	r io.Reader

	// Whether we have read the frame header.
	// This is of interest when buffer is empty.
	// If true we expect to see a new block.
	sawFrameHeader bool

	// Whether the current frame expects a checksum.
	hasChecksum bool

	// Whether we have read at least one frame.
	readOneFrame bool

	// True if the frame size is not known.
	frameSizeUnknown bool

	// The number of uncompressed bytes remaining in the current frame.
	// If frameSizeUnknown is true, this is not valid.
	remainingFrameSize uint64

	// The number of bytes read from r up to the start of the current
	// block, for error reporting.
	blockOffset int64

	// Buffered decompressed data.
	buffer []byte
	// Current read offset in buffer.
	off int

	// The current repeated offsets.
	repeatedOffset1 uint32
	repeatedOffset2 uint32
	repeatedOffset3 uint32

	// The current Huffman tree used for compressing literals.
	huffmanTable     []uint16
	huffmanTableBits int

	// The window for back references.
	windowSize int    // maximum required window size
	window     []byte // window data

	// A buffer available to hold a compressed block.
	compressedBuf []byte

	// A buffer for literals.
	literals []byte

	// Sequence decode FSE tables.
	seqTables    [3][]fseBaselineEntry
	seqTableBits [3]uint8

	// Buffers for sequence decode FSE tables.
	seqTableBuffers [3][]fseBaselineEntry

	// Scratch space used for small reads, to avoid allocation.
	scratch [16]byte

	// A scratch table for reading an FSE. Only temporarily valid.
	fseScratch []fseEntry

	// For checksum computation.
	checksum xxhash64
}

// NewReader creates a new Reader that decompresses data from the given reader.
func NewReader(input io.Reader) *Reader {
	r := new(Reader)
	r.Reset(input)
	return r
}

// Reset discards the current state and starts reading a new stream from r.
// This permits reusing a Reader rather than allocating a new one.
func (r *Reader) Reset(input io.Reader) {
	r.r = input

	// Several fields are preserved to avoid allocation.
	// Others are always set before they are used.
	r.sawFrameHeader = false
	r.hasChecksum = false
	r.readOneFrame = false
	r.frameSizeUnknown = false
	r.remainingFrameSize = 0
	r.blockOffset = 0
	// buffer
	r.off = 0
	// repeatedOffset1
	// repeatedOffset2
	// repeatedOffset3
	// huffmanTable
	// huffmanTableBits
	// windowSize
	// window
	// compressedBuf
	// literals
	// seqTables
	// seqTableBits
	// seqTableBuffers
	// scratch
	// fseScratch
}

// Read implements [io.Reader].
func (r *Reader) Read(p []byte) (int, error) {
	if err := r.refillIfNeeded(); err != nil {
		return 0, err
	}
	n := copy(p, r.buffer[r.off:])
	r.off += n
	return n, nil
}

// ReadByte implements [io.ByteReader].
func (r *Reader) ReadByte() (byte, error) {
	if err := r.refillIfNeeded(); err != nil {
		return 0, err
	}
	ret := r.buffer[r.off]
	r.off++
	return ret, nil
}

// refillIfNeeded reads the next block if necessary.
func (r *Reader) refillIfNeeded() error {
	for r.off >= len(r.buffer) {
		if err := r.refill(); err != nil {
			return err
		}
		r.off = 0
	}
	return nil
}

// refill reads and decompresses the next block.
func (r *Reader) refill() error {
	if !r.sawFrameHeader {
		if err := r.readFrameHeader(); err != nil {
			return err
		}
	}
	return r.readBlock()
}

// readFrameHeader reads the frame header and prepares to read a block.
func (r *Reader) readFrameHeader() error {
retry:
	relativeOffset := 0

	// Read magic number. RFC 3.1.1.
	if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
		// We require that the stream contain at least one frame.
		if err == io.EOF && !r.readOneFrame {
			err = io.ErrUnexpectedEOF
		}
		return r.wrapError(relativeOffset, err)
	}

	if magic := binary.LittleEndian.Uint32(r.scratch[:4]); magic != 0xfd2fb528 {
		if magic >= 0x184d2a50 && magic <= 0x184d2a5f {
			// This is a skippable frame.
			r.blockOffset += int64(relativeOffset) + 4
			if err := r.skipFrame(); err != nil {
				return err
			}
			goto retry
		}

		return r.makeError(relativeOffset, "invalid magic number")
	}

	relativeOffset += 4

	// Read Frame_Header_Descriptor. RFC 3.1.1.1.1.
	if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil {
		return r.wrapNonEOFError(relativeOffset, err)
	}
	descriptor := r.scratch[0]

	singleSegment := descriptor&(1<<5) != 0

	fcsFieldSize := 1 << (descriptor >> 6)
	if fcsFieldSize == 1 && !singleSegment {
		fcsFieldSize = 0
	}

	var windowDescriptorSize int
	if singleSegment {
		windowDescriptorSize = 0
	} else {
		windowDescriptorSize = 1
	}

	if descriptor&(1<<3) != 0 {
		return r.makeError(relativeOffset, "reserved bit set in frame header descriptor")
	}

	r.hasChecksum = descriptor&(1<<2) != 0
	if r.hasChecksum {
		r.checksum.reset()
	}

	if descriptor&3 != 0 {
		return r.makeError(relativeOffset, "dictionaries are not supported")
	}

	relativeOffset++

	headerSize := windowDescriptorSize + fcsFieldSize

	if _, err := io.ReadFull(r.r, r.scratch[:headerSize]); err != nil {
		return r.wrapNonEOFError(relativeOffset, err)
	}

	// Figure out the maximum amount of data we need to retain
	// for backreferences.

	if singleSegment {
		// No window required, as all the data is in a single buffer.
		r.windowSize = 0
	} else {
		// Window descriptor. RFC 3.1.1.1.2.
		windowDescriptor := r.scratch[0]
		exponent := uint64(windowDescriptor >> 3)
		mantissa := uint64(windowDescriptor & 7)
		windowLog := exponent + 10
		windowBase := uint64(1) << windowLog
		windowAdd := (windowBase / 8) * mantissa
		windowSize := windowBase + windowAdd

		// Default zstd sets limits on the window size.
		if fuzzing && (windowLog > 31 || windowSize > 1<<27) {
			return r.makeError(relativeOffset, "windowSize too large")
		}

		// RFC 8878 permits us to set an 8M max on window size.
		if windowSize > 8<<20 {
			windowSize = 8 << 20
		}

		r.windowSize = int(windowSize)
	}

	// Frame_Content_Size. RFC 3.1.1.4.
	r.frameSizeUnknown = false
	r.remainingFrameSize = 0
	fb := r.scratch[windowDescriptorSize:]
	switch fcsFieldSize {
	case 0:
		r.frameSizeUnknown = true
	case 1:
		r.remainingFrameSize = uint64(fb[0])
	case 2:
		r.remainingFrameSize = 256 + uint64(binary.LittleEndian.Uint16(fb))
	case 4:
		r.remainingFrameSize = uint64(binary.LittleEndian.Uint32(fb))
	case 8:
		r.remainingFrameSize = binary.LittleEndian.Uint64(fb)
	default:
		panic("unreachable")
	}

	relativeOffset += headerSize

	r.sawFrameHeader = true
	r.readOneFrame = true
	r.blockOffset += int64(relativeOffset)

	// Prepare to read blocks from the frame.
	r.repeatedOffset1 = 1
	r.repeatedOffset2 = 4
	r.repeatedOffset3 = 8
	r.huffmanTableBits = 0
	r.window = r.window[:0]
	r.seqTables[0] = nil
	r.seqTables[1] = nil
	r.seqTables[2] = nil

	return nil
}

// skipFrame skips a skippable frame. RFC 3.1.2.
func (r *Reader) skipFrame() error {
	relativeOffset := 0

	if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
		return r.wrapNonEOFError(relativeOffset, err)
	}

	relativeOffset += 4

	size := binary.LittleEndian.Uint32(r.scratch[:4])

	if seeker, ok := r.r.(io.Seeker); ok {
		if _, err := seeker.Seek(int64(size), io.SeekCurrent); err != nil {
			return err
		}
		r.blockOffset += int64(relativeOffset) + int64(size)
		return nil
	}

	var skip []byte
	const chunk = 1 << 20 // 1M
	for size >= chunk {
		if len(skip) == 0 {
			skip = make([]byte, chunk)
		}
		if _, err := io.ReadFull(r.r, skip); err != nil {
			return r.wrapNonEOFError(relativeOffset, err)
		}
		relativeOffset += chunk
		size -= chunk
	}
	if size > 0 {
		if len(skip) == 0 {
			skip = make([]byte, size)
		}
		if _, err := io.ReadFull(r.r, skip); err != nil {
			return r.wrapNonEOFError(relativeOffset, err)
		}
		relativeOffset += int(size)
	}

	r.blockOffset += int64(relativeOffset)

	return nil
}

// readBlock reads the next block from a frame.
func (r *Reader) readBlock() error {
	relativeOffset := 0

	// Read Block_Header. RFC 3.1.1.2.
	if _, err := io.ReadFull(r.r, r.scratch[:3]); err != nil {
		return r.wrapNonEOFError(relativeOffset, err)
	}

	relativeOffset += 3

	header := uint32(r.scratch[0]) | (uint32(r.scratch[1]) << 8) | (uint32(r.scratch[2]) << 16)

	lastBlock := header&1 != 0
	blockType := (header >> 1) & 3
	blockSize := int(header >> 3)

	// Maximum block size is smaller of window size and 128K.
	// We don't record the window size for a single segment frame,
	// so just use 128K. RFC 3.1.1.2.3, 3.1.1.2.4.
	if blockSize > 128<<10 || (r.windowSize > 0 && blockSize > r.windowSize) {
		return r.makeError(relativeOffset, "block size too large")
	}

	// Handle different block types. RFC 3.1.1.2.2.
	switch blockType {
	case 0:
		r.setBufferSize(blockSize)
		if _, err := io.ReadFull(r.r, r.buffer); err != nil {
			return r.wrapNonEOFError(relativeOffset, err)
		}
		relativeOffset += blockSize
		r.blockOffset += int64(relativeOffset)
	case 1:
		r.setBufferSize(blockSize)
		if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil {
			return r.wrapNonEOFError(relativeOffset, err)
		}
		relativeOffset++
		v := r.scratch[0]
		for i := range r.buffer {
			r.buffer[i] = v
		}
		r.blockOffset += int64(relativeOffset)
	case 2:
		r.blockOffset += int64(relativeOffset)
		if err := r.compressedBlock(blockSize); err != nil {
			return err
		}
		r.blockOffset += int64(blockSize)
	case 3:
		return r.makeError(relativeOffset, "invalid block type")
	}

	if !r.frameSizeUnknown {
		if uint64(len(r.buffer)) > r.remainingFrameSize {
			return r.makeError(relativeOffset, "too many uncompressed bytes in frame")
		}
		r.remainingFrameSize -= uint64(len(r.buffer))
	}

	if r.hasChecksum {
		r.checksum.update(r.buffer)
	}

	if !lastBlock {
		r.saveWindow(r.buffer)
	} else {
		if !r.frameSizeUnknown && r.remainingFrameSize != 0 {
			return r.makeError(relativeOffset, "not enough uncompressed bytes for frame")
		}
		// Check for checksum at end of frame. RFC 3.1.1.
		if r.hasChecksum {
			if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
				return r.wrapNonEOFError(0, err)
			}

			inputChecksum := binary.LittleEndian.Uint32(r.scratch[:4])
			dataChecksum := uint32(r.checksum.digest())
			if inputChecksum != dataChecksum {
				return r.wrapError(0, fmt.Errorf("invalid checksum: got %#x want %#x", dataChecksum, inputChecksum))
			}

			r.blockOffset += 4
		}
		r.sawFrameHeader = false
	}

	return nil
}

// setBufferSize sets the decompressed buffer size.
// When this is called the buffer is empty.
func (r *Reader) setBufferSize(size int) {
	if cap(r.buffer) < size {
		need := size - cap(r.buffer)
		r.buffer = append(r.buffer[:cap(r.buffer)], make([]byte, need)...)
	}
	r.buffer = r.buffer[:size]
}

// saveWindow saves bytes in the backreference window.
// TODO: use a circular buffer for less data movement.
func (r *Reader) saveWindow(buf []byte) {
	if r.windowSize == 0 {
		return
	}

	if len(buf) >= r.windowSize {
		from := len(buf) - r.windowSize
		r.window = append(r.window[:0], buf[from:]...)
		return
	}

	keep := r.windowSize - len(buf) // must be positive
	if keep < len(r.window) {
		remove := len(r.window) - keep
		copy(r.window[:], r.window[remove:])
	}

	r.window = append(r.window, buf...)
}

// zstdError is an error while decompressing.
type zstdError struct {
	offset int64
	err    error
}

func (ze *zstdError) Error() string {
	return fmt.Sprintf("zstd decompression error at %d: %v", ze.offset, ze.err)
}

func (ze *zstdError) Unwrap() error {
	return ze.err
}

func (r *Reader) makeEOFError(off int) error {
	return r.wrapError(off, io.ErrUnexpectedEOF)
}

func (r *Reader) wrapNonEOFError(off int, err error) error {
	if err == io.EOF {
		err = io.ErrUnexpectedEOF
	}
	return r.wrapError(off, err)
}

func (r *Reader) makeError(off int, msg string) error {
	return r.wrapError(off, errors.New(msg))
}

func (r *Reader) wrapError(off int, err error) error {
	if err == io.EOF {
		return err
	}
	return &zstdError{r.blockOffset + int64(off), err}
}
