// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package main

import (
	"strings"
	"text/template"
)

type WireType string

const (
	WireVarint  WireType = "Varint"
	WireFixed32 WireType = "Fixed32"
	WireFixed64 WireType = "Fixed64"
	WireBytes   WireType = "Bytes"
	WireGroup   WireType = "Group"
)

func (w WireType) Expr() Expr {
	if w == WireGroup {
		return "wire.StartGroupType"
	}
	return "wire." + Expr(w) + "Type"
}

func (w WireType) Packable() bool {
	return w == WireVarint || w == WireFixed32 || w == WireFixed64
}

func (w WireType) ConstSize() bool {
	return w == WireFixed32 || w == WireFixed64
}

type GoType string

const (
	GoBool    = "bool"
	GoInt32   = "int32"
	GoUint32  = "uint32"
	GoInt64   = "int64"
	GoUint64  = "uint64"
	GoFloat32 = "float32"
	GoFloat64 = "float64"
	GoString  = "string"
	GoBytes   = "[]byte"
)

func (g GoType) Zero() Expr {
	switch g {
	case GoBool:
		return "false"
	case GoString:
		return `""`
	case GoBytes:
		return "nil"
	}
	return "0"
}

// Kind is the reflect.Kind of the type.
func (g GoType) Kind() Expr {
	if g == "" || g == GoBytes {
		return ""
	}
	return "reflect." + Expr(strings.ToUpper(string(g[:1]))+string(g[1:]))
}

// PointerMethod is the "internal/impl".pointer method used to access a pointer to this type.
func (g GoType) PointerMethod() Expr {
	if g == GoBytes {
		return "Bytes"
	}
	return Expr(strings.ToUpper(string(g[:1])) + string(g[1:]))
}

type ProtoKind struct {
	Name     string
	WireType WireType

	// Conversions to/from protoreflect.Value.
	ToValue   Expr
	FromValue Expr

	// Conversions to/from generated structures.
	GoType         GoType
	ToGoType       Expr
	ToGoTypeNoZero Expr
	FromGoType     Expr
	NoPointer      bool
	NoValueCodec   bool
}

func (k ProtoKind) Expr() Expr {
	return "protoreflect." + Expr(k.Name) + "Kind"
}

