blob: 8b85c69fbdef931b5b62806c7850a5d4b03dcffe [file] [log] [blame]
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impl
import (
"reflect"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/internal/mapsort"
pref "google.golang.org/protobuf/reflect/protoreflect"
)
type mapInfo struct {
goType reflect.Type
keyWiretag uint64
valWiretag uint64
keyFuncs valueCoderFuncs
valFuncs valueCoderFuncs
keyZero pref.Value
keyKind pref.Kind
}
func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) {
// TODO: Consider generating specialized map coders.
keyField := fd.MapKey()
valField := fd.MapValue()
keyWiretag := wire.EncodeTag(1, wireTypes[keyField.Kind()])
valWiretag := wire.EncodeTag(2, wireTypes[valField.Kind()])
keyFuncs := encoderFuncsForValue(keyField)
valFuncs := encoderFuncsForValue(valField)
conv := NewConverter(ft, fd)
mapi := &mapInfo{
goType: ft,
keyWiretag: keyWiretag,
valWiretag: valWiretag,
keyFuncs: keyFuncs,
valFuncs: valFuncs,
keyZero: keyField.Default(),
keyKind: keyField.Kind(),
}
funcs = pointerCoderFuncs{
size: func(p pointer, tagsize int, opts marshalOptions) int {
mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map()
return sizeMap(mapv, tagsize, mapi, opts)
},
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map()
return appendMap(b, mapv, wiretag, mapi, opts)
},
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
mp := p.AsValueOf(ft)
if mp.Elem().IsNil() {
mp.Elem().Set(reflect.MakeMap(mapi.goType))
}
mapv := conv.PBValueOf(mp.Elem()).Map()
return consumeMap(b, mapv, wtyp, mapi, opts)
},
}
if valFuncs.isInit != nil {
funcs.isInit = func(p pointer) error {
mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map()
return isInitMap(mapv, mapi)
}
}
return funcs
}
const (
mapKeyTagSize = 1 // field 1, tag size 1.
mapValTagSize = 1 // field 2, tag size 2.
)
func sizeMap(mapv pref.Map, tagsize int, mapi *mapInfo, opts marshalOptions) int {
if mapv.Len() == 0 {
return 0
}
n := 0
mapv.Range(func(key pref.MapKey, value pref.Value) bool {
n += tagsize + wire.SizeBytes(
mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)+
mapi.valFuncs.size(value, mapValTagSize, opts))
return true
})
return n
}
func consumeMap(b []byte, mapv pref.Map, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
if wtyp != wire.BytesType {
return 0, errUnknown
}
b, n := wire.ConsumeBytes(b)
if n < 0 {
return 0, wire.ParseError(n)
}
var (
key = mapi.keyZero
val = mapv.NewValue()
)
for len(b) > 0 {
num, wtyp, n := wire.ConsumeTag(b)
if n < 0 {
return 0, wire.ParseError(n)
}
b = b[n:]
err := errUnknown
switch num {
case 1:
var v pref.Value
v, n, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
if err != nil {
break
}
key = v
case 2:
var v pref.Value
v, n, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
if err != nil {
break
}
val = v
}
if err == errUnknown {
n = wire.ConsumeFieldValue(num, wtyp, b)
if n < 0 {
return 0, wire.ParseError(n)
}
} else if err != nil {
return 0, err
}
b = b[n:]
}
mapv.Set(key.MapKey(), val)
return n, nil
}
func appendMap(b []byte, mapv pref.Map, wiretag uint64, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
if mapv.Len() == 0 {
return b, nil
}
var err error
fn := func(key pref.MapKey, value pref.Value) bool {
b = wire.AppendVarint(b, wiretag)
size := 0
size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
size += mapi.valFuncs.size(value, mapValTagSize, opts)
b = wire.AppendVarint(b, uint64(size))
b, err = mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
if err != nil {
return false
}
b, err = mapi.valFuncs.marshal(b, value, mapi.valWiretag, opts)
if err != nil {
return false
}
return true
}
if opts.Deterministic() {
mapsort.Range(mapv, mapi.keyKind, fn)
} else {
mapv.Range(fn)
}
return b, err
}
func isInitMap(mapv pref.Map, mapi *mapInfo) error {
var err error
mapv.Range(func(_ pref.MapKey, value pref.Value) bool {
err = mapi.valFuncs.isInit(value)
return err == nil
})
return err
}