runtime/protoiface: use more efficient options representation

Change the representation of option flags in protoiface from bools to a
bitfield. This brings the representation of options in protoiface in
sync with that in internal/impl.

This change has several benefits:

1. We will probably find that we need to add more option flags over time.
Converting to the more efficient representation of these flags as high
in the call stack as possible minimizes the performance implication of
the struct growing.

2. On a similar note, this avoids the need to convert from the compact
representation to the larger one when passing from internal/impl to
proto, since the {Marshal,Unmarshal}State methods take the compact form.

3. This removes unused options from protoiface. Instead of documenting
that AllowPartial is always set, we can just not include an AllowPartial
flag in the protoiface options.

4. Conversely, this provides a way to add option flags to protoiface
that we don't want to expose in the proto package.

name                             old time/op    new time/op    delta
EmptyMessage/Wire/Marshal-12       11.1ns ± 7%    10.1ns ± 1%   -9.35%  (p=0.000 n=8+8)
EmptyMessage/Wire/Unmarshal-12     7.07ns ± 0%    6.74ns ± 1%   -4.58%  (p=0.000 n=8+8)
EmptyMessage/Wire/Validate-12      4.30ns ± 1%    3.80ns ± 8%  -11.45%  (p=0.000 n=7+8)
RepeatedInt32/Wire/Marshal-12      1.17µs ± 1%    1.21µs ± 7%   +4.09%  (p=0.000 n=8+8)
RepeatedInt32/Wire/Unmarshal-12     938ns ± 0%     942ns ± 3%     ~     (p=0.178 n=7+8)
RepeatedInt32/Wire/Validate-12      521ns ± 4%     543ns ± 7%     ~     (p=0.157 n=7+8)
Required/Wire/Marshal-12           97.2ns ± 1%    95.3ns ± 1%   -1.98%  (p=0.001 n=7+7)
Required/Wire/Unmarshal-12         41.0ns ± 9%    38.6ns ± 3%   -5.73%  (p=0.048 n=8+8)
Required/Wire/Validate-12          25.4ns ±11%    21.4ns ± 3%  -15.62%  (p=0.000 n=8+7)

Change-Id: I3ac1b00ab36cfdf61316ec087a5dd20d9248e4f6
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/216760
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/internal/impl/decode.go b/internal/impl/decode.go
index 74fd821..3cd7f5a 100644
--- a/internal/impl/decode.go
+++ b/internal/impl/decode.go
@@ -16,48 +16,18 @@
 	piface "google.golang.org/protobuf/runtime/protoiface"
 )
 
-// unmarshalOptions is a more efficient representation of UnmarshalOptions.
-//
-// We don't preserve the AllowPartial flag, because fast-path (un)marshal
-// operations always allow partial messages.
-type unmarshalOptions struct {
-	flags unmarshalOptionFlags
-
-	// Keep this field's type identical to (proto.UnmarshalOptions).Resolver
-	// to avoid a type conversion on assignment.
-	resolver interface {
-		FindExtensionByName(field pref.FullName) (pref.ExtensionType, error)
-		FindExtensionByNumber(message pref.FullName, field pref.FieldNumber) (pref.ExtensionType, error)
-	}
-}
-
-type unmarshalOptionFlags uint8
-
-const (
-	unmarshalDiscardUnknown unmarshalOptionFlags = 1 << iota
-)
-
-func newUnmarshalOptions(opts piface.UnmarshalOptions) unmarshalOptions {
-	o := unmarshalOptions{
-		resolver: opts.Resolver,
-	}
-	if opts.DiscardUnknown {
-		o.flags |= unmarshalDiscardUnknown
-	}
-	return o
-}
+type unmarshalOptions piface.UnmarshalOptions
 
 func (o unmarshalOptions) Options() proto.UnmarshalOptions {
 	return proto.UnmarshalOptions{
 		Merge:          true,
 		AllowPartial:   true,
 		DiscardUnknown: o.DiscardUnknown(),
-		Resolver:       o.Resolver(),
+		Resolver:       o.Resolver,
 	}
 }
 
-func (o unmarshalOptions) DiscardUnknown() bool                 { return o.flags&unmarshalDiscardUnknown != 0 }
-func (o unmarshalOptions) Resolver() preg.ExtensionTypeResolver { return o.resolver }
+func (o unmarshalOptions) DiscardUnknown() bool { return o.Flags&piface.UnmarshalDiscardUnknown != 0 }
 
 type unmarshalOutput struct {
 	n           int // number of bytes consumed
@@ -72,7 +42,7 @@
 	} else {
 		p = m.(*messageReflectWrapper).pointer()
 	}
-	out, err := mi.unmarshalPointer(in.Buf, p, 0, newUnmarshalOptions(opts))
+	out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions(opts))
 	return piface.UnmarshalOutput{
 		Initialized: out.initialized,
 	}, err