var ProtoKinds = []ProtoKind{
	{
		Name:       "Bool",
		WireType:   WireVarint,
		ToValue:    "protoreflect.ValueOfBool(wire.DecodeBool(v))",
		FromValue:  "wire.EncodeBool(v.Bool())",
		GoType:     GoBool,
		ToGoType:   "wire.DecodeBool(v)",
		FromGoType: "wire.EncodeBool(v)",
	},
	{
		Name:      "Enum",
		WireType:  WireVarint,
		ToValue:   "protoreflect.ValueOfEnum(protoreflect.EnumNumber(v))",
		FromValue: "uint64(v.Enum())",
	},
	{
		Name:       "Int32",
		WireType:   WireVarint,
		ToValue:    "protoreflect.ValueOfInt32(int32(v))",
		FromValue:  "uint64(int32(v.Int()))",
		GoType:     GoInt32,
		ToGoType:   "int32(v)",
		FromGoType: "uint64(v)",
	},
	{
		Name:       "Sint32",
		WireType:   WireVarint,
		ToValue:    "protoreflect.ValueOfInt32(int32(wire.DecodeZigZag(v & math.MaxUint32)))",
		FromValue:  "wire.EncodeZigZag(int64(int32(v.Int())))",
		GoType:     GoInt32,
		ToGoType:   "int32(wire.DecodeZigZag(v & math.MaxUint32))",
		FromGoType: "wire.EncodeZigZag(int64(v))",
	},
	{
		Name:       "Uint32",
		WireType:   WireVarint,
		ToValue:    "protoreflect.ValueOfUint32(uint32(v))",
		FromValue:  "uint64(uint32(v.Uint()))",
		GoType:     GoUint32,
		ToGoType:   "uint32(v)",
		FromGoType: "uint64(v)",
	},
	{
		Name:       "Int64",
		WireType:   WireVarint,
		ToValue:    "protoreflect.ValueOfInt64(int64(v))",
		FromValue:  "uint64(v.Int())",
		GoType:     GoInt64,
		ToGoType:   "int64(v)",
		FromGoType: "uint64(v)",
	},
	{
		Name:       "Sint64",
		WireType:   WireVarint,
		ToValue:    "protoreflect.ValueOfInt64(wire.DecodeZigZag(v))",
		FromValue:  "wire.EncodeZigZag(v.Int())",
		GoType:     GoInt64,
		ToGoType:   "wire.DecodeZigZag(v)",
		FromGoType: "wire.EncodeZigZag(v)",
	},
	{
		Name:       "Uint64",
		WireType:   WireVarint,
		ToValue:    "protoreflect.ValueOfUint64(v)",
		FromValue:  "v.Uint()",
		GoType:     GoUint64,
		ToGoType:   "v",
		FromGoType: "v",
	},
	{
		Name:       "Sfixed32",
		WireType:   WireFixed32,
		ToValue:    "protoreflect.ValueOfInt32(int32(v))",
		FromValue:  "uint32(v.Int())",
		GoType:     GoInt32,
		ToGoType:   "int32(v)",
		FromGoType: "uint32(v)",
	},
	{
		Name:       "Fixed32",
		WireType:   WireFixed32,
		ToValue:    "protoreflect.ValueOfUint32(uint32(v))",
		FromValue:  "uint32(v.Uint())",
		GoType:     GoUint32,
		ToGoType:   "v",
		FromGoType: "v",
	},
	{
		Name:       "Float",
		WireType:   WireFixed32,
		ToValue:    "protoreflect.ValueOfFloat32(math.Float32frombits(uint32(v)))",
		FromValue:  "math.Float32bits(float32(v.Float()))",
		GoType:     GoFloat32,
		ToGoType:   "math.Float32frombits(v)",
		FromGoType: "math.Float32bits(v)",
	},
	{
		Name:       "Sfixed64",
		WireType:   WireFixed64,
		ToValue:    "protoreflect.ValueOfInt64(int64(v))",
		FromValue:  "uint64(v.Int())",
		GoType:     GoInt64,
		ToGoType:   "int64(v)",
		FromGoType: "uint64(v)",
	},
	{
		Name:       "Fixed64",
		WireType:   WireFixed64,
		ToValue:    "protoreflect.ValueOfUint64(v)",
		FromValue:  "v.Uint()",
		GoType:     GoUint64,
		ToGoType:   "v",
		FromGoType: "v",
	},
	{
		Name:       "Double",
		WireType:   WireFixed64,
		ToValue:    "protoreflect.ValueOfFloat64(math.Float64frombits(v))",
		FromValue:  "math.Float64bits(v.Float())",
		GoType:     GoFloat64,
		ToGoType:   "math.Float64frombits(v)",
		FromGoType: "math.Float64bits(v)",
	},
	{
		Name:       "String",
		WireType:   WireBytes,
		ToValue:    "protoreflect.ValueOfString(string(v))",
		FromValue:  "v.String()",
		GoType:     GoString,
		ToGoType:   "v",
		FromGoType: "v",
	},
	{
		Name:           "Bytes",
		WireType:       WireBytes,
		ToValue:        "protoreflect.ValueOfBytes(append(([]byte)(nil), v...))",
		FromValue:      "v.Bytes()",
		GoType:         GoBytes,
		ToGoType:       "append(emptyBuf[:], v...)",
		ToGoTypeNoZero: "append(([]byte)(nil), v...)",
		FromGoType:     "v",
		NoPointer:      true,
	},
	{
		Name:         "Message",
		WireType:     WireBytes,
		ToValue:      "protoreflect.ValueOfBytes(v)",
		FromValue:    "v",
		NoValueCodec: true,
	},
	{
		Name:         "Group",
		WireType:     WireGroup,
		ToValue:      "protoreflect.ValueOfBytes(v)",
		FromValue:    "v",
		NoValueCodec: true,
	},
}

func generateProtoDecode() string {
	return mustExecute(protoDecodeTemplate, ProtoKinds)
}

