proto: ensure MarshalOptions are plumbed to all Size calls
This used to not be necessary, but a subsequent change will change behavior
based on MarshalOptions.Deterministic, so it is important that we do not
accidentally drop MarshalOptions anywhere.
related to golang/protobuf#1609
Change-Id: I6b53352707d93939642a627eb41c930f8ac3157f
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/579995
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Lasse Folger <lassefolger@google.com>
diff --git a/internal/impl/codec_field.go b/internal/impl/codec_field.go
index e8040ef..78ee47e 100644
--- a/internal/impl/codec_field.go
+++ b/internal/impl/codec_field.go
@@ -268,16 +268,17 @@
return f.mi.checkInitializedPointer(p.Elem())
}
-func sizeMessage(m proto.Message, tagsize int, _ marshalOptions) int {
- return protowire.SizeBytes(proto.Size(m)) + tagsize
+func sizeMessage(m proto.Message, tagsize int, opts marshalOptions) int {
+ return protowire.SizeBytes(opts.Options().Size(m)) + tagsize
}
func appendMessage(b []byte, m proto.Message, wiretag uint64, opts marshalOptions) ([]byte, error) {
- calculatedSize := proto.Size(m)
+ mopts := opts.Options()
+ calculatedSize := mopts.Size(m)
b = protowire.AppendVarint(b, wiretag)
b = protowire.AppendVarint(b, uint64(calculatedSize))
before := len(b)
- b, err := opts.Options().MarshalAppend(b, m)
+ b, err := mopts.MarshalAppend(b, m)
if measuredSize := len(b) - before; calculatedSize != measuredSize && err == nil {
return nil, errors.MismatchedSizeCalculation(calculatedSize, measuredSize)
}
@@ -417,8 +418,8 @@
return f.mi.unmarshalPointer(b, p.Elem(), f.num, opts)
}
-func sizeGroup(m proto.Message, tagsize int, _ marshalOptions) int {
- return 2*tagsize + proto.Size(m)
+func sizeGroup(m proto.Message, tagsize int, opts marshalOptions) int {
+ return 2*tagsize + opts.Options().Size(m)
}
func appendGroup(b []byte, m proto.Message, wiretag uint64, opts marshalOptions) ([]byte, error) {
@@ -536,26 +537,28 @@
return nil
}
-func sizeMessageSlice(p pointer, goType reflect.Type, tagsize int, _ marshalOptions) int {
+func sizeMessageSlice(p pointer, goType reflect.Type, tagsize int, opts marshalOptions) int {
+ mopts := opts.Options()
s := p.PointerSlice()
n := 0
for _, v := range s {
m := asMessage(v.AsValueOf(goType.Elem()))
- n += protowire.SizeBytes(proto.Size(m)) + tagsize
+ n += protowire.SizeBytes(mopts.Size(m)) + tagsize
}
return n
}
func appendMessageSlice(b []byte, p pointer, wiretag uint64, goType reflect.Type, opts marshalOptions) ([]byte, error) {
+ mopts := opts.Options()
s := p.PointerSlice()
var err error
for _, v := range s {
m := asMessage(v.AsValueOf(goType.Elem()))
b = protowire.AppendVarint(b, wiretag)
- siz := proto.Size(m)
+ siz := mopts.Size(m)
b = protowire.AppendVarint(b, uint64(siz))
before := len(b)
- b, err = opts.Options().MarshalAppend(b, m)
+ b, err = mopts.MarshalAppend(b, m)
if err != nil {
return b, err
}
@@ -602,11 +605,12 @@
// Slices of messages
func sizeMessageSliceValue(listv protoreflect.Value, tagsize int, opts marshalOptions) int {
+ mopts := opts.Options()
list := listv.List()
n := 0
for i, llen := 0, list.Len(); i < llen; i++ {
m := list.Get(i).Message().Interface()
- n += protowire.SizeBytes(proto.Size(m)) + tagsize
+ n += protowire.SizeBytes(mopts.Size(m)) + tagsize
}
return n
}
@@ -617,7 +621,7 @@
for i, llen := 0, list.Len(); i < llen; i++ {
m := list.Get(i).Message().Interface()
b = protowire.AppendVarint(b, wiretag)
- siz := proto.Size(m)
+ siz := mopts.Size(m)
b = protowire.AppendVarint(b, uint64(siz))
before := len(b)
var err error
@@ -675,11 +679,12 @@
}
func sizeGroupSliceValue(listv protoreflect.Value, tagsize int, opts marshalOptions) int {
+ mopts := opts.Options()
list := listv.List()
n := 0
for i, llen := 0, list.Len(); i < llen; i++ {
m := list.Get(i).Message().Interface()
- n += 2*tagsize + proto.Size(m)
+ n += 2*tagsize + mopts.Size(m)
}
return n
}
@@ -762,12 +767,13 @@
}
}
-func sizeGroupSlice(p pointer, messageType reflect.Type, tagsize int, _ marshalOptions) int {
+func sizeGroupSlice(p pointer, messageType reflect.Type, tagsize int, opts marshalOptions) int {
+ mopts := opts.Options()
s := p.PointerSlice()
n := 0
for _, v := range s {
m := asMessage(v.AsValueOf(messageType.Elem()))
- n += 2*tagsize + proto.Size(m)
+ n += 2*tagsize + mopts.Size(m)
}
return n
}
diff --git a/proto/encode.go b/proto/encode.go
index c779f6d..1f847bc 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -75,6 +75,27 @@
UseCachedSize bool
}
+// flags turns the specified MarshalOptions (user-facing) into
+// protoiface.MarshalInputFlags (used internally by the marshaler).
+//
+// See impl.marshalOptions.Options for the inverse operation.
+func (o MarshalOptions) flags() protoiface.MarshalInputFlags {
+ var flags protoiface.MarshalInputFlags
+
+ // Note: o.AllowPartial is always forced to true by MarshalOptions.marshal,
+ // which is why it is not a part of MarshalInputFlags.
+
+ if o.Deterministic {
+ flags |= protoiface.MarshalDeterministic
+ }
+
+ if o.UseCachedSize {
+ flags |= protoiface.MarshalUseCachedSize
+ }
+
+ return flags
+}
+
// Marshal returns the wire-format encoding of m.
//
// This is the most common entry point for encoding a Protobuf message.
@@ -157,12 +178,7 @@
in := protoiface.MarshalInput{
Message: m,
Buf: b,
- }
- if o.Deterministic {
- in.Flags |= protoiface.MarshalDeterministic
- }
- if o.UseCachedSize {
- in.Flags |= protoiface.MarshalUseCachedSize
+ Flags: o.flags(),
}
if methods.Size != nil {
sout := methods.Size(protoiface.SizeInput{
diff --git a/proto/size.go b/proto/size.go
index f1692b4..052fb5a 100644
--- a/proto/size.go
+++ b/proto/size.go
@@ -34,6 +34,7 @@
if methods != nil && methods.Size != nil {
out := methods.Size(protoiface.SizeInput{
Message: m,
+ Flags: o.flags(),
})
return out.Size
}
@@ -42,6 +43,7 @@
// This case is mainly used for legacy types with a Marshal method.
out, _ := methods.Marshal(protoiface.MarshalInput{
Message: m,
+ Flags: o.flags(),
})
return len(out.Buf)
}