@@ -202,7 +172,7 @@
 	xt := x.Type()
 	if xt == nil {
 		var err error
-		xt, err = opts.Resolver().FindExtensionByNumber(mi.Desc.FullName(), num)
+		xt, err = opts.Resolver.FindExtensionByNumber(mi.Desc.FullName(), num)
 		if err != nil {
 			if err == preg.NotFound {
 				return out, errUnknown
diff --git a/internal/impl/encode.go b/internal/impl/encode.go
index 94f5b54..608e57f 100644
--- a/internal/impl/encode.go
+++ b/internal/impl/encode.go
@@ -14,27 +14,7 @@
 	piface "google.golang.org/protobuf/runtime/protoiface"
 )
 
-// marshalOptions is a more efficient representation of MarshalOptions.
-//
-// We don't preserve the AllowPartial flag, because fast-path (un)marshal
-// operations always allow partial messages.
-type marshalOptions uint
-
-const (
-	marshalDeterministic marshalOptions = 1 << iota
-	marshalUseCachedSize
-)
-
-func newMarshalOptions(opts piface.MarshalOptions) marshalOptions {
-	var o marshalOptions
-	if opts.Deterministic {
-		o |= marshalDeterministic
-	}
-	if opts.UseCachedSize {
-		o |= marshalUseCachedSize
-	}
-	return o
-}
+type marshalOptions piface.MarshalOptions
 
 func (o marshalOptions) Options() proto.MarshalOptions {
 	return proto.MarshalOptions{
@@ -44,8 +24,8 @@
 	}
 }
 
-func (o marshalOptions) Deterministic() bool { return o&marshalDeterministic != 0 }
-func (o marshalOptions) UseCachedSize() bool { return o&marshalUseCachedSize != 0 }
+func (o marshalOptions) Deterministic() bool { return o.Flags&piface.MarshalDeterministic != 0 }
+func (o marshalOptions) UseCachedSize() bool { return o.Flags&piface.MarshalUseCachedSize != 0 }
 
 // size is protoreflect.Methods.Size.
 func (mi *MessageInfo) size(m pref.Message, opts piface.MarshalOptions) (size int) {
@@ -55,7 +35,7 @@
 	} else {
 		p = m.(*messageReflectWrapper).pointer()
 	}
-	return mi.sizePointer(p, newMarshalOptions(opts))
+	return mi.sizePointer(p, marshalOptions(opts))
 }
 
 func (mi *MessageInfo) sizePointer(p pointer, opts marshalOptions) (size int) {
@@ -109,7 +89,7 @@
 	} else {
 		p = m.(*messageReflectWrapper).pointer()
 	}
-	b, err := mi.marshalAppendPointer(in.Buf, p, newMarshalOptions(opts))
+	b, err := mi.marshalAppendPointer(in.Buf, p, marshalOptions(opts))
 	return piface.MarshalOutput{Buf: b}, err
 }
 
diff --git a/internal/impl/validate.go b/internal/impl/validate.go
index 4ae1d03..9e33979 100644
--- a/internal/impl/validate.go
+++ b/internal/impl/validate.go
@@ -69,7 +69,7 @@
 	if !ok {
 		return ValidationUnknown
 	}
-	return mi.validate(b, 0, newUnmarshalOptions(opts))
+	return mi.validate(b, 0, unmarshalOptions(opts))
 }
 
 type validationInfo struct {
@@ -330,7 +330,7 @@
 				// In this case, a type added to the resolver in the future could cause
 				// unmarshaling to begin failing. Supporting this requires some way to
 				// determine if the resolver is frozen.
-				xt, err := opts.Resolver().FindExtensionByNumber(st.mi.Desc.FullName(), num)
+				xt, err := opts.Resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num)
 				if err != nil && err != preg.NotFound {
 					return ValidationUnknown
 				}
diff --git a/proto/decode.go b/proto/decode.go
index 9a6b2f7..f3cd997 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -43,8 +43,6 @@
 	}
 }
 
-var _ = protoiface.UnmarshalOptions(UnmarshalOptions{})
-
 // Unmarshal parses the wire-format message in b and places the result in m.
 func Unmarshal(b []byte, m Message) error {
 	_, err := UnmarshalOptions{}.unmarshal(b, m)
@@ -79,9 +77,15 @@
 	methods := protoMethods(m)
 	if methods != nil && methods.Unmarshal != nil &&
 		!(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) {
+		opts := protoiface.UnmarshalOptions{
+			Resolver: o.Resolver,
+		}
+		if o.DiscardUnknown {
+			opts.Flags |= protoiface.UnmarshalDiscardUnknown
+		}
 		out, err = methods.Unmarshal(m, protoiface.UnmarshalInput{
 			Buf: b,
-		}, protoiface.UnmarshalOptions(o))
+		}, opts)
 	} else {
 		err = o.unmarshalMessageSlow(b, m)
 	}
