proto: ensure MarshalOptions are plumbed to all Size calls

This used to not be necessary, but a subsequent change will change behavior
based on MarshalOptions.Deterministic, so it is important that we do not
accidentally drop MarshalOptions anywhere.

related to golang/protobuf#1609

Change-Id: I6b53352707d93939642a627eb41c930f8ac3157f
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/579995
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Lasse Folger <lassefolger@google.com>
diff --git a/internal/impl/codec_field.go b/internal/impl/codec_field.go
index e8040ef..78ee47e 100644
--- a/internal/impl/codec_field.go
+++ b/internal/impl/codec_field.go
@@ -268,16 +268,17 @@
 	return f.mi.checkInitializedPointer(p.Elem())
 }
 
-func sizeMessage(m proto.Message, tagsize int, _ marshalOptions) int {
-	return protowire.SizeBytes(proto.Size(m)) + tagsize
+func sizeMessage(m proto.Message, tagsize int, opts marshalOptions) int {
+	return protowire.SizeBytes(opts.Options().Size(m)) + tagsize
 }
 
 func appendMessage(b []byte, m proto.Message, wiretag uint64, opts marshalOptions) ([]byte, error) {
-	calculatedSize := proto.Size(m)
+	mopts := opts.Options()
+	calculatedSize := mopts.Size(m)
 	b = protowire.AppendVarint(b, wiretag)
 	b = protowire.AppendVarint(b, uint64(calculatedSize))
 	before := len(b)
-	b, err := opts.Options().MarshalAppend(b, m)
+	b, err := mopts.MarshalAppend(b, m)
 	if measuredSize := len(b) - before; calculatedSize != measuredSize && err == nil {
 		return nil, errors.MismatchedSizeCalculation(calculatedSize, measuredSize)
 	}
@@ -417,8 +418,8 @@
 	return f.mi.unmarshalPointer(b, p.Elem(), f.num, opts)
 }
 
-func sizeGroup(m proto.Message, tagsize int, _ marshalOptions) int {
-	return 2*tagsize + proto.Size(m)
+func sizeGroup(m proto.Message, tagsize int, opts marshalOptions) int {
+	return 2*tagsize + opts.Options().Size(m)
 }
 
 func appendGroup(b []byte, m proto.Message, wiretag uint64, opts marshalOptions) ([]byte, error) {
@@ -536,26 +537,28 @@
 	return nil
 }
 
-func sizeMessageSlice(p pointer, goType reflect.Type, tagsize int, _ marshalOptions) int {
+func sizeMessageSlice(p pointer, goType reflect.Type, tagsize int, opts marshalOptions) int {
+	mopts := opts.Options()
 	s := p.PointerSlice()
 	n := 0
 	for _, v := range s {
 		m := asMessage(v.AsValueOf(goType.Elem()))
-		n += protowire.SizeBytes(proto.Size(m)) + tagsize
+		n += protowire.SizeBytes(mopts.Size(m)) + tagsize
 	}
 	return n
 }
 
 func appendMessageSlice(b []byte, p pointer, wiretag uint64, goType reflect.Type, opts marshalOptions) ([]byte, error) {
+	mopts := opts.Options()
 	s := p.PointerSlice()
 	var err error
 	for _, v := range s {
 		m := asMessage(v.AsValueOf(goType.Elem()))
 		b = protowire.AppendVarint(b, wiretag)
-		siz := proto.Size(m)
+		siz := mopts.Size(m)
 		b = protowire.AppendVarint(b, uint64(siz))
 		before := len(b)
-		b, err = opts.Options().MarshalAppend(b, m)
+		b, err = mopts.MarshalAppend(b, m)
 		if err != nil {
 			return b, err
 		}
@@ -602,11 +605,12 @@
 // Slices of messages
 
 func sizeMessageSliceValue(listv protoreflect.Value, tagsize int, opts marshalOptions) int {
+	mopts := opts.Options()
 	list := listv.List()
 	n := 0
 	for i, llen := 0, list.Len(); i < llen; i++ {
 		m := list.Get(i).Message().Interface()
-		n += protowire.SizeBytes(proto.Size(m)) + tagsize
+		n += protowire.SizeBytes(mopts.Size(m)) + tagsize
 	}
 	return n
 }
@@ -617,7 +621,7 @@
 	for i, llen := 0, list.Len(); i < llen; i++ {
 		m := list.Get(i).Message().Interface()
 		b = protowire.AppendVarint(b, wiretag)
-		siz := proto.Size(m)
+		siz := mopts.Size(m)
 		b = protowire.AppendVarint(b, uint64(siz))
 		before := len(b)
 		var err error
@@ -675,11 +679,12 @@
 }
 
 func sizeGroupSliceValue(listv protoreflect.Value, tagsize int, opts marshalOptions) int {
+	mopts := opts.Options()
 	list := listv.List()
 	n := 0
 	for i, llen := 0, list.Len(); i < llen; i++ {
 		m := list.Get(i).Message().Interface()
-		n += 2*tagsize + proto.Size(m)
+		n += 2*tagsize + mopts.Size(m)
 	}
 	return n
 }
@@ -762,12 +767,13 @@
 	}
 }
 
-func sizeGroupSlice(p pointer, messageType reflect.Type, tagsize int, _ marshalOptions) int {
+func sizeGroupSlice(p pointer, messageType reflect.Type, tagsize int, opts marshalOptions) int {
+	mopts := opts.Options()
 	s := p.PointerSlice()
 	n := 0
 	for _, v := range s {
 		m := asMessage(v.AsValueOf(messageType.Elem()))
-		n += 2*tagsize + proto.Size(m)
+		n += 2*tagsize + mopts.Size(m)
 	}
 	return n
 }
diff --git a/proto/encode.go b/proto/encode.go
index c779f6d..1f847bc 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -75,6 +75,27 @@
 	UseCachedSize bool
 }
 
+// flags turns the specified MarshalOptions (user-facing) into
+// protoiface.MarshalInputFlags (used internally by the marshaler).
+//
+// See impl.marshalOptions.Options for the inverse operation.
+func (o MarshalOptions) flags() protoiface.MarshalInputFlags {
+	var flags protoiface.MarshalInputFlags
+
+	// Note: o.AllowPartial is always forced to true by MarshalOptions.marshal,
+	// which is why it is not a part of MarshalInputFlags.
+
+	if o.Deterministic {
+		flags |= protoiface.MarshalDeterministic
+	}
+
+	if o.UseCachedSize {
+		flags |= protoiface.MarshalUseCachedSize
+	}
+
+	return flags
+}
+
 // Marshal returns the wire-format encoding of m.
 //
 // This is the most common entry point for encoding a Protobuf message.
@@ -157,12 +178,7 @@
 		in := protoiface.MarshalInput{
 			Message: m,
 			Buf:     b,
-		}
-		if o.Deterministic {
-			in.Flags |= protoiface.MarshalDeterministic
-		}
-		if o.UseCachedSize {
-			in.Flags |= protoiface.MarshalUseCachedSize
+			Flags:   o.flags(),
 		}
 		if methods.Size != nil {
 			sout := methods.Size(protoiface.SizeInput{
diff --git a/proto/size.go b/proto/size.go
index f1692b4..052fb5a 100644
--- a/proto/size.go
+++ b/proto/size.go
@@ -34,6 +34,7 @@
 	if methods != nil && methods.Size != nil {
 		out := methods.Size(protoiface.SizeInput{
 			Message: m,
+			Flags:   o.flags(),
 		})
 		return out.Size
 	}
@@ -42,6 +43,7 @@
 		// This case is mainly used for legacy types with a Marshal method.
 		out, _ := methods.Marshal(protoiface.MarshalInput{
 			Message: m,
+			Flags:   o.flags(),
 		})
 		return len(out.Buf)
 	}