var protoDecodeTemplate = template.Must(template.New("").Parse(`
// unmarshalScalar decodes a value of the given kind.
//
// Message values are decoded into a []byte which aliases the input data.
func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, fd protoreflect.FieldDescriptor) (val protoreflect.Value, n int, err error) {
	switch fd.Kind() {
	{{- range .}}
	case {{.Expr}}:
		if wtyp != {{.WireType.Expr}} {
			return val, 0, errUnknown
		}
		{{if (eq .WireType "Group") -}}
		v, n := wire.ConsumeGroup(fd.Number(), b)
		{{- else -}}
		v, n := wire.Consume{{.WireType}}(b)
		{{- end}}
		if n < 0 {
			return val, 0, wire.ParseError(n)
		}
		{{if (eq .Name "String") -}}
		if strs.EnforceUTF8(fd) && !utf8.Valid(v) {
			return protoreflect.Value{}, 0, errors.InvalidUTF8(string(fd.FullName()))
		}
		{{end -}}
		return {{.ToValue}}, n, nil
	{{- end}}
	default:
		return val, 0, errUnknown
	}
}

func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, list protoreflect.List, fd protoreflect.FieldDescriptor) (n int, err error) {
	switch fd.Kind() {
	{{- range .}}
	case {{.Expr}}:
		{{- if .WireType.Packable}}
		if wtyp == wire.BytesType {
			buf, n := wire.ConsumeBytes(b)
			if n < 0 {
				return 0, wire.ParseError(n)
			}
			for len(buf) > 0 {
				v, n := wire.Consume{{.WireType}}(buf)
				if n < 0 {
					return 0, wire.ParseError(n)
				}
				buf = buf[n:]
				list.Append({{.ToValue}})
			}
			return n, nil
		}
		{{- end}}
		if wtyp != {{.WireType.Expr}} {
			return 0, errUnknown
		}
		{{if (eq .WireType "Group") -}}
		v, n := wire.ConsumeGroup(fd.Number(), b)
		{{- else -}}
		v, n := wire.Consume{{.WireType}}(b)
		{{- end}}
		if n < 0 {
			return 0, wire.ParseError(n)
		}
		{{if (eq .Name "String") -}}
		if strs.EnforceUTF8(fd) && !utf8.Valid(v) {
			return 0, errors.InvalidUTF8(string(fd.FullName()))
		}
		{{end -}}
		{{if or (eq .Name "Message") (eq .Name "Group") -}}
		m := list.NewElement()
		if err := o.unmarshalMessage(v, m.Message()); err != nil {
			return 0, err
		}
		list.Append(m)
		{{- else -}}
		list.Append({{.ToValue}})
		{{- end}}
		return n, nil
	{{- end}}
	default:
		return 0, errUnknown
	}
}
`))

func generateProtoEncode() string {
	return mustExecute(protoEncodeTemplate, ProtoKinds)
}

var protoEncodeTemplate = template.Must(template.New("").Parse(`
var wireTypes = map[protoreflect.Kind]wire.Type{
{{- range .}}
	{{.Expr}}: {{.WireType.Expr}},
{{- end}}
}

func (o MarshalOptions) marshalSingular(b []byte, fd protoreflect.FieldDescriptor, v protoreflect.Value) ([]byte, error) {
	switch fd.Kind() {
	{{- range .}}
	case {{.Expr}}:
		{{- if (eq .Name "String") }}
		if strs.EnforceUTF8(fd) && !utf8.ValidString(v.String()) {
			return b, errors.InvalidUTF8(string(fd.FullName()))
		}
		b = wire.AppendString(b, {{.FromValue}})
		{{- else if (eq .Name "Message") -}}
		var pos int
		var err error
		b, pos = appendSpeculativeLength(b)
		b, err = o.marshalMessage(b, v.Message())
		if err != nil {
			return b, err
		}
		b = finishSpeculativeLength(b, pos)
		{{- else if (eq .Name "Group") -}}
		var err error
		b, err = o.marshalMessage(b, v.Message())
		if err != nil {
			return b, err
		}
		b = wire.AppendVarint(b, wire.EncodeTag(fd.Number(), wire.EndGroupType))
		{{- else -}}
		b = wire.Append{{.WireType}}(b, {{.FromValue}})
		{{- end}}
	{{- end}}
	default:
		return b, errors.New("invalid kind %v", fd.Kind())
	}
	return b, nil
}
`))

func generateProtoSize() string {
	return mustExecute(protoSizeTemplate, ProtoKinds)
}

var protoSizeTemplate = template.Must(template.New("").Parse(`
func sizeSingular(num wire.Number, kind protoreflect.Kind, v protoreflect.Value) int {
	switch kind {
	{{- range .}}
	case {{.Expr}}:
		{{if (eq .Name "Message") -}}
		return wire.SizeBytes(sizeMessage(v.Message()))
		{{- else if or (eq .WireType "Fixed32") (eq .WireType "Fixed64") -}}
		return wire.Size{{.WireType}}()
		{{- else if (eq .WireType "Bytes") -}}
		return wire.Size{{.WireType}}(len({{.FromValue}}))
		{{- else if (eq .WireType "Group") -}}
		return wire.Size{{.WireType}}(num, sizeMessage(v.Message()))
		{{- else -}}
		return wire.Size{{.WireType}}({{.FromValue}})
		{{- end}}
	{{- end}}
	default:
		return 0
	}
}
`))
