blob: 83ca3b9fabea14b8037c1888de63f299597d18ea [file] [log] [blame]
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package codec
import (
"encoding/binary"
"fmt"
"math"
"reflect"
)
// An Encoder encodes Go values into a sequence of bytes.
// To use an Encoder:
// - Create one with NewEncoder.
// - Call the Encode method one or more times.
// - Retrieve the resulting bytes by calling Bytes.
type Encoder struct {
buf []byte
typeNums map[reflect.Type]int
seen map[any]uint64 // for references; see StartStruct
}
// NewEncoder returns an Encoder.
func NewEncoder() *Encoder {
return &Encoder{
typeNums: map[reflect.Type]int{},
seen: map[any]uint64{},
}
}
// Encode encodes x.
func (e *Encoder) Encode(x any) (err error) {
defer handlePanic(&err)
e.EncodeAny(x)
return nil
}
func handlePanic(errp *error) {
r := recover()
if r == nil {
// No panic; do nothing.
return
}
// If the panic is not from this package, re-panic.
cerr, ok := r.(codecError)
if !ok {
panic(r)
}
// Otherwise, set errp.
*errp = cerr.err
}
// codecError wraps errors from this package so handlePanic
// can distinguish them.
type codecError struct {
err error
}
func fail(err error) {
panic(codecError{err})
}
func failf(format string, args ...any) {
fail(fmt.Errorf(format, args...))
}
// Bytes returns the encoded byte slice.
func (e *Encoder) Bytes() []byte {
data := e.buf // remember the data
e.buf = nil // start with a fresh buffer
e.encodeInitial() // encode metadata
return append(e.buf, data...) // concatenate metadata and data
}
// A Decoder decodes a Go value encoded by an Encoder.
// To use a Decoder:
// - Pass NewDecoder the return value of Encoder.Bytes.
// - Call the Decode method once for each call to Encoder.Encode.
type Decoder struct {
buf []byte
i int
typeInfos []*typeInfo
refs []any // list of struct pointers, in the order seen
}
// NewDecoder returns a Decoder for the given bytes.
func NewDecoder(data []byte) *Decoder {
return &Decoder{buf: data, i: 0}
}
// Decode decodes a value encoded with Encoder.Encode.
func (d *Decoder) Decode() (_ any, err error) {
defer handlePanic(&err)
if d.typeInfos == nil {
d.decodeInitial()
}
return d.DecodeAny(), nil
}
func badcode(c byte) {
failf("bad code: %d", c)
}
//////////////// Reading From and Writing To the Buffer
func (e *Encoder) writeByte(b byte) {
e.buf = append(e.buf, b)
}
func (d *Decoder) readByte() byte {
b := d.curByte()
d.i++
return b
}
// curByte returns the next byte to be read
// without actually consuming it.
func (d *Decoder) curByte() byte {
return d.buf[d.i]
}
func (e *Encoder) writeBytes(b []byte) {
e.buf = append(e.buf, b...)
}
// readBytes reads and returns the given number of bytes.
// It fails if there are not enough bytes in the input.
func (d *Decoder) readBytes(n int) []byte {
d.i += n
return d.buf[d.i-n : d.i]
}
func (e *Encoder) writeString(s string) {
e.buf = append(e.buf, s...)
}
func (d *Decoder) readString(len int) string {
return string(d.readBytes(len))
}
func (e *Encoder) writeUint32(u uint32) {
var buf [4]byte
binary.LittleEndian.PutUint32(buf[:], u)
e.writeBytes(buf[:])
}
func (d *Decoder) readUint32() uint32 {
return binary.LittleEndian.Uint32(d.readBytes(4))
}
func (e *Encoder) writeUint64(u uint64) {
var buf [8]byte
binary.LittleEndian.PutUint64(buf[:], u)
e.writeBytes(buf[:])
}
func (d *Decoder) readUint64() uint64 {
return binary.LittleEndian.Uint64(d.readBytes(8))
}
//////////////// Encoding Scheme
// Byte codes that begin each encoded value.
// See the package doc for their descriptions.
const (
nilCode = 255 - iota // a nil value
// reserve a few values for future use
reserved1
reserved2
reserved3
reserved4
reserved5
reserved6
reserved7
nBytesCode // uint n follows, then n bytes
nValuesCode // uint n follows, then n values
refCode // uint n follows, referring to a previous value
startCode // start of a value of indeterminate length
endCode // end of a value that began with start
// Bytes less than endCode represent themselves.
)
// EncodeUint encodes a uint64.
func (e *Encoder) EncodeUint(u uint64) {
switch {
case u < endCode:
// u fits into the initial byte.
e.writeByte(byte(u))
case u <= math.MaxUint32:
// Encode as a sequence of 4 bytes, the little-endian representation of
// a uint32.
e.writeByte(nBytesCode)
e.writeByte(4)
e.writeUint32(uint32(u))
default:
// Encode as a sequence of 8 bytes, the little-endian representation of
// a uint64.
e.writeByte(nBytesCode)
e.writeByte(8)
e.writeUint64(u)
}
}
// DecodeUint decodes a uint64.
func (d *Decoder) DecodeUint() uint64 {
b := d.readByte()
switch {
case b < endCode:
return uint64(b)
case b == nBytesCode:
switch n := d.readByte(); n {
case 4:
return uint64(d.readUint32())
case 8:
return d.readUint64()
default:
failf("DecodeUint: bad length %d", n)
}
default:
badcode(b)
}
return 0
}
// EncodeInt encodes a signed integer.
func (e *Encoder) EncodeInt(i int64) {
// Encode small negative as well as small positive integers efficiently.
// Algorithm from gob; see "Encoding Details" at https://pkg.go.dev/encoding/gob.
var u uint64
if i < 0 {
u = (^uint64(i) << 1) | 1 // complement i, bit 0 is 1
} else {
u = (uint64(i) << 1) // do not complement i, bit 0 is 0
}
e.EncodeUint(u)
}
// DecodeInt decodes a signed integer.
func (d *Decoder) DecodeInt() int64 {
u := d.DecodeUint()
if u&1 == 1 {
return int64(^(u >> 1))
}
return int64(u >> 1)
}
// encodeLen encodes the length of a byte sequence.
func (e *Encoder) encodeLen(n int) {
e.writeByte(nBytesCode)
e.EncodeUint(uint64(n))
}
// decodeLen decodes the length of a byte sequence.
func (d *Decoder) decodeLen() int {
if b := d.readByte(); b != nBytesCode {
badcode(b)
}
return int(d.DecodeUint())
}
// EncodeBytes encodes a byte slice.
func (e *Encoder) EncodeBytes(b []byte) {
e.encodeLen(len(b))
e.writeBytes(b)
}
// DecodeBytes decodes a byte slice.
// It does no copying.
func (d *Decoder) DecodeBytes() []byte {
return d.readBytes(d.decodeLen())
}
// EncodeString encodes a string.
func (e *Encoder) EncodeString(s string) {
e.encodeLen(len(s))
e.writeString(s)
}
// DecodeString decodes a string.
func (d *Decoder) DecodeString() string {
return d.readString(d.decodeLen())
}
// EncodeBool encodes a bool.
func (e *Encoder) EncodeBool(b bool) {
if b {
e.writeByte(1)
} else {
e.writeByte(0)
}
}
// DecodeBool decodes a bool.
func (d *Decoder) DecodeBool() bool {
b := d.readByte()
switch b {
case 0:
return false
case 1:
return true
default:
failf("bad bool: %d", b)
return false
}
}
// EncodeFloat encodes a float64.
func (e *Encoder) EncodeFloat(f float64) {
e.EncodeUint(math.Float64bits(f))
}
// DecodeFloat decodes a float64.
func (d *Decoder) DecodeFloat() float64 {
return math.Float64frombits(d.DecodeUint())
}
func (e *Encoder) EncodeNil() {
e.writeByte(nilCode)
}
// StartList should be called before encoding any sequence of variable-length
// values.
func (e *Encoder) StartList(len int) {
e.writeByte(nValuesCode)
e.EncodeUint(uint64(len))
}
// StartList should be called before decoding any sequence of variable-length
// values. It returns -1 if the encoded list was nil. Otherwise, it returns the
// length of the sequence.
func (d *Decoder) StartList() int {
switch b := d.readByte(); b {
case nilCode:
return -1
case nValuesCode:
return int(d.DecodeUint())
default:
badcode(b)
return 0
}
}
//////////////// Struct Support
// StartStruct should be called before encoding a struct pointer. The isNil
// argument says whether the pointer is nil. The p argument is the struct
// pointer. If StartStruct returns false, encoding should not proceed.
func (e *Encoder) StartStruct(isNil bool, p any) bool {
if isNil {
e.EncodeNil()
return false
}
if u, ok := e.seen[p]; ok {
// If we have already seen this struct pointer,
// encode a reference to it.
e.writeByte(refCode)
e.EncodeUint(u)
return false // Caller should not encode the struct.
}
// Note that we have seen this pointer, and assign it
// its position in the encoding.
e.seen[p] = uint64(len(e.seen))
e.writeByte(startCode)
return true
}
// StartStruct should be called before decoding a struct pointer. If it returns
// false, decoding should not proceed. If it returns true and the second return
// value is non-nil, it is a reference to a previous value and should be used
// instead of proceeding with decoding.
func (d *Decoder) StartStruct() (bool, any) {
b := d.readByte()
switch b {
case nilCode: // do not set the pointer
return false, nil
case refCode:
u := d.DecodeUint()
return true, d.refs[u]
case startCode:
return true, nil
default:
badcode(b)
return false, nil // unreached, needed for compiler
}
}
// StoreRef should be called by a struct decoder immediately after it allocates
// a struct pointer.
func (d *Decoder) StoreRef(p any) {
d.refs = append(d.refs, p)
}
// EndStruct should be called after encoding a struct.
func (e *Encoder) EndStruct() {
e.writeByte(endCode)
}
// NextStructField should be called by a struct decoder in a loop.
// It returns the field number of the next encoded field, or -1
// if there are no more fields.
func (d *Decoder) NextStructField() int {
if d.curByte() == endCode {
d.readByte() // consume the end byte
return -1
}
return int(d.DecodeUint())
}
// UnknownField should be called by a struct decoder
// when it sees a field number that it doesn't know.
func (d *Decoder) UnknownField(typeName string, num int) {
d.skip()
}
// skip reads past a value in the input.
func (d *Decoder) skip() {
b := d.readByte()
if b < endCode {
// Small integers represent themselves in a single byte.
return
}
switch b {
case nilCode:
// Nothing follows.
case nBytesCode:
// A uint n and n bytes follow. It is efficient to call readBytes here
// because it does no allocation.
d.readBytes(int(d.DecodeUint()))
case nValuesCode:
// A uint n and n values follow.
n := int(d.DecodeUint())
for i := 0; i < n; i++ {
d.skip()
}
case refCode:
// A uint follows.
d.DecodeUint()
case startCode:
// Skip until we see endCode.
for d.curByte() != endCode {
d.skip()
}
d.readByte() // consume the endCode byte
default:
badcode(b)
}
}
//////////////// Encoding Arbitrary Values
// EncodeAny encodes a Go type. The type must have
// been registered with Register.
func (e *Encoder) EncodeAny(x any) {
// Encode a nil interface value with a zero.
if x == nil {
e.writeByte(0)
return
}
// Find the typeInfo for the type, which has the encoder.
t := reflect.TypeOf(x)
ti := typeInfosByType[t]
if ti == nil {
failf("unregistered type %s", t)
}
// Assign a number to the type if we haven't already.
num, ok := e.typeNums[t]
if !ok {
num = len(e.typeNums)
e.typeNums[t] = num
}
// Encode a pair (2-element list) of the type number and the encoded value.
e.StartList(2)
e.EncodeUint(uint64(num))
ti.encode(e, x)
}
// DecodeAny decodes a value encoded by EncodeAny.
func (d *Decoder) DecodeAny() any {
// If we're looking at a zero, this is a nil interface.
if d.curByte() == 0 {
d.readByte() // consume the byte
return nil
}
// Otherwise, we should have a two-item list: type number and value.
n := d.StartList()
if n != 2 {
failf("DecodeAny: bad list length %d", n)
}
num := d.DecodeUint()
if num >= uint64(len(d.typeInfos)) {
failf("type number %d out of range", num)
}
ti := d.typeInfos[num]
return ti.decode(d)
}
// encodeInitial encodes metadata that appears at the start of the
// encoded byte slice.
func (e *Encoder) encodeInitial() {
// Encode the list of type names we saw, in the order we
// assigned numbers to them.
names := make([]string, len(e.typeNums))
for t, num := range e.typeNums {
names[num] = typeName(t)
}
e.StartList(len(names))
for _, n := range names {
e.EncodeString(n)
}
}
// decodeInitial decodes metadata that appears at the start of the
// encoded byte slice.
func (d *Decoder) decodeInitial() {
// Decode the list of type names. The number of a type is its position in
// the list.
n := d.StartList()
d.typeInfos = make([]*typeInfo, n)
for num := 0; num < n; num++ {
name := d.DecodeString()
ti := typeInfosByName[name]
if ti == nil {
failf("unregistered type: %s", name)
}
d.typeInfos[num] = ti
}
}
//////////////// Type Registry
// All types subject to encoding must be registered, even
// builtin types.
// A typeInfo describes how to encode and decode a type.
type typeInfo struct {
name string // e.g. "go/ast.File"
encode encodeFunc
decode decodeFunc
}
type (
encodeFunc func(*Encoder, any)
decodeFunc func(*Decoder) any
)
var (
typeInfosByName = map[string]*typeInfo{}
typeInfosByType = map[reflect.Type]*typeInfo{}
)
// Register records the type of x for use by Encoders and Decoders.
func Register(x any, enc encodeFunc, dec decodeFunc) {
t := reflect.TypeOf(x)
tn := typeName(t)
if _, ok := typeInfosByName[tn]; ok {
panic(fmt.Sprintf("codec.Register: duplicate type %s (typeName=%q)", t, tn))
}
ti := &typeInfo{
name: tn,
encode: enc,
decode: dec,
}
typeInfosByName[ti.name] = ti
typeInfosByType[t] = ti
}
// typeName returns the full, qualified name for a type.
func typeName(t reflect.Type) string {
if t.PkgPath() == "" {
return t.String()
}
return t.PkgPath() + "." + t.Name()
}
var builtinTypes []reflect.Type
func init() {
Register(int64(0),
func(e *Encoder, x any) { e.EncodeInt(x.(int64)) },
func(d *Decoder) any { return d.DecodeInt() })
Register(uint64(0),
func(e *Encoder, x any) { e.EncodeUint(x.(uint64)) },
func(d *Decoder) any { return d.DecodeUint() })
Register(int(0),
func(e *Encoder, x any) { e.EncodeInt(int64(x.(int))) },
func(d *Decoder) any { return int(d.DecodeInt()) })
Register(float64(0),
func(e *Encoder, x any) { e.EncodeFloat(x.(float64)) },
func(d *Decoder) any { return d.DecodeFloat() })
Register(false,
func(e *Encoder, x any) { e.EncodeBool(x.(bool)) },
func(d *Decoder) any { return d.DecodeBool() })
Register("",
func(e *Encoder, x any) { e.EncodeString(x.(string)) },
func(d *Decoder) any { return d.DecodeString() })
Register([]byte(nil),
func(e *Encoder, x any) { e.EncodeBytes(x.([]byte)) },
func(d *Decoder) any { return d.DecodeBytes() })
for t := range typeInfosByType {
builtinTypes = append(builtinTypes, t)
}
}