blob: c70cc21f5710421a890fd462ceb1863782aaa0aa [file] [log] [blame]
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"strings"
"text/template"
)
type WireType string
const (
WireVarint WireType = "Varint"
WireFixed32 WireType = "Fixed32"
WireFixed64 WireType = "Fixed64"
WireBytes WireType = "Bytes"
WireGroup WireType = "Group"
)
func (w WireType) Expr() Expr {
if w == WireGroup {
return "protowire.StartGroupType"
}
return "protowire." + Expr(w) + "Type"
}
func (w WireType) Packable() bool {
return w == WireVarint || w == WireFixed32 || w == WireFixed64
}
func (w WireType) ConstSize() bool {
return w == WireFixed32 || w == WireFixed64
}
type GoType string
var GoTypes = []GoType{
GoBool,
GoInt32,
GoUint32,
GoInt64,
GoUint64,
GoFloat32,
GoFloat64,
GoString,
GoBytes,
}
const (
GoBool = "bool"
GoInt32 = "int32"
GoUint32 = "uint32"
GoInt64 = "int64"
GoUint64 = "uint64"
GoFloat32 = "float32"
GoFloat64 = "float64"
GoString = "string"
GoBytes = "[]byte"
)
func (g GoType) Zero() Expr {
switch g {
case GoBool:
return "false"
case GoString:
return `""`
case GoBytes:
return "nil"
}
return "0"
}
// Kind is the reflect.Kind of the type.
func (g GoType) Kind() Expr {
if g == "" || g == GoBytes {
return ""
}
return "reflect." + Expr(strings.ToUpper(string(g[:1]))+string(g[1:]))
}
// PointerMethod is the "internal/impl".pointer method used to access a pointer to this type.
func (g GoType) PointerMethod() Expr {
if g == GoBytes {
return "Bytes"
}
return Expr(strings.ToUpper(string(g[:1])) + string(g[1:]))
}
type ProtoKind struct {
Name string
WireType WireType
// Conversions to/from protoreflect.Value.
ToValue Expr
FromValue Expr
// Conversions to/from generated structures.
GoType GoType
ToGoType Expr
ToGoTypeNoZero Expr
FromGoType Expr
NoPointer bool
NoValueCodec bool
}
func (k ProtoKind) Expr() Expr {
return "protoreflect." + Expr(k.Name) + "Kind"
}
var ProtoKinds = []ProtoKind{
{
Name: "Bool",
WireType: WireVarint,
ToValue: "protoreflect.ValueOfBool(protowire.DecodeBool(v))",
FromValue: "protowire.EncodeBool(v.Bool())",
GoType: GoBool,
ToGoType: "protowire.DecodeBool(v)",
FromGoType: "protowire.EncodeBool(v)",
},
{
Name: "Enum",
WireType: WireVarint,
ToValue: "protoreflect.ValueOfEnum(protoreflect.EnumNumber(v))",
FromValue: "uint64(v.Enum())",
},
{
Name: "Int32",
WireType: WireVarint,
ToValue: "protoreflect.ValueOfInt32(int32(v))",
FromValue: "uint64(int32(v.Int()))",
GoType: GoInt32,
ToGoType: "int32(v)",
FromGoType: "uint64(v)",
},
{
Name: "Sint32",
WireType: WireVarint,
ToValue: "protoreflect.ValueOfInt32(int32(protowire.DecodeZigZag(v & math.MaxUint32)))",
FromValue: "protowire.EncodeZigZag(int64(int32(v.Int())))",
GoType: GoInt32,
ToGoType: "int32(protowire.DecodeZigZag(v & math.MaxUint32))",
FromGoType: "protowire.EncodeZigZag(int64(v))",
},
{
Name: "Uint32",
WireType: WireVarint,
ToValue: "protoreflect.ValueOfUint32(uint32(v))",
FromValue: "uint64(uint32(v.Uint()))",
GoType: GoUint32,
ToGoType: "uint32(v)",
FromGoType: "uint64(v)",
},
{
Name: "Int64",
WireType: WireVarint,
ToValue: "protoreflect.ValueOfInt64(int64(v))",
FromValue: "uint64(v.Int())",
GoType: GoInt64,
ToGoType: "int64(v)",
FromGoType: "uint64(v)",
},
{
Name: "Sint64",
WireType: WireVarint,
ToValue: "protoreflect.ValueOfInt64(protowire.DecodeZigZag(v))",
FromValue: "protowire.EncodeZigZag(v.Int())",
GoType: GoInt64,
ToGoType: "protowire.DecodeZigZag(v)",
FromGoType: "protowire.EncodeZigZag(v)",
},
{
Name: "Uint64",
WireType: WireVarint,
ToValue: "protoreflect.ValueOfUint64(v)",
FromValue: "v.Uint()",
GoType: GoUint64,
ToGoType: "v",
FromGoType: "v",
},
{
Name: "Sfixed32",
WireType: WireFixed32,
ToValue: "protoreflect.ValueOfInt32(int32(v))",
FromValue: "uint32(v.Int())",
GoType: GoInt32,
ToGoType: "int32(v)",
FromGoType: "uint32(v)",
},
{
Name: "Fixed32",
WireType: WireFixed32,
ToValue: "protoreflect.ValueOfUint32(uint32(v))",
FromValue: "uint32(v.Uint())",
GoType: GoUint32,
ToGoType: "v",
FromGoType: "v",
},
{
Name: "Float",
WireType: WireFixed32,
ToValue: "protoreflect.ValueOfFloat32(math.Float32frombits(uint32(v)))",
FromValue: "math.Float32bits(float32(v.Float()))",
GoType: GoFloat32,
ToGoType: "math.Float32frombits(v)",
FromGoType: "math.Float32bits(v)",
},
{
Name: "Sfixed64",
WireType: WireFixed64,
ToValue: "protoreflect.ValueOfInt64(int64(v))",
FromValue: "uint64(v.Int())",
GoType: GoInt64,
ToGoType: "int64(v)",
FromGoType: "uint64(v)",
},
{
Name: "Fixed64",
WireType: WireFixed64,
ToValue: "protoreflect.ValueOfUint64(v)",
FromValue: "v.Uint()",
GoType: GoUint64,
ToGoType: "v",
FromGoType: "v",
},
{
Name: "Double",
WireType: WireFixed64,
ToValue: "protoreflect.ValueOfFloat64(math.Float64frombits(v))",
FromValue: "math.Float64bits(v.Float())",
GoType: GoFloat64,
ToGoType: "math.Float64frombits(v)",
FromGoType: "math.Float64bits(v)",
},
{
Name: "String",
WireType: WireBytes,
ToValue: "protoreflect.ValueOfString(string(v))",
FromValue: "v.String()",
GoType: GoString,
ToGoType: "string(v)",
FromGoType: "v",
},
{
Name: "Bytes",
WireType: WireBytes,
ToValue: "protoreflect.ValueOfBytes(append(emptyBuf[:], v...))",
FromValue: "v.Bytes()",
GoType: GoBytes,
ToGoType: "append(emptyBuf[:], v...)",
ToGoTypeNoZero: "append(([]byte)(nil), v...)",
FromGoType: "v",
NoPointer: true,
},
{
Name: "Message",
WireType: WireBytes,
ToValue: "protoreflect.ValueOfBytes(v)",
FromValue: "v",
NoValueCodec: true,
},
{
Name: "Group",
WireType: WireGroup,
ToValue: "protoreflect.ValueOfBytes(v)",
FromValue: "v",
NoValueCodec: true,
},
}
func generateProtoDecode() string {
return mustExecute(protoDecodeTemplate, ProtoKinds)
}
var protoDecodeTemplate = template.Must(template.New("").Parse(`
// unmarshalScalar decodes a value of the given kind.
//
// Message values are decoded into a []byte which aliases the input data.
func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp protowire.Type, fd protoreflect.FieldDescriptor) (val protoreflect.Value, n int, err error) {
switch fd.Kind() {
{{- range .}}
case {{.Expr}}:
if wtyp != {{.WireType.Expr}} {
return val, 0, errUnknown
}
{{if (eq .WireType "Group") -}}
v, n := protowire.ConsumeGroup(fd.Number(), b)
{{- else -}}
v, n := protowire.Consume{{.WireType}}(b)
{{- end}}
if n < 0 {
return val, 0, errDecode
}
{{if (eq .Name "String") -}}
if strs.EnforceUTF8(fd) && !utf8.Valid(v) {
return protoreflect.Value{}, 0, errors.InvalidUTF8(string(fd.FullName()))
}
{{end -}}
return {{.ToValue}}, n, nil
{{- end}}
default:
return val, 0, errUnknown
}
}
func (o UnmarshalOptions) unmarshalList(b []byte, wtyp protowire.Type, list protoreflect.List, fd protoreflect.FieldDescriptor) (n int, err error) {
switch fd.Kind() {
{{- range .}}
case {{.Expr}}:
{{- if .WireType.Packable}}
if wtyp == protowire.BytesType {
buf, n := protowire.ConsumeBytes(b)
if n < 0 {
return 0, errDecode
}
for len(buf) > 0 {
v, n := protowire.Consume{{.WireType}}(buf)
if n < 0 {
return 0, errDecode
}
buf = buf[n:]
list.Append({{.ToValue}})
}
return n, nil
}
{{- end}}
if wtyp != {{.WireType.Expr}} {
return 0, errUnknown
}
{{if (eq .WireType "Group") -}}
v, n := protowire.ConsumeGroup(fd.Number(), b)
{{- else -}}
v, n := protowire.Consume{{.WireType}}(b)
{{- end}}
if n < 0 {
return 0, errDecode
}
{{if (eq .Name "String") -}}
if strs.EnforceUTF8(fd) && !utf8.Valid(v) {
return 0, errors.InvalidUTF8(string(fd.FullName()))
}
{{end -}}
{{if or (eq .Name "Message") (eq .Name "Group") -}}
m := list.NewElement()
if err := o.unmarshalMessage(v, m.Message()); err != nil {
return 0, err
}
list.Append(m)
{{- else -}}
list.Append({{.ToValue}})
{{- end}}
return n, nil
{{- end}}
default:
return 0, errUnknown
}
}
// We append to an empty array rather than a nil []byte to get non-nil zero-length byte slices.
var emptyBuf [0]byte
`))
func generateProtoEncode() string {
return mustExecute(protoEncodeTemplate, ProtoKinds)
}
var protoEncodeTemplate = template.Must(template.New("").Parse(`
var wireTypes = map[protoreflect.Kind]protowire.Type{
{{- range .}}
{{.Expr}}: {{.WireType.Expr}},
{{- end}}
}
func (o MarshalOptions) marshalSingular(b []byte, fd protoreflect.FieldDescriptor, v protoreflect.Value) ([]byte, error) {
switch fd.Kind() {
{{- range .}}
case {{.Expr}}:
{{- if (eq .Name "String") }}
if strs.EnforceUTF8(fd) && !utf8.ValidString(v.String()) {
return b, errors.InvalidUTF8(string(fd.FullName()))
}
b = protowire.AppendString(b, {{.FromValue}})
{{- else if (eq .Name "Message") -}}
var pos int
var err error
b, pos = appendSpeculativeLength(b)
b, err = o.marshalMessage(b, v.Message())
if err != nil {
return b, err
}
b = finishSpeculativeLength(b, pos)
{{- else if (eq .Name "Group") -}}
var err error
b, err = o.marshalMessage(b, v.Message())
if err != nil {
return b, err
}
b = protowire.AppendVarint(b, protowire.EncodeTag(fd.Number(), protowire.EndGroupType))
{{- else -}}
b = protowire.Append{{.WireType}}(b, {{.FromValue}})
{{- end}}
{{- end}}
default:
return b, errors.New("invalid kind %v", fd.Kind())
}
return b, nil
}
`))
func generateProtoSize() string {
return mustExecute(protoSizeTemplate, ProtoKinds)
}
var protoSizeTemplate = template.Must(template.New("").Parse(`
func (o MarshalOptions) sizeSingular(num protowire.Number, kind protoreflect.Kind, v protoreflect.Value) int {
switch kind {
{{- range .}}
case {{.Expr}}:
{{if (eq .Name "Message") -}}
return protowire.SizeBytes(o.size(v.Message()))
{{- else if or (eq .WireType "Fixed32") (eq .WireType "Fixed64") -}}
return protowire.Size{{.WireType}}()
{{- else if (eq .WireType "Bytes") -}}
return protowire.Size{{.WireType}}(len({{.FromValue}}))
{{- else if (eq .WireType "Group") -}}
return protowire.Size{{.WireType}}(num, o.size(v.Message()))
{{- else -}}
return protowire.Size{{.WireType}}({{.FromValue}})
{{- end}}
{{- end}}
default:
return 0
}
}
`))