// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Package protobuild constructs messages.
//
// This package is used to construct multiple types of message with a similar shape
// from a common template.
package protobuild

import (
	"fmt"
	"math"
	"reflect"

	pref "google.golang.org/protobuf/reflect/protoreflect"
	"google.golang.org/protobuf/reflect/protoregistry"
)

// A Value is a value assignable to a field.
// A Value may be a value accepted by protoreflect.ValueOf. In addition:
//
// • An int may be assigned to any numeric field.
//
// • A float64 may be assigned to a double field.
//
// • Either a string or []byte may be assigned to a string or bytes field.
//
// • A string containing the value name may be assigned to an enum field.
//
// • A slice may be assigned to a list, and a map may be assigned to a map.
type Value interface{}

// A Message is a template to apply to a message. Keys are field names, including
// extension names.
type Message map[pref.Name]Value

// Unknown is a key associated with the unknown fields of a message.
// The value should be a []byte.
const Unknown = "@unknown"

// Build applies the template to a message.
func (template Message) Build(m pref.Message) {
	md := m.Descriptor()
	fields := md.Fields()
	exts := make(map[pref.Name]pref.FieldDescriptor)
	protoregistry.GlobalTypes.RangeExtensionsByMessage(md.FullName(), func(xt pref.ExtensionType) bool {
		xd := xt.TypeDescriptor()
		exts[xd.Name()] = xd
		return true
	})
	for k, v := range template {
		if k == Unknown {
			m.SetUnknown(pref.RawFields(v.([]byte)))
			continue
		}
		fd := fields.ByName(k)
		if fd == nil {
			fd = exts[k]
		}
		if fd == nil {
			panic(fmt.Sprintf("%v.%v: not found", md.FullName(), k))
		}
		switch {
		case fd.IsList():
			list := m.Mutable(fd).List()
			s := reflect.ValueOf(v)
			for i := 0; i < s.Len(); i++ {
				if fd.Message() == nil {
					list.Append(fieldValue(fd, s.Index(i).Interface()))
				} else {
					e := list.NewElement()
					s.Index(i).Interface().(Message).Build(e.Message())
					list.Append(e)
				}
			}
		case fd.IsMap():
			mapv := m.Mutable(fd).Map()
			rm := reflect.ValueOf(v)
			for _, k := range rm.MapKeys() {
				mk := fieldValue(fd.MapKey(), k.Interface()).MapKey()
				if fd.MapValue().Message() == nil {
					mv := fieldValue(fd.MapValue(), rm.MapIndex(k).Interface())
					mapv.Set(mk, mv)
				} else if mapv.Has(mk) {
					mv := mapv.Get(mk).Message()
					rm.MapIndex(k).Interface().(Message).Build(mv)
				} else {
					mv := mapv.NewValue()
					rm.MapIndex(k).Interface().(Message).Build(mv.Message())
					mapv.Set(mk, mv)
				}
			}
		default:
			if fd.Message() == nil {
				m.Set(fd, fieldValue(fd, v))
			} else {
				v.(Message).Build(m.Mutable(fd).Message())
			}
		}
	}
}

func fieldValue(fd pref.FieldDescriptor, v interface{}) pref.Value {
	switch o := v.(type) {
	case int:
		switch fd.Kind() {
		case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
			if min, max := math.MinInt32, math.MaxInt32; o < min || o > max {
				panic(fmt.Sprintf("%v: value %v out of range [%v, %v]", fd.FullName(), o, min, max))
			}
			v = int32(o)
		case pref.Uint32Kind, pref.Fixed32Kind:
			if min, max := 0, math.MaxUint32; o < min || o > max {
				panic(fmt.Sprintf("%v: value %v out of range [%v, %v]", fd.FullName(), o, min, max))
			}
			v = uint32(o)
		case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
			v = int64(o)
		case pref.Uint64Kind, pref.Fixed64Kind:
			if o < 0 {
				panic(fmt.Sprintf("%v: value %v out of range [%v, %v]", fd.FullName(), o, 0, uint64(math.MaxUint64)))
			}
			v = uint64(o)
		case pref.FloatKind:
			v = float32(o)
		case pref.DoubleKind:
			v = float64(o)
		case pref.EnumKind:
			v = pref.EnumNumber(o)
		default:
			panic(fmt.Sprintf("%v: invalid value type int", fd.FullName()))
		}
	case float64:
		switch fd.Kind() {
		case pref.FloatKind:
			v = float32(o)
		}
	case string:
		switch fd.Kind() {
		case pref.BytesKind:
			v = []byte(o)
		case pref.EnumKind:
			v = fd.Enum().Values().ByName(pref.Name(o)).Number()
		}
	case []byte:
		return pref.ValueOf(append([]byte{}, o...))
	}
	return pref.ValueOf(v)
}
