| // Copyright 2024 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package impl |
| |
| import ( |
| "fmt" |
| "math/bits" |
| "os" |
| "reflect" |
| "sort" |
| "sync/atomic" |
| |
| "google.golang.org/protobuf/encoding/protowire" |
| "google.golang.org/protobuf/internal/errors" |
| "google.golang.org/protobuf/internal/protolazy" |
| "google.golang.org/protobuf/reflect/protoreflect" |
| preg "google.golang.org/protobuf/reflect/protoregistry" |
| piface "google.golang.org/protobuf/runtime/protoiface" |
| ) |
| |
| var enableLazy int32 = func() int32 { |
| if os.Getenv("GOPROTODEBUG") == "nolazy" { |
| return 0 |
| } |
| return 1 |
| }() |
| |
| // EnableLazyUnmarshal enables lazy unmarshaling. |
| func EnableLazyUnmarshal(enable bool) { |
| if enable { |
| atomic.StoreInt32(&enableLazy, 1) |
| return |
| } |
| atomic.StoreInt32(&enableLazy, 0) |
| } |
| |
| // LazyEnabled reports whether lazy unmarshalling is currently enabled. |
| func LazyEnabled() bool { |
| return atomic.LoadInt32(&enableLazy) != 0 |
| } |
| |
| // UnmarshalField unmarshals a field in a message. |
| func UnmarshalField(m interface{}, num protowire.Number) { |
| switch m := m.(type) { |
| case *messageState: |
| m.messageInfo().lazyUnmarshal(m.pointer(), num) |
| case *messageReflectWrapper: |
| m.messageInfo().lazyUnmarshal(m.pointer(), num) |
| default: |
| panic(fmt.Sprintf("unsupported wrapper type %T", m)) |
| } |
| } |
| |
| func (mi *MessageInfo) lazyUnmarshal(p pointer, num protoreflect.FieldNumber) { |
| var f *coderFieldInfo |
| if int(num) < len(mi.denseCoderFields) { |
| f = mi.denseCoderFields[num] |
| } else { |
| f = mi.coderFields[num] |
| } |
| if f == nil { |
| panic(fmt.Sprintf("lazyUnmarshal: field info for %v.%v", mi.Desc.FullName(), num)) |
| } |
| lazy := *p.Apply(mi.lazyOffset).LazyInfoPtr() |
| start, end, found, _, multipleEntries := lazy.FindFieldInProto(uint32(num)) |
| if !found && multipleEntries == nil { |
| panic(fmt.Sprintf("lazyUnmarshal: can't find field data for %v.%v", mi.Desc.FullName(), num)) |
| } |
| // The actual pointer in the message can not be set until the whole struct is filled in, otherwise we will have races. |
| // Create another pointer and set it atomically, if we won the race and the pointer in the original message is still nil. |
| fp := pointerOfValue(reflect.New(f.ft)) |
| if multipleEntries != nil { |
| for _, entry := range multipleEntries { |
| mi.unmarshalField(lazy.Buffer()[entry.Start:entry.End], fp, f, lazy, lazy.UnmarshalFlags()) |
| } |
| } else { |
| mi.unmarshalField(lazy.Buffer()[start:end], fp, f, lazy, lazy.UnmarshalFlags()) |
| } |
| p.Apply(f.offset).AtomicSetPointerIfNil(fp.Elem()) |
| } |
| |
| func (mi *MessageInfo) unmarshalField(b []byte, p pointer, f *coderFieldInfo, lazyInfo *protolazy.XXX_lazyUnmarshalInfo, flags piface.UnmarshalInputFlags) error { |
| opts := lazyUnmarshalOptions |
| opts.flags |= flags |
| for len(b) > 0 { |
| // Parse the tag (field number and wire type). |
| var tag uint64 |
| if b[0] < 0x80 { |
| tag = uint64(b[0]) |
| b = b[1:] |
| } else if len(b) >= 2 && b[1] < 128 { |
| tag = uint64(b[0]&0x7f) + uint64(b[1])<<7 |
| b = b[2:] |
| } else { |
| var n int |
| tag, n = protowire.ConsumeVarint(b) |
| if n < 0 { |
| return errors.New("invalid wire data") |
| } |
| b = b[n:] |
| } |
| var num protowire.Number |
| if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) { |
| return errors.New("invalid wire data") |
| } else { |
| num = protowire.Number(n) |
| } |
| wtyp := protowire.Type(tag & 7) |
| if num == f.num { |
| o, err := f.funcs.unmarshal(b, p, wtyp, f, opts) |
| if err == nil { |
| b = b[o.n:] |
| continue |
| } |
| if err != errUnknown { |
| return err |
| } |
| } |
| n := protowire.ConsumeFieldValue(num, wtyp, b) |
| if n < 0 { |
| return errors.New("invalid wire data") |
| } |
| b = b[n:] |
| } |
| return nil |
| } |
| |
| func (mi *MessageInfo) skipField(b []byte, f *coderFieldInfo, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) { |
| fmi := f.validation.mi |
| if fmi == nil { |
| fd := mi.Desc.Fields().ByNumber(f.num) |
| if fd == nil { |
| return out, ValidationUnknown |
| } |
| messageName := fd.Message().FullName() |
| messageType, err := preg.GlobalTypes.FindMessageByName(messageName) |
| if err != nil { |
| return out, ValidationUnknown |
| } |
| var ok bool |
| fmi, ok = messageType.(*MessageInfo) |
| if !ok { |
| return out, ValidationUnknown |
| } |
| } |
| fmi.init() |
| switch f.validation.typ { |
| case validationTypeMessage: |
| if wtyp != protowire.BytesType { |
| return out, ValidationWrongWireType |
| } |
| v, n := protowire.ConsumeBytes(b) |
| if n < 0 { |
| return out, ValidationInvalid |
| } |
| out, st := fmi.validate(v, 0, opts) |
| out.n = n |
| return out, st |
| case validationTypeGroup: |
| if wtyp != protowire.StartGroupType { |
| return out, ValidationWrongWireType |
| } |
| out, st := fmi.validate(b, f.num, opts) |
| return out, st |
| default: |
| return out, ValidationUnknown |
| } |
| } |
| |
| // unmarshalPointerLazy is similar to unmarshalPointerEager, but it |
| // specifically handles lazy unmarshalling. it expects lazyOffset and |
| // presenceOffset to both be valid. |
| func (mi *MessageInfo) unmarshalPointerLazy(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) { |
| initialized := true |
| var requiredMask uint64 |
| var lazy **protolazy.XXX_lazyUnmarshalInfo |
| var presence presence |
| var lazyIndex []protolazy.IndexEntry |
| var lastNum protowire.Number |
| outOfOrder := false |
| lazyDecode := false |
| presence = p.Apply(mi.presenceOffset).PresenceInfo() |
| lazy = p.Apply(mi.lazyOffset).LazyInfoPtr() |
| if !presence.AnyPresent(mi.presenceSize) { |
| if opts.CanBeLazy() { |
| // If the message contains existing data, we need to merge into it. |
| // Lazy unmarshaling doesn't merge, so only enable it when the |
| // message is empty (has no presence bitmap). |
| lazyDecode = true |
| if *lazy == nil { |
| *lazy = &protolazy.XXX_lazyUnmarshalInfo{} |
| } |
| (*lazy).SetUnmarshalFlags(opts.flags) |
| if !opts.AliasBuffer() { |
| // Make a copy of the buffer for lazy unmarshaling. |
| // Set the AliasBuffer flag so recursive unmarshal |
| // operations reuse the copy. |
| b = append([]byte{}, b...) |
| opts.flags |= piface.UnmarshalAliasBuffer |
| } |
| (*lazy).SetBuffer(b) |
| } |
| } |
| // Track special handling of lazy fields. |
| // |
| // In the common case, all fields are lazyValidateOnly (and lazyFields remains nil). |
| // In the event that validation for a field fails, this map tracks handling of the field. |
| type lazyAction uint8 |
| const ( |
| lazyValidateOnly lazyAction = iota // validate the field only |
| lazyUnmarshalNow // eagerly unmarshal the field |
| lazyUnmarshalLater // unmarshal the field after the message is fully processed |
| ) |
| var lazyFields map[*coderFieldInfo]lazyAction |
| var exts *map[int32]ExtensionField |
| start := len(b) |
| pos := 0 |
| for len(b) > 0 { |
| // Parse the tag (field number and wire type). |
| var tag uint64 |
| if b[0] < 0x80 { |
| tag = uint64(b[0]) |
| b = b[1:] |
| } else if len(b) >= 2 && b[1] < 128 { |
| tag = uint64(b[0]&0x7f) + uint64(b[1])<<7 |
| b = b[2:] |
| } else { |
| var n int |
| tag, n = protowire.ConsumeVarint(b) |
| if n < 0 { |
| return out, errDecode |
| } |
| b = b[n:] |
| } |
| var num protowire.Number |
| if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) { |
| return out, errors.New("invalid field number") |
| } else { |
| num = protowire.Number(n) |
| } |
| wtyp := protowire.Type(tag & 7) |
| |
| if wtyp == protowire.EndGroupType { |
| if num != groupTag { |
| return out, errors.New("mismatching end group marker") |
| } |
| groupTag = 0 |
| break |
| } |
| |
| var f *coderFieldInfo |
| if int(num) < len(mi.denseCoderFields) { |
| f = mi.denseCoderFields[num] |
| } else { |
| f = mi.coderFields[num] |
| } |
| var n int |
| err := errUnknown |
| discardUnknown := false |
| Field: |
| switch { |
| case f != nil: |
| if f.funcs.unmarshal == nil { |
| break |
| } |
| if f.isLazy && lazyDecode { |
| switch { |
| case lazyFields == nil || lazyFields[f] == lazyValidateOnly: |
| // Attempt to validate this field and leave it for later lazy unmarshaling. |
| o, valid := mi.skipField(b, f, wtyp, opts) |
| switch valid { |
| case ValidationValid: |
| // Skip over the valid field and continue. |
| err = nil |
| presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize) |
| requiredMask |= f.validation.requiredBit |
| if !o.initialized { |
| initialized = false |
| } |
| n = o.n |
| break Field |
| case ValidationInvalid: |
| return out, errors.New("invalid proto wire format") |
| case ValidationWrongWireType: |
| break Field |
| case ValidationUnknown: |
| if lazyFields == nil { |
| lazyFields = make(map[*coderFieldInfo]lazyAction) |
| } |
| if presence.Present(f.presenceIndex) { |
| // We were unable to determine if the field is valid or not, |
| // and we've already skipped over at least one instance of this |
| // field. Clear the presence bit (so if we stop decoding early, |
| // we don't leave a partially-initialized field around) and flag |
| // the field for unmarshaling before we return. |
| presence.ClearPresent(f.presenceIndex) |
| lazyFields[f] = lazyUnmarshalLater |
| discardUnknown = true |
| break Field |
| } else { |
| // We were unable to determine if the field is valid or not, |
| // but this is the first time we've seen it. Flag it as needing |
| // eager unmarshaling and fall through to the eager unmarshal case below. |
| lazyFields[f] = lazyUnmarshalNow |
| } |
| } |
| case lazyFields[f] == lazyUnmarshalLater: |
| // This field will be unmarshaled in a separate pass below. |
| // Skip over it here. |
| discardUnknown = true |
| break Field |
| default: |
| // Eagerly unmarshal the field. |
| } |
| } |
| if f.isLazy && !lazyDecode && presence.Present(f.presenceIndex) { |
| if p.Apply(f.offset).AtomicGetPointer().IsNil() { |
| mi.lazyUnmarshal(p, f.num) |
| } |
| } |
| var o unmarshalOutput |
| o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts) |
| n = o.n |
| if err != nil { |
| break |
| } |
| requiredMask |= f.validation.requiredBit |
| if f.funcs.isInit != nil && !o.initialized { |
| initialized = false |
| } |
| if f.presenceIndex != noPresence { |
| presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize) |
| } |
| default: |
| // Possible extension. |
| if exts == nil && mi.extensionOffset.IsValid() { |
| exts = p.Apply(mi.extensionOffset).Extensions() |
| if *exts == nil { |
| *exts = make(map[int32]ExtensionField) |
| } |
| } |
| if exts == nil { |
| break |
| } |
| var o unmarshalOutput |
| o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts) |
| if err != nil { |
| break |
| } |
| n = o.n |
| if !o.initialized { |
| initialized = false |
| } |
| } |
| if err != nil { |
| if err != errUnknown { |
| return out, err |
| } |
| n = protowire.ConsumeFieldValue(num, wtyp, b) |
| if n < 0 { |
| return out, errDecode |
| } |
| if !discardUnknown && !opts.DiscardUnknown() && mi.unknownOffset.IsValid() { |
| u := mi.mutableUnknownBytes(p) |
| *u = protowire.AppendTag(*u, num, wtyp) |
| *u = append(*u, b[:n]...) |
| } |
| } |
| b = b[n:] |
| end := start - len(b) |
| if lazyDecode && f != nil && f.isLazy { |
| if num != lastNum { |
| lazyIndex = append(lazyIndex, protolazy.IndexEntry{ |
| FieldNum: uint32(num), |
| Start: uint32(pos), |
| End: uint32(end), |
| }) |
| } else { |
| i := len(lazyIndex) - 1 |
| lazyIndex[i].End = uint32(end) |
| lazyIndex[i].MultipleContiguous = true |
| } |
| } |
| if num < lastNum { |
| outOfOrder = true |
| } |
| pos = end |
| lastNum = num |
| } |
| if groupTag != 0 { |
| return out, errors.New("missing end group marker") |
| } |
| if lazyFields != nil { |
| // Some fields failed validation, and now need to be unmarshaled. |
| for f, action := range lazyFields { |
| if action != lazyUnmarshalLater { |
| continue |
| } |
| initialized = false |
| if *lazy == nil { |
| *lazy = &protolazy.XXX_lazyUnmarshalInfo{} |
| } |
| if err := mi.unmarshalField((*lazy).Buffer(), p.Apply(f.offset), f, *lazy, opts.flags); err != nil { |
| return out, err |
| } |
| presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize) |
| } |
| } |
| if lazyDecode { |
| if outOfOrder { |
| sort.Slice(lazyIndex, func(i, j int) bool { |
| return lazyIndex[i].FieldNum < lazyIndex[j].FieldNum || |
| (lazyIndex[i].FieldNum == lazyIndex[j].FieldNum && |
| lazyIndex[i].Start < lazyIndex[j].Start) |
| }) |
| } |
| if *lazy == nil { |
| *lazy = &protolazy.XXX_lazyUnmarshalInfo{} |
| } |
| |
| (*lazy).SetIndex(lazyIndex) |
| } |
| if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) { |
| initialized = false |
| } |
| if initialized { |
| out.initialized = true |
| } |
| out.n = start - len(b) |
| return out, nil |
| } |