diff --git a/proto/encode.go b/proto/encode.go
index 3afa331..65e951e 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -72,8 +72,6 @@
 	UseCachedSize bool
 }
 
-var _ = protoiface.MarshalOptions(MarshalOptions{})
-
 // Marshal returns the wire-format encoding of m.
 func Marshal(m Message) ([]byte, error) {
 	out, err := MarshalOptions{}.marshal(nil, m)
@@ -107,18 +105,25 @@
 	m := message.ProtoReflect()
 	if methods := protoMethods(m); methods != nil && methods.Marshal != nil &&
 		!(o.Deterministic && methods.Flags&protoiface.SupportMarshalDeterministic == 0) {
+		opts := protoiface.MarshalOptions{}
+		if o.Deterministic {
+			opts.Flags |= protoiface.MarshalDeterministic
+		}
+		if o.UseCachedSize {
+			opts.Flags |= protoiface.MarshalUseCachedSize
+		}
 		if methods.Size != nil {
-			sz := methods.Size(m, protoiface.MarshalOptions(o))
+			sz := methods.Size(m, opts)
 			if cap(b) < len(b)+sz {
 				x := make([]byte, len(b), growcap(cap(b), len(b)+sz))
 				copy(x, b)
 				b = x
 			}
-			o.UseCachedSize = true
+			opts.Flags |= protoiface.MarshalUseCachedSize
 		}
 		out, err = methods.Marshal(m, protoiface.MarshalInput{
 			Buf: b,
-		}, protoiface.MarshalOptions(o))
+		}, opts)
 	} else {
 		out.Buf, err = o.marshalMessageSlow(b, m)
 	}
diff --git a/reflect/protoreflect/methods.go b/reflect/protoreflect/methods.go
index fd4e07b..5a29d2c 100644
--- a/reflect/protoreflect/methods.go
+++ b/reflect/protoreflect/methods.go
@@ -34,9 +34,7 @@
 	}
 	marshalOptions = struct {
 		pragma.NoUnkeyedLiterals
-		AllowPartial  bool
-		Deterministic bool
-		UseCachedSize bool
+		Flags uint8
 	}
 	unmarshalInput = struct {
 		pragma.NoUnkeyedLiterals
@@ -48,10 +46,8 @@
 	}
 	unmarshalOptions = struct {
 		pragma.NoUnkeyedLiterals
-		Merge          bool
-		AllowPartial   bool
-		DiscardUnknown bool
-		Resolver       interface {
+		Flags    uint8
+		Resolver interface {
 			FindExtensionByName(field FullName) (ExtensionType, error)
 			FindExtensionByNumber(message FullName, field FieldNumber) (ExtensionType, error)
 		}
diff --git a/runtime/protoiface/methods.go b/runtime/protoiface/methods.go
index d5a7677..54e7fb3 100644
--- a/runtime/protoiface/methods.go
+++ b/runtime/protoiface/methods.go
@@ -63,16 +63,21 @@
 }
 
 // MarshalOptions configure the marshaler.
-//
-// This type is identical to the one in package proto.
 type MarshalOptions = struct {
 	pragma.NoUnkeyedLiterals
 
-	AllowPartial  bool // may be treated as true by method implementations
-	Deterministic bool
-	UseCachedSize bool
+	Flags MarshalFlags
 }
 
+// MarshalFlags are configure the marshaler.
+// Most flags correspond to fields in proto.MarshalOptions.
+type MarshalFlags = uint8
+
+const (
+	MarshalDeterministic MarshalFlags = 1 << iota
+	MarshalUseCachedSize
+)
+
 // UnmarshalInput is input to the unmarshaler.
 type UnmarshalInput = struct {
 	pragma.NoUnkeyedLiterals
@@ -91,16 +96,20 @@
 }
 
 // UnmarshalOptions configures the unmarshaler.
-//
-// This type is identical to the one in package proto.
 type UnmarshalOptions = struct {
 	pragma.NoUnkeyedLiterals
 
-	Merge          bool // may be treated as true by method implementations
-	AllowPartial   bool // may be treated as true by method implementations
-	DiscardUnknown bool
-	Resolver       interface {
+	Flags    UnmarshalFlags
+	Resolver interface {
 		FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
 		FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
 	}
 }
+
+// UnmarshalFlags configure the unmarshaler.
+// Most flags correspond to fields in proto.UnmarshalOptions.
+type UnmarshalFlags = uint8
+
+const (
+	UnmarshalDiscardUnknown UnmarshalFlags = 1 << iota
+)