| // Copyright 2019 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package impl |
| |
| import ( |
| "reflect" |
| "sort" |
| |
| "google.golang.org/protobuf/encoding/protowire" |
| "google.golang.org/protobuf/internal/genid" |
| "google.golang.org/protobuf/reflect/protoreflect" |
| ) |
| |
| type mapInfo struct { |
| goType reflect.Type |
| keyWiretag uint64 |
| valWiretag uint64 |
| keyFuncs valueCoderFuncs |
| valFuncs valueCoderFuncs |
| keyZero protoreflect.Value |
| keyKind protoreflect.Kind |
| conv *mapConverter |
| } |
| |
| func encoderFuncsForMap(fd protoreflect.FieldDescriptor, ft reflect.Type) (valueMessage *MessageInfo, funcs pointerCoderFuncs) { |
| // TODO: Consider generating specialized map coders. |
| keyField := fd.MapKey() |
| valField := fd.MapValue() |
| keyWiretag := protowire.EncodeTag(1, wireTypes[keyField.Kind()]) |
| valWiretag := protowire.EncodeTag(2, wireTypes[valField.Kind()]) |
| keyFuncs := encoderFuncsForValue(keyField) |
| valFuncs := encoderFuncsForValue(valField) |
| conv := newMapConverter(ft, fd) |
| |
| mapi := &mapInfo{ |
| goType: ft, |
| keyWiretag: keyWiretag, |
| valWiretag: valWiretag, |
| keyFuncs: keyFuncs, |
| valFuncs: valFuncs, |
| keyZero: keyField.Default(), |
| keyKind: keyField.Kind(), |
| conv: conv, |
| } |
| if valField.Kind() == protoreflect.MessageKind { |
| valueMessage = getMessageInfo(ft.Elem()) |
| } |
| |
| funcs = pointerCoderFuncs{ |
| size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int { |
| return sizeMap(p.AsValueOf(ft).Elem(), mapi, f, opts) |
| }, |
| marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) { |
| return appendMap(b, p.AsValueOf(ft).Elem(), mapi, f, opts) |
| }, |
| unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) { |
| mp := p.AsValueOf(ft) |
| if mp.Elem().IsNil() { |
| mp.Elem().Set(reflect.MakeMap(mapi.goType)) |
| } |
| if f.mi == nil { |
| return consumeMap(b, mp.Elem(), wtyp, mapi, f, opts) |
| } else { |
| return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, f, opts) |
| } |
| }, |
| } |
| switch valField.Kind() { |
| case protoreflect.MessageKind: |
| funcs.merge = mergeMapOfMessage |
| case protoreflect.BytesKind: |
| funcs.merge = mergeMapOfBytes |
| default: |
| funcs.merge = mergeMap |
| } |
| if valFuncs.isInit != nil { |
| funcs.isInit = func(p pointer, f *coderFieldInfo) error { |
| return isInitMap(p.AsValueOf(ft).Elem(), mapi, f) |
| } |
| } |
| return valueMessage, funcs |
| } |
| |
| const ( |
| mapKeyTagSize = 1 // field 1, tag size 1. |
| mapValTagSize = 1 // field 2, tag size 2. |
| ) |
| |
| func sizeMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) int { |
| if mapv.Len() == 0 { |
| return 0 |
| } |
| n := 0 |
| iter := mapRange(mapv) |
| for iter.Next() { |
| key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey() |
| keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts) |
| var valSize int |
| value := mapi.conv.valConv.PBValueOf(iter.Value()) |
| if f.mi == nil { |
| valSize = mapi.valFuncs.size(value, mapValTagSize, opts) |
| } else { |
| p := pointerOfValue(iter.Value()) |
| valSize += mapValTagSize |
| valSize += protowire.SizeBytes(f.mi.sizePointer(p, opts)) |
| } |
| n += f.tagsize + protowire.SizeBytes(keySize+valSize) |
| } |
| return n |
| } |
| |
| func consumeMap(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) { |
| if wtyp != protowire.BytesType { |
| return out, errUnknown |
| } |
| b, n := protowire.ConsumeBytes(b) |
| if n < 0 { |
| return out, errDecode |
| } |
| var ( |
| key = mapi.keyZero |
| val = mapi.conv.valConv.New() |
| ) |
| for len(b) > 0 { |
| num, wtyp, n := protowire.ConsumeTag(b) |
| if n < 0 { |
| return out, errDecode |
| } |
| if num > protowire.MaxValidNumber { |
| return out, errDecode |
| } |
| b = b[n:] |
| err := errUnknown |
| switch num { |
| case genid.MapEntry_Key_field_number: |
| var v protoreflect.Value |
| var o unmarshalOutput |
| v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts) |
| if err != nil { |
| break |
| } |
| key = v |
| n = o.n |
| case genid.MapEntry_Value_field_number: |
| var v protoreflect.Value |
| var o unmarshalOutput |
| v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts) |
| if err != nil { |
| break |
| } |
| val = v |
| n = o.n |
| } |
| if err == errUnknown { |
| n = protowire.ConsumeFieldValue(num, wtyp, b) |
| if n < 0 { |
| return out, errDecode |
| } |
| } else if err != nil { |
| return out, err |
| } |
| b = b[n:] |
| } |
| mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val)) |
| out.n = n |
| return out, nil |
| } |
| |
| func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) { |
| if wtyp != protowire.BytesType { |
| return out, errUnknown |
| } |
| b, n := protowire.ConsumeBytes(b) |
| if n < 0 { |
| return out, errDecode |
| } |
| var ( |
| key = mapi.keyZero |
| val = reflect.New(f.mi.GoReflectType.Elem()) |
| ) |
| for len(b) > 0 { |
| num, wtyp, n := protowire.ConsumeTag(b) |
| if n < 0 { |
| return out, errDecode |
| } |
| if num > protowire.MaxValidNumber { |
| return out, errDecode |
| } |
| b = b[n:] |
| err := errUnknown |
| switch num { |
| case 1: |
| var v protoreflect.Value |
| var o unmarshalOutput |
| v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts) |
| if err != nil { |
| break |
| } |
| key = v |
| n = o.n |
| case 2: |
| if wtyp != protowire.BytesType { |
| break |
| } |
| var v []byte |
| v, n = protowire.ConsumeBytes(b) |
| if n < 0 { |
| return out, errDecode |
| } |
| var o unmarshalOutput |
| o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts) |
| if o.initialized { |
| // Consider this map item initialized so long as we see |
| // an initialized value. |
| out.initialized = true |
| } |
| } |
| if err == errUnknown { |
| n = protowire.ConsumeFieldValue(num, wtyp, b) |
| if n < 0 { |
| return out, errDecode |
| } |
| } else if err != nil { |
| return out, err |
| } |
| b = b[n:] |
| } |
| mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val) |
| out.n = n |
| return out, nil |
| } |
| |
| func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) { |
| if f.mi == nil { |
| key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey() |
| val := mapi.conv.valConv.PBValueOf(valrv) |
| size := 0 |
| size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts) |
| size += mapi.valFuncs.size(val, mapValTagSize, opts) |
| b = protowire.AppendVarint(b, uint64(size)) |
| b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts) |
| if err != nil { |
| return nil, err |
| } |
| return mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts) |
| } else { |
| key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey() |
| val := pointerOfValue(valrv) |
| valSize := f.mi.sizePointer(val, opts) |
| size := 0 |
| size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts) |
| size += mapValTagSize + protowire.SizeBytes(valSize) |
| b = protowire.AppendVarint(b, uint64(size)) |
| b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts) |
| if err != nil { |
| return nil, err |
| } |
| b = protowire.AppendVarint(b, mapi.valWiretag) |
| b = protowire.AppendVarint(b, uint64(valSize)) |
| return f.mi.marshalAppendPointer(b, val, opts) |
| } |
| } |
| |
| func appendMap(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) { |
| if mapv.Len() == 0 { |
| return b, nil |
| } |
| if opts.Deterministic() { |
| return appendMapDeterministic(b, mapv, mapi, f, opts) |
| } |
| iter := mapRange(mapv) |
| for iter.Next() { |
| var err error |
| b = protowire.AppendVarint(b, f.wiretag) |
| b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, f, opts) |
| if err != nil { |
| return b, err |
| } |
| } |
| return b, nil |
| } |
| |
| func appendMapDeterministic(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) { |
| keys := mapv.MapKeys() |
| sort.Slice(keys, func(i, j int) bool { |
| switch keys[i].Kind() { |
| case reflect.Bool: |
| return !keys[i].Bool() && keys[j].Bool() |
| case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: |
| return keys[i].Int() < keys[j].Int() |
| case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: |
| return keys[i].Uint() < keys[j].Uint() |
| case reflect.Float32, reflect.Float64: |
| return keys[i].Float() < keys[j].Float() |
| case reflect.String: |
| return keys[i].String() < keys[j].String() |
| default: |
| panic("invalid kind: " + keys[i].Kind().String()) |
| } |
| }) |
| for _, key := range keys { |
| var err error |
| b = protowire.AppendVarint(b, f.wiretag) |
| b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, f, opts) |
| if err != nil { |
| return b, err |
| } |
| } |
| return b, nil |
| } |
| |
| func isInitMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo) error { |
| if mi := f.mi; mi != nil { |
| mi.init() |
| if !mi.needsInitCheck { |
| return nil |
| } |
| iter := mapRange(mapv) |
| for iter.Next() { |
| val := pointerOfValue(iter.Value()) |
| if err := mi.checkInitializedPointer(val); err != nil { |
| return err |
| } |
| } |
| } else { |
| iter := mapRange(mapv) |
| for iter.Next() { |
| val := mapi.conv.valConv.PBValueOf(iter.Value()) |
| if err := mapi.valFuncs.isInit(val); err != nil { |
| return err |
| } |
| } |
| } |
| return nil |
| } |
| |
| func mergeMap(dst, src pointer, f *coderFieldInfo, opts mergeOptions) { |
| dstm := dst.AsValueOf(f.ft).Elem() |
| srcm := src.AsValueOf(f.ft).Elem() |
| if srcm.Len() == 0 { |
| return |
| } |
| if dstm.IsNil() { |
| dstm.Set(reflect.MakeMap(f.ft)) |
| } |
| iter := mapRange(srcm) |
| for iter.Next() { |
| dstm.SetMapIndex(iter.Key(), iter.Value()) |
| } |
| } |
| |
| func mergeMapOfBytes(dst, src pointer, f *coderFieldInfo, opts mergeOptions) { |
| dstm := dst.AsValueOf(f.ft).Elem() |
| srcm := src.AsValueOf(f.ft).Elem() |
| if srcm.Len() == 0 { |
| return |
| } |
| if dstm.IsNil() { |
| dstm.Set(reflect.MakeMap(f.ft)) |
| } |
| iter := mapRange(srcm) |
| for iter.Next() { |
| dstm.SetMapIndex(iter.Key(), reflect.ValueOf(append(emptyBuf[:], iter.Value().Bytes()...))) |
| } |
| } |
| |
| func mergeMapOfMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) { |
| dstm := dst.AsValueOf(f.ft).Elem() |
| srcm := src.AsValueOf(f.ft).Elem() |
| if srcm.Len() == 0 { |
| return |
| } |
| if dstm.IsNil() { |
| dstm.Set(reflect.MakeMap(f.ft)) |
| } |
| iter := mapRange(srcm) |
| for iter.Next() { |
| val := reflect.New(f.ft.Elem().Elem()) |
| if f.mi != nil { |
| f.mi.mergePointer(pointerOfValue(val), pointerOfValue(iter.Value()), opts) |
| } else { |
| opts.Merge(asMessage(val), asMessage(iter.Value())) |
| } |
| dstm.SetMapIndex(iter.Key(), val) |
| } |
| } |