proto: ensure MarshalOptions are plumbed to all Size calls This used to not be necessary, but a subsequent change will change behavior based on MarshalOptions.Deterministic, so it is important that we do not accidentally drop MarshalOptions anywhere. related to golang/protobuf#1609 Change-Id: I6b53352707d93939642a627eb41c930f8ac3157f Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/579995 LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Lasse Folger <lassefolger@google.com>
diff --git a/internal/impl/codec_field.go b/internal/impl/codec_field.go index e8040ef..78ee47e 100644 --- a/internal/impl/codec_field.go +++ b/internal/impl/codec_field.go
@@ -268,16 +268,17 @@ return f.mi.checkInitializedPointer(p.Elem()) } -func sizeMessage(m proto.Message, tagsize int, _ marshalOptions) int { - return protowire.SizeBytes(proto.Size(m)) + tagsize +func sizeMessage(m proto.Message, tagsize int, opts marshalOptions) int { + return protowire.SizeBytes(opts.Options().Size(m)) + tagsize } func appendMessage(b []byte, m proto.Message, wiretag uint64, opts marshalOptions) ([]byte, error) { - calculatedSize := proto.Size(m) + mopts := opts.Options() + calculatedSize := mopts.Size(m) b = protowire.AppendVarint(b, wiretag) b = protowire.AppendVarint(b, uint64(calculatedSize)) before := len(b) - b, err := opts.Options().MarshalAppend(b, m) + b, err := mopts.MarshalAppend(b, m) if measuredSize := len(b) - before; calculatedSize != measuredSize && err == nil { return nil, errors.MismatchedSizeCalculation(calculatedSize, measuredSize) } @@ -417,8 +418,8 @@ return f.mi.unmarshalPointer(b, p.Elem(), f.num, opts) } -func sizeGroup(m proto.Message, tagsize int, _ marshalOptions) int { - return 2*tagsize + proto.Size(m) +func sizeGroup(m proto.Message, tagsize int, opts marshalOptions) int { + return 2*tagsize + opts.Options().Size(m) } func appendGroup(b []byte, m proto.Message, wiretag uint64, opts marshalOptions) ([]byte, error) { @@ -536,26 +537,28 @@ return nil } -func sizeMessageSlice(p pointer, goType reflect.Type, tagsize int, _ marshalOptions) int { +func sizeMessageSlice(p pointer, goType reflect.Type, tagsize int, opts marshalOptions) int { + mopts := opts.Options() s := p.PointerSlice() n := 0 for _, v := range s { m := asMessage(v.AsValueOf(goType.Elem())) - n += protowire.SizeBytes(proto.Size(m)) + tagsize + n += protowire.SizeBytes(mopts.Size(m)) + tagsize } return n } func appendMessageSlice(b []byte, p pointer, wiretag uint64, goType reflect.Type, opts marshalOptions) ([]byte, error) { + mopts := opts.Options() s := p.PointerSlice() var err error for _, v := range s { m := asMessage(v.AsValueOf(goType.Elem())) b = protowire.AppendVarint(b, wiretag) - siz := proto.Size(m) + siz := mopts.Size(m) b = protowire.AppendVarint(b, uint64(siz)) before := len(b) - b, err = opts.Options().MarshalAppend(b, m) + b, err = mopts.MarshalAppend(b, m) if err != nil { return b, err } @@ -602,11 +605,12 @@ // Slices of messages func sizeMessageSliceValue(listv protoreflect.Value, tagsize int, opts marshalOptions) int { + mopts := opts.Options() list := listv.List() n := 0 for i, llen := 0, list.Len(); i < llen; i++ { m := list.Get(i).Message().Interface() - n += protowire.SizeBytes(proto.Size(m)) + tagsize + n += protowire.SizeBytes(mopts.Size(m)) + tagsize } return n } @@ -617,7 +621,7 @@ for i, llen := 0, list.Len(); i < llen; i++ { m := list.Get(i).Message().Interface() b = protowire.AppendVarint(b, wiretag) - siz := proto.Size(m) + siz := mopts.Size(m) b = protowire.AppendVarint(b, uint64(siz)) before := len(b) var err error @@ -675,11 +679,12 @@ } func sizeGroupSliceValue(listv protoreflect.Value, tagsize int, opts marshalOptions) int { + mopts := opts.Options() list := listv.List() n := 0 for i, llen := 0, list.Len(); i < llen; i++ { m := list.Get(i).Message().Interface() - n += 2*tagsize + proto.Size(m) + n += 2*tagsize + mopts.Size(m) } return n } @@ -762,12 +767,13 @@ } } -func sizeGroupSlice(p pointer, messageType reflect.Type, tagsize int, _ marshalOptions) int { +func sizeGroupSlice(p pointer, messageType reflect.Type, tagsize int, opts marshalOptions) int { + mopts := opts.Options() s := p.PointerSlice() n := 0 for _, v := range s { m := asMessage(v.AsValueOf(messageType.Elem())) - n += 2*tagsize + proto.Size(m) + n += 2*tagsize + mopts.Size(m) } return n }
diff --git a/proto/encode.go b/proto/encode.go index c779f6d..1f847bc 100644 --- a/proto/encode.go +++ b/proto/encode.go
@@ -75,6 +75,27 @@ UseCachedSize bool } +// flags turns the specified MarshalOptions (user-facing) into +// protoiface.MarshalInputFlags (used internally by the marshaler). +// +// See impl.marshalOptions.Options for the inverse operation. +func (o MarshalOptions) flags() protoiface.MarshalInputFlags { + var flags protoiface.MarshalInputFlags + + // Note: o.AllowPartial is always forced to true by MarshalOptions.marshal, + // which is why it is not a part of MarshalInputFlags. + + if o.Deterministic { + flags |= protoiface.MarshalDeterministic + } + + if o.UseCachedSize { + flags |= protoiface.MarshalUseCachedSize + } + + return flags +} + // Marshal returns the wire-format encoding of m. // // This is the most common entry point for encoding a Protobuf message. @@ -157,12 +178,7 @@ in := protoiface.MarshalInput{ Message: m, Buf: b, - } - if o.Deterministic { - in.Flags |= protoiface.MarshalDeterministic - } - if o.UseCachedSize { - in.Flags |= protoiface.MarshalUseCachedSize + Flags: o.flags(), } if methods.Size != nil { sout := methods.Size(protoiface.SizeInput{
diff --git a/proto/size.go b/proto/size.go index f1692b4..052fb5a 100644 --- a/proto/size.go +++ b/proto/size.go
@@ -34,6 +34,7 @@ if methods != nil && methods.Size != nil { out := methods.Size(protoiface.SizeInput{ Message: m, + Flags: o.flags(), }) return out.Size } @@ -42,6 +43,7 @@ // This case is mainly used for legacy types with a Marshal method. out, _ := methods.Marshal(protoiface.MarshalInput{ Message: m, + Flags: o.flags(), }) return len(out.Buf) }