all: fast-path method refactoring

Move all fast-path inputs and outputs into the Input/Output structs.
Collapse all booleans into bitfields.

Change-Id: I79ebfbac9cd1d8ef5ec17c4f955311db007391ca
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/219505
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/internal/benchmarks/micro/micro_test.go b/internal/benchmarks/micro/micro_test.go
index e4ff349..dd96744 100644
--- a/internal/benchmarks/micro/micro_test.go
+++ b/internal/benchmarks/micro/micro_test.go
@@ -13,7 +13,6 @@
 
 	"google.golang.org/protobuf/internal/impl"
 	"google.golang.org/protobuf/proto"
-	"google.golang.org/protobuf/reflect/protoregistry"
 	"google.golang.org/protobuf/runtime/protoiface"
 	"google.golang.org/protobuf/types/known/emptypb"
 
@@ -51,11 +50,8 @@
 	b.Run("Wire/Validate", func(b *testing.B) {
 		b.RunParallel(func(pb *testing.PB) {
 			mt := (&emptypb.Empty{}).ProtoReflect().Type()
-			opts := protoiface.UnmarshalOptions{
-				Resolver: protoregistry.GlobalTypes,
-			}
 			for pb.Next() {
-				_, got := impl.Validate([]byte{}, mt, opts)
+				_, got := impl.Validate(mt, protoiface.UnmarshalInput{})
 				want := impl.ValidationValid
 				if got != want {
 					b.Fatalf("Validate = %v, want %v", got, want)
@@ -112,11 +108,10 @@
 	b.Run("Wire/Validate", func(b *testing.B) {
 		b.RunParallel(func(pb *testing.PB) {
 			mt := (&testpb.TestAllTypes{}).ProtoReflect().Type()
-			opts := protoiface.UnmarshalOptions{
-				Resolver: protoregistry.GlobalTypes,
-			}
 			for pb.Next() {
-				_, got := impl.Validate(w, mt, opts)
+				_, got := impl.Validate(mt, protoiface.UnmarshalInput{
+					Buf: w,
+				})
 				want := impl.ValidationValid
 				if got != want {
 					b.Fatalf("Validate = %v, want %v", got, want)
@@ -182,11 +177,10 @@
 	b.Run("Wire/Validate", func(b *testing.B) {
 		b.RunParallel(func(pb *testing.PB) {
 			mt := (&micropb.SixteenRequired{}).ProtoReflect().Type()
-			opts := protoiface.UnmarshalOptions{
-				Resolver: protoregistry.GlobalTypes,
-			}
 			for pb.Next() {
-				_, got := impl.Validate(w, mt, opts)
+				_, got := impl.Validate(mt, protoiface.UnmarshalInput{
+					Buf: w,
+				})
 				want := impl.ValidationValid
 				if got != want {
 					b.Fatalf("Validate = %v, want %v", got, want)
diff --git a/internal/fuzz/wirefuzz/fuzz.go b/internal/fuzz/wirefuzz/fuzz.go
index 28aed57..01463f8 100644
--- a/internal/fuzz/wirefuzz/fuzz.go
+++ b/internal/fuzz/wirefuzz/fuzz.go
@@ -10,7 +10,6 @@
 
 	"google.golang.org/protobuf/internal/impl"
 	"google.golang.org/protobuf/proto"
-	"google.golang.org/protobuf/reflect/protoregistry"
 	piface "google.golang.org/protobuf/runtime/protoiface"
 
 	fuzzpb "google.golang.org/protobuf/internal/testprotos/fuzz"
@@ -19,9 +18,10 @@
 // Fuzz is a fuzzer for proto.Marshal and proto.Unmarshal.
 func Fuzz(data []byte) (score int) {
 	m1 := &fuzzpb.Fuzz{}
-	vout, valid := impl.Validate(data, m1.ProtoReflect().Type(), piface.UnmarshalOptions{
-		Resolver: protoregistry.GlobalTypes,
+	vout, valid := impl.Validate(m1.ProtoReflect().Type(), piface.UnmarshalInput{
+		Buf: data,
 	})
+	vinit := vout.Flags&piface.UnmarshalInitialized != 0
 	if err := (proto.UnmarshalOptions{
 		AllowPartial: true,
 	}).Unmarshal(data, m1); err != nil {
@@ -39,7 +39,7 @@
 	default:
 		panic("unmarshal ok with validation status: " + valid.String())
 	}
-	if proto.IsInitialized(m1) != nil && vout.Initialized {
+	if proto.IsInitialized(m1) != nil && vinit {
 		panic("validation reports partial message is initialized")
 	}
 	data1, err := proto.MarshalOptions{
diff --git a/internal/impl/codec_field.go b/internal/impl/codec_field.go
index 98992e7..bc998d3 100644
--- a/internal/impl/codec_field.go
+++ b/internal/impl/codec_field.go
@@ -285,7 +285,7 @@
 		return out, err
 	}
 	out.n = n
-	out.initialized = o.Initialized
+	out.initialized = o.Flags&piface.UnmarshalInitialized != 0
 	return out, nil
 }
 
@@ -428,7 +428,7 @@
 		return out, err
 	}
 	out.n = n
-	out.initialized = o.Initialized
+	out.initialized = o.Flags&piface.UnmarshalInitialized != 0
 	return out, nil
 }
 
@@ -559,7 +559,7 @@
 	}
 	p.AppendPointerSlice(pointerOfValue(mp))
 	out.n = n
-	out.initialized = o.Initialized
+	out.initialized = o.Flags&piface.UnmarshalInitialized != 0
 	return out, nil
 }
 
@@ -621,7 +621,7 @@
 	}
 	list.Append(m)
 	out.n = n
-	out.initialized = o.Initialized
+	out.initialized = o.Flags&piface.UnmarshalInitialized != 0
 	return listv, out, nil
 }
 
@@ -688,7 +688,7 @@
 	}
 	list.Append(m)
 	out.n = n
-	out.initialized = o.Initialized
+	out.initialized = o.Flags&piface.UnmarshalInitialized != 0
 	return listv, out, nil
 }
 
@@ -773,7 +773,7 @@
 	}
 	p.AppendPointerSlice(pointerOfValue(mp))
 	out.n = n
-	out.initialized = o.Initialized
+	out.initialized = o.Flags&piface.UnmarshalInitialized != 0
 	return out, nil
 }
 
diff --git a/internal/impl/decode.go b/internal/impl/decode.go
index f3dcdf4..5a19f23 100644
--- a/internal/impl/decode.go
+++ b/internal/impl/decode.go
@@ -11,30 +11,37 @@
 	"google.golang.org/protobuf/internal/errors"
 	"google.golang.org/protobuf/internal/flags"
 	"google.golang.org/protobuf/proto"
-	pref "google.golang.org/protobuf/reflect/protoreflect"
+	"google.golang.org/protobuf/reflect/protoreflect"
 	preg "google.golang.org/protobuf/reflect/protoregistry"
+	"google.golang.org/protobuf/runtime/protoiface"
 	piface "google.golang.org/protobuf/runtime/protoiface"
 )
 
-type unmarshalOptions piface.UnmarshalOptions
+type unmarshalOptions struct {
+	flags    protoiface.UnmarshalInputFlags
+	resolver interface {
+		FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
+		FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
+	}
+}
 
 func (o unmarshalOptions) Options() proto.UnmarshalOptions {
 	return proto.UnmarshalOptions{
 		Merge:          true,
 		AllowPartial:   true,
 		DiscardUnknown: o.DiscardUnknown(),
-		Resolver:       o.Resolver,
+		Resolver:       o.resolver,
 	}
 }
 
-func (o unmarshalOptions) DiscardUnknown() bool { return o.Flags&piface.UnmarshalDiscardUnknown != 0 }
+func (o unmarshalOptions) DiscardUnknown() bool { return o.flags&piface.UnmarshalDiscardUnknown != 0 }
 
 func (o unmarshalOptions) IsDefault() bool {
-	return o.Flags == 0 && o.Resolver == preg.GlobalTypes
+	return o.flags == 0 && o.resolver == preg.GlobalTypes
 }
 
 var lazyUnmarshalOptions = unmarshalOptions{
-	Resolver: preg.GlobalTypes,
+	resolver: preg.GlobalTypes,
 }
 
 type unmarshalOutput struct {
@@ -43,16 +50,23 @@
 }
 
 // unmarshal is protoreflect.Methods.Unmarshal.
-func (mi *MessageInfo) unmarshal(m pref.Message, in piface.UnmarshalInput, opts piface.UnmarshalOptions) (piface.UnmarshalOutput, error) {
+func (mi *MessageInfo) unmarshal(in piface.UnmarshalInput) (piface.UnmarshalOutput, error) {
 	var p pointer
-	if ms, ok := m.(*messageState); ok {
+	if ms, ok := in.Message.(*messageState); ok {
 		p = ms.pointer()
 	} else {
-		p = m.(*messageReflectWrapper).pointer()
+		p = in.Message.(*messageReflectWrapper).pointer()
 	}
-	out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions(opts))
+	out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{
+		flags:    in.Flags,
+		resolver: in.Resolver,
+	})
+	var flags piface.UnmarshalOutputFlags
+	if out.initialized {
+		flags |= piface.UnmarshalInitialized
+	}
 	return piface.UnmarshalOutput{
-		Initialized: out.initialized,
+		Flags: flags,
 	}, err
 }
 
@@ -184,7 +198,7 @@
 	xt := x.Type()
 	if xt == nil {
 		var err error
-		xt, err = opts.Resolver.FindExtensionByNumber(mi.Desc.FullName(), num)
+		xt, err = opts.resolver.FindExtensionByNumber(mi.Desc.FullName(), num)
 		if err != nil {
 			if err == preg.NotFound {
 				return out, errUnknown
diff --git a/internal/impl/encode.go b/internal/impl/encode.go
index 8923774..ec08ed6 100644
--- a/internal/impl/encode.go
+++ b/internal/impl/encode.go
@@ -10,11 +10,12 @@
 
 	"google.golang.org/protobuf/internal/flags"
 	proto "google.golang.org/protobuf/proto"
-	pref "google.golang.org/protobuf/reflect/protoreflect"
 	piface "google.golang.org/protobuf/runtime/protoiface"
 )
 
-type marshalOptions piface.MarshalOptions
+type marshalOptions struct {
+	flags piface.MarshalInputFlags
+}
 
 func (o marshalOptions) Options() proto.MarshalOptions {
 	return proto.MarshalOptions{
@@ -24,18 +25,21 @@
 	}
 }
 
-func (o marshalOptions) Deterministic() bool { return o.Flags&piface.MarshalDeterministic != 0 }
-func (o marshalOptions) UseCachedSize() bool { return o.Flags&piface.MarshalUseCachedSize != 0 }
+func (o marshalOptions) Deterministic() bool { return o.flags&piface.MarshalDeterministic != 0 }
+func (o marshalOptions) UseCachedSize() bool { return o.flags&piface.MarshalUseCachedSize != 0 }
 
 // size is protoreflect.Methods.Size.
-func (mi *MessageInfo) size(m pref.Message, opts piface.MarshalOptions) (size int) {
+func (mi *MessageInfo) size(in piface.SizeInput) piface.SizeOutput {
 	var p pointer
-	if ms, ok := m.(*messageState); ok {
+	if ms, ok := in.Message.(*messageState); ok {
 		p = ms.pointer()
 	} else {
-		p = m.(*messageReflectWrapper).pointer()
+		p = in.Message.(*messageReflectWrapper).pointer()
 	}
-	return mi.sizePointer(p, marshalOptions(opts))
+	size := mi.sizePointer(p, marshalOptions{
+		flags: in.Flags,
+	})
+	return piface.SizeOutput{Size: size}
 }
 
 func (mi *MessageInfo) sizePointer(p pointer, opts marshalOptions) (size int) {
@@ -82,14 +86,16 @@
 }
 
 // marshal is protoreflect.Methods.Marshal.
-func (mi *MessageInfo) marshal(m pref.Message, in piface.MarshalInput, opts piface.MarshalOptions) (piface.MarshalOutput, error) {
+func (mi *MessageInfo) marshal(in piface.MarshalInput) (out piface.MarshalOutput, err error) {
 	var p pointer
-	if ms, ok := m.(*messageState); ok {
+	if ms, ok := in.Message.(*messageState); ok {
 		p = ms.pointer()
 	} else {
-		p = m.(*messageReflectWrapper).pointer()
+		p = in.Message.(*messageReflectWrapper).pointer()
 	}
-	b, err := mi.marshalAppendPointer(in.Buf, p, marshalOptions(opts))
+	b, err := mi.marshalAppendPointer(in.Buf, p, marshalOptions{
+		flags: in.Flags,
+	})
 	return piface.MarshalOutput{Buf: b}, err
 }
 
diff --git a/internal/impl/isinit.go b/internal/impl/isinit.go
index 63d1fa5..4bd978f 100644
--- a/internal/impl/isinit.go
+++ b/internal/impl/isinit.go
@@ -9,16 +9,17 @@
 
 	"google.golang.org/protobuf/internal/errors"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
+	piface "google.golang.org/protobuf/runtime/protoiface"
 )
 
-func (mi *MessageInfo) isInitialized(m pref.Message) error {
+func (mi *MessageInfo) isInitialized(in piface.IsInitializedInput) (piface.IsInitializedOutput, error) {
 	var p pointer
-	if ms, ok := m.(*messageState); ok {
+	if ms, ok := in.Message.(*messageState); ok {
 		p = ms.pointer()
 	} else {
-		p = m.(*messageReflectWrapper).pointer()
+		p = in.Message.(*messageReflectWrapper).pointer()
 	}
-	return mi.isInitializedPointer(p)
+	return piface.IsInitializedOutput{}, mi.isInitializedPointer(p)
 }
 
 func (mi *MessageInfo) isInitializedPointer(p pointer) error {
diff --git a/internal/impl/legacy_message.go b/internal/impl/legacy_message.go
index ebfe18a..43b8be2 100644
--- a/internal/impl/legacy_message.go
+++ b/internal/impl/legacy_message.go
@@ -383,8 +383,8 @@
 	Flags: piface.SupportMarshalDeterministic,
 }
 
-func legacyMarshal(m protoreflect.Message, in piface.MarshalInput, opts piface.MarshalOptions) (piface.MarshalOutput, error) {
-	v := m.(unwrapper).protoUnwrap()
+func legacyMarshal(in piface.MarshalInput) (piface.MarshalOutput, error) {
+	v := in.Message.(unwrapper).protoUnwrap()
 	marshaler, ok := v.(legacyMarshaler)
 	if !ok {
 		return piface.MarshalOutput{}, errors.New("%T does not implement Marshal", v)
@@ -398,8 +398,8 @@
 	}, err
 }
 
-func legacyUnmarshal(m protoreflect.Message, in piface.UnmarshalInput, opts piface.UnmarshalOptions) (piface.UnmarshalOutput, error) {
-	v := m.(unwrapper).protoUnwrap()
+func legacyUnmarshal(in piface.UnmarshalInput) (piface.UnmarshalOutput, error) {
+	v := in.Message.(unwrapper).protoUnwrap()
 	unmarshaler, ok := v.(legacyUnmarshaler)
 	if !ok {
 		return piface.UnmarshalOutput{}, errors.New("%T does not implement Marshal", v)
@@ -407,14 +407,14 @@
 	return piface.UnmarshalOutput{}, unmarshaler.Unmarshal(in.Buf)
 }
 
-func legacyMerge(dst, src pref.Message, in piface.MergeInput, opts piface.MergeOptions) piface.MergeOutput {
-	dstv := dst.(unwrapper).protoUnwrap()
+func legacyMerge(in piface.MergeInput) piface.MergeOutput {
+	dstv := in.Destination.(unwrapper).protoUnwrap()
 	merger, ok := dstv.(legacyMerger)
 	if !ok {
 		return piface.MergeOutput{}
 	}
-	merger.Merge(Export{}.ProtoMessageV1Of(src))
-	return piface.MergeOutput{Merged: true}
+	merger.Merge(Export{}.ProtoMessageV1Of(in.Source))
+	return piface.MergeOutput{Flags: piface.MergeComplete}
 }
 
 // aberrantMessageType implements MessageType for all types other than pointer-to-struct.
diff --git a/internal/impl/merge.go b/internal/impl/merge.go
index 20d9dfd..cdc4267 100644
--- a/internal/impl/merge.go
+++ b/internal/impl/merge.go
@@ -13,24 +13,24 @@
 	piface "google.golang.org/protobuf/runtime/protoiface"
 )
 
-type mergeOptions piface.MergeOptions
+type mergeOptions struct{}
 
 func (o mergeOptions) Merge(dst, src proto.Message) {
 	proto.Merge(dst, src)
 }
 
 // merge is protoreflect.Methods.Merge.
-func (mi *MessageInfo) merge(dst, src pref.Message, in piface.MergeInput, opts piface.MergeOptions) piface.MergeOutput {
-	dp, ok := mi.getPointer(dst)
+func (mi *MessageInfo) merge(in piface.MergeInput) piface.MergeOutput {
+	dp, ok := mi.getPointer(in.Destination)
 	if !ok {
-		return piface.MergeOutput{Merged: false}
+		return piface.MergeOutput{}
 	}
-	sp, ok := mi.getPointer(src)
+	sp, ok := mi.getPointer(in.Source)
 	if !ok {
-		return piface.MergeOutput{Merged: false}
+		return piface.MergeOutput{}
 	}
-	mi.mergePointer(dp, sp, opts)
-	return piface.MergeOutput{Merged: true}
+	mi.mergePointer(dp, sp, mergeOptions{})
+	return piface.MergeOutput{Flags: piface.MergeComplete}
 }
 
 func (mi *MessageInfo) mergePointer(dst, src pointer, opts mergeOptions) {
diff --git a/internal/impl/validate.go b/internal/impl/validate.go
index bb00cd0..449331b 100644
--- a/internal/impl/validate.go
+++ b/internal/impl/validate.go
@@ -55,13 +55,21 @@
 // of the message type.
 //
 // This function is exposed for testing.
-func Validate(b []byte, mt pref.MessageType, opts piface.UnmarshalOptions) (out piface.UnmarshalOutput, _ ValidationStatus) {
+func Validate(mt pref.MessageType, in piface.UnmarshalInput) (out piface.UnmarshalOutput, _ ValidationStatus) {
 	mi, ok := mt.(*MessageInfo)
 	if !ok {
 		return out, ValidationUnknown
 	}
-	o, st := mi.validate(b, 0, unmarshalOptions(opts))
-	out.Initialized = o.initialized
+	if in.Resolver == nil {
+		in.Resolver = preg.GlobalTypes
+	}
+	o, st := mi.validate(in.Buf, 0, unmarshalOptions{
+		flags:    in.Flags,
+		resolver: in.Resolver,
+	})
+	if o.initialized {
+		out.Flags |= piface.UnmarshalInitialized
+	}
 	return out, st
 }
 
@@ -325,7 +333,7 @@
 				// In this case, a type added to the resolver in the future could cause
 				// unmarshaling to begin failing. Supporting this requires some way to
 				// determine if the resolver is frozen.
-				xt, err := opts.Resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num)
+				xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num)
 				if err != nil && err != preg.NotFound {
 					return out, ValidationUnknown
 				}
@@ -502,7 +510,7 @@
 					if err != nil {
 						return out, ValidationInvalid
 					}
-					xt, err := opts.Resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid)
+					xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid)
 					switch {
 					case err == preg.NotFound:
 						b = b[n:]
diff --git a/proto/decode.go b/proto/decode.go
index b712786..717c979 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -77,22 +77,22 @@
 	methods := protoMethods(m)
 	if methods != nil && methods.Unmarshal != nil &&
 		!(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) {
-		opts := protoiface.UnmarshalOptions{
+		in := protoiface.UnmarshalInput{
+			Message:  m,
+			Buf:      b,
 			Resolver: o.Resolver,
 		}
 		if o.DiscardUnknown {
-			opts.Flags |= protoiface.UnmarshalDiscardUnknown
+			in.Flags |= protoiface.UnmarshalDiscardUnknown
 		}
-		out, err = methods.Unmarshal(m, protoiface.UnmarshalInput{
-			Buf: b,
-		}, opts)
+		out, err = methods.Unmarshal(in)
 	} else {
 		err = o.unmarshalMessageSlow(b, m)
 	}
 	if err != nil {
 		return out, err
 	}
-	if allowPartial || out.Initialized {
+	if allowPartial || (out.Flags&protoiface.UnmarshalInitialized != 0) {
 		return out, nil
 	}
 	return out, isInitialized(m)
diff --git a/proto/encode.go b/proto/encode.go
index 3625091..950556f 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -105,25 +105,28 @@
 	m := message.ProtoReflect()
 	if methods := protoMethods(m); methods != nil && methods.Marshal != nil &&
 		!(o.Deterministic && methods.Flags&protoiface.SupportMarshalDeterministic == 0) {
-		opts := protoiface.MarshalOptions{}
+		in := protoiface.MarshalInput{
+			Message: m,
+			Buf:     b,
+		}
 		if o.Deterministic {
-			opts.Flags |= protoiface.MarshalDeterministic
+			in.Flags |= protoiface.MarshalDeterministic
 		}
 		if o.UseCachedSize {
-			opts.Flags |= protoiface.MarshalUseCachedSize
+			in.Flags |= protoiface.MarshalUseCachedSize
 		}
 		if methods.Size != nil {
-			sz := methods.Size(m, opts)
-			if cap(b) < len(b)+sz {
-				x := make([]byte, len(b), growcap(cap(b), len(b)+sz))
-				copy(x, b)
-				b = x
+			sout := methods.Size(protoiface.SizeInput{
+				Message: m,
+				Flags:   in.Flags,
+			})
+			if cap(b) < len(b)+sout.Size {
+				in.Buf = make([]byte, len(b), growcap(cap(b), len(b)+sout.Size))
+				copy(in.Buf, b)
 			}
-			opts.Flags |= protoiface.MarshalUseCachedSize
+			in.Flags |= protoiface.MarshalUseCachedSize
 		}
-		out, err = methods.Marshal(m, protoiface.MarshalInput{
-			Buf: b,
-		}, opts)
+		out, err = methods.Marshal(in)
 	} else {
 		out.Buf, err = o.marshalMessageSlow(b, m)
 	}
diff --git a/proto/isinit.go b/proto/isinit.go
index df7cb2c..98494f1 100644
--- a/proto/isinit.go
+++ b/proto/isinit.go
@@ -7,6 +7,7 @@
 import (
 	"google.golang.org/protobuf/internal/errors"
 	"google.golang.org/protobuf/reflect/protoreflect"
+	"google.golang.org/protobuf/runtime/protoiface"
 )
 
 // IsInitialized returns an error if any required fields in m are not set.
@@ -17,7 +18,10 @@
 // IsInitialized returns an error if any required fields in m are not set.
 func isInitialized(m protoreflect.Message) error {
 	if methods := protoMethods(m); methods != nil && methods.IsInitialized != nil {
-		return methods.IsInitialized(m)
+		_, err := methods.IsInitialized(protoiface.IsInitializedInput{
+			Message: m,
+		})
+		return err
 	}
 	return isInitializedSlow(m)
 }
diff --git a/proto/merge.go b/proto/merge.go
index 701e4c0..05cdafc 100644
--- a/proto/merge.go
+++ b/proto/merge.go
@@ -55,10 +55,12 @@
 func (o mergeOptions) mergeMessage(dst, src protoreflect.Message) {
 	methods := protoMethods(dst)
 	if methods != nil && methods.Merge != nil {
-		var in protoiface.MergeInput
-		var opts protoiface.MergeOptions
-		out := methods.Merge(dst, src, in, opts)
-		if out.Merged {
+		in := protoiface.MergeInput{
+			Destination: dst,
+			Source:      src,
+		}
+		out := methods.Merge(in)
+		if out.Flags&protoiface.MergeComplete != 0 {
 			return
 		}
 	}
diff --git a/proto/methods_test.go b/proto/methods_test.go
index dcee6de..e231e29 100644
--- a/proto/methods_test.go
+++ b/proto/methods_test.go
@@ -148,7 +148,7 @@
 				if err != nil {
 					t.Fatalf("Unmarshal error: %v", err)
 				}
-				if got, want := out.Initialized, !test.partial; got != want {
+				if got, want := (out.Flags&protoiface.UnmarshalInitialized != 0), !test.partial; got != want {
 					t.Errorf("out.Initialized = %v, want %v", got, want)
 				}
 			})
diff --git a/proto/size.go b/proto/size.go
index 5f26693..32620a9 100644
--- a/proto/size.go
+++ b/proto/size.go
@@ -24,12 +24,17 @@
 func sizeMessage(m protoreflect.Message) (size int) {
 	methods := protoMethods(m)
 	if methods != nil && methods.Size != nil {
-		return methods.Size(m, protoiface.MarshalOptions{})
+		out := methods.Size(protoiface.SizeInput{
+			Message: m,
+		})
+		return out.Size
 	}
 	if methods != nil && methods.Marshal != nil {
 		// This is not efficient, but we don't have any choice.
 		// This case is mainly used for legacy types with a Marshal method.
-		out, _ := methods.Marshal(m, protoiface.MarshalInput{}, protoiface.MarshalOptions{})
+		out, _ := methods.Marshal(protoiface.MarshalInput{
+			Message: m,
+		})
 		return len(out.Buf)
 	}
 	return sizeMessageSlow(m)
diff --git a/proto/validate_test.go b/proto/validate_test.go
index 490115a..30e1fe5 100644
--- a/proto/validate_test.go
+++ b/proto/validate_test.go
@@ -9,7 +9,6 @@
 	"testing"
 
 	"google.golang.org/protobuf/internal/impl"
-	"google.golang.org/protobuf/reflect/protoregistry"
 	piface "google.golang.org/protobuf/runtime/protoiface"
 )
 
@@ -27,13 +26,13 @@
 				if test.validationStatus != 0 {
 					want = test.validationStatus
 				}
-				var opts piface.UnmarshalOptions
-				opts.Resolver = protoregistry.GlobalTypes
-				out, status := impl.Validate(test.wire, mt, opts)
+				out, status := impl.Validate(mt, piface.UnmarshalInput{
+					Buf: test.wire,
+				})
 				if status != want {
 					t.Errorf("Validate(%x) = %v, want %v", test.wire, status, want)
 				}
-				if got, want := out.Initialized, !test.partial; got != want && !test.nocheckValidInit && status == impl.ValidationValid {
+				if got, want := (out.Flags&piface.UnmarshalInitialized != 0), !test.partial; got != want && !test.nocheckValidInit && status == impl.ValidationValid {
 					t.Errorf("Validate(%x): initialized = %v, want %v", test.wire, got, want)
 				}
 			})
@@ -46,9 +45,9 @@
 		for _, m := range test.decodeTo {
 			t.Run(fmt.Sprintf("%s (%T)", test.desc, m), func(t *testing.T) {
 				mt := m.ProtoReflect().Type()
-				var opts piface.UnmarshalOptions
-				opts.Resolver = protoregistry.GlobalTypes
-				_, got := impl.Validate(test.wire, mt, opts)
+				_, got := impl.Validate(mt, piface.UnmarshalInput{
+					Buf: test.wire,
+				})
 				want := impl.ValidationInvalid
 				if got != want {
 					t.Errorf("Validate(%x) = %v, want %v", test.wire, got, want)
diff --git a/reflect/protoreflect/methods.go b/reflect/protoreflect/methods.go
index 9778b88..d897ace 100644
--- a/reflect/protoreflect/methods.go
+++ b/reflect/protoreflect/methods.go
@@ -18,49 +18,61 @@
 	methods = struct {
 		pragma.NoUnkeyedLiterals
 		Flags         supportFlags
-		Size          func(Message, marshalOptions) int
-		Marshal       func(Message, marshalInput, marshalOptions) (marshalOutput, error)
-		Unmarshal     func(Message, unmarshalInput, unmarshalOptions) (unmarshalOutput, error)
-		IsInitialized func(Message) error
-		Merge         func(Message, Message, mergeInput, mergeOptions) mergeOutput
+		Size          func(sizeInput) sizeOutput
+		Marshal       func(marshalInput) (marshalOutput, error)
+		Unmarshal     func(unmarshalInput) (unmarshalOutput, error)
+		IsInitialized func(isInitializedInput) (isInitializedOutput, error)
+		Merge         func(mergeInput) mergeOutput
 	}
 	supportFlags = uint64
+	sizeInput    = struct {
+		pragma.NoUnkeyedLiterals
+		Message Message
+		Flags   uint8
+	}
+	sizeOutput = struct {
+		pragma.NoUnkeyedLiterals
+		Size int
+	}
 	marshalInput = struct {
 		pragma.NoUnkeyedLiterals
-		Buf []byte
+		Message Message
+		Buf     []byte
+		Flags   uint8
 	}
 	marshalOutput = struct {
 		pragma.NoUnkeyedLiterals
 		Buf []byte
 	}
-	marshalOptions = struct {
-		pragma.NoUnkeyedLiterals
-		Flags uint8
-	}
 	unmarshalInput = struct {
 		pragma.NoUnkeyedLiterals
-		Buf []byte
-	}
-	unmarshalOutput = struct {
-		pragma.NoUnkeyedLiterals
-		Initialized bool
-	}
-	unmarshalOptions = struct {
-		pragma.NoUnkeyedLiterals
+		Message  Message
+		Buf      []byte
 		Flags    uint8
 		Resolver interface {
 			FindExtensionByName(field FullName) (ExtensionType, error)
 			FindExtensionByNumber(message FullName, field FieldNumber) (ExtensionType, error)
 		}
 	}
+	unmarshalOutput = struct {
+		pragma.NoUnkeyedLiterals
+		Flags uint8
+	}
+	isInitializedInput = struct {
+		pragma.NoUnkeyedLiterals
+		Message Message
+	}
+	isInitializedOutput = struct {
+		pragma.NoUnkeyedLiterals
+		Flags uint8
+	}
 	mergeInput = struct {
 		pragma.NoUnkeyedLiterals
+		Source      Message
+		Destination Message
 	}
 	mergeOutput = struct {
 		pragma.NoUnkeyedLiterals
-		Merged bool
-	}
-	mergeOptions = struct {
-		pragma.NoUnkeyedLiterals
+		Flags uint8
 	}
 )
diff --git a/runtime/protoiface/methods.go b/runtime/protoiface/methods.go
index df2f252..f8efd86 100644
--- a/runtime/protoiface/methods.go
+++ b/runtime/protoiface/methods.go
@@ -22,25 +22,26 @@
 	Flags SupportFlags
 
 	// Size returns the size in bytes of the wire-format encoding of m.
-	// MarshalAppend must be provided if a custom Size is provided.
-	Size func(m protoreflect.Message, opts MarshalOptions) int
+	// Marshal must be provided if a custom Size is provided.
+	Size func(SizeInput) SizeOutput
 
 	// Marshal writes the wire-format encoding of m to the provided buffer.
-	// Size should be provided if a custom MarshalAppend is provided.
+	// Size should be provided if a custom Marshal is provided.
 	// It must not return an error for a partial message.
-	Marshal func(m protoreflect.Message, in MarshalInput, opts MarshalOptions) (MarshalOutput, error)
+	Marshal func(MarshalInput) (MarshalOutput, error)
 
 	// Unmarshal parses the wire-format encoding of a message and merges the result to m.
 	// It must not reset the target message or return an error for a partial message.
-	Unmarshal func(m protoreflect.Message, in UnmarshalInput, opts UnmarshalOptions) (UnmarshalOutput, error)
+	Unmarshal func(UnmarshalInput) (UnmarshalOutput, error)
 
 	// IsInitialized returns an error if any required fields in m are not set.
-	IsInitialized func(m protoreflect.Message) error
+	IsInitialized func(IsInitializedInput) (IsInitializedOutput, error)
 
 	// Merge merges src into dst.
-	Merge func(dst, src protoreflect.Message, in MergeInput, opts MergeOptions) MergeOutput
+	Merge func(MergeInput) MergeOutput
 }
 
+// SupportFlags indicate support for optional features.
 type SupportFlags = uint64
 
 const (
@@ -51,87 +52,126 @@
 	SupportUnmarshalDiscardUnknown
 )
 
-// MarshalInput is input to the marshaler.
+// SizeInput is input to the Size method.
+type SizeInput = struct {
+	pragma.NoUnkeyedLiterals
+
+	Message protoreflect.Message
+	Flags   MarshalInputFlags
+}
+
+// SizeOutput is output from the Size method.
+type SizeOutput = struct {
+	pragma.NoUnkeyedLiterals
+
+	Size int
+}
+
+// MarshalInput is input to the Marshal method.
 type MarshalInput = struct {
 	pragma.NoUnkeyedLiterals
 
-	Buf []byte // output is appended to this buffer
+	Message protoreflect.Message
+	Buf     []byte // output is appended to this buffer
+	Flags   MarshalInputFlags
 }
 
-// MarshalOutput is output from the marshaler.
+// MarshalOutput is output from the Marshal method.
 type MarshalOutput = struct {
 	pragma.NoUnkeyedLiterals
 
 	Buf []byte // contains marshaled message
 }
 
-// MarshalOptions configure the marshaler.
-type MarshalOptions = struct {
-	pragma.NoUnkeyedLiterals
-
-	Flags MarshalFlags
-}
-
-// MarshalFlags are configure the marshaler.
+// MarshalInputFlags configure the marshaler.
 // Most flags correspond to fields in proto.MarshalOptions.
-type MarshalFlags = uint8
+type MarshalInputFlags = uint8
 
 const (
-	MarshalDeterministic MarshalFlags = 1 << iota
+	MarshalDeterministic MarshalInputFlags = 1 << iota
 	MarshalUseCachedSize
 )
 
-// UnmarshalInput is input to the unmarshaler.
+// UnmarshalInput is input to the Unmarshal method.
 type UnmarshalInput = struct {
 	pragma.NoUnkeyedLiterals
 
-	Buf []byte // input buffer
-}
-
-// UnmarshalOutput is output from the unmarshaler.
-type UnmarshalOutput = struct {
-	pragma.NoUnkeyedLiterals
-
-	// Initialized may be set on return if all required fields are known to be set.
-	// A value of false does not indicate that the message is uninitialized, only
-	// that its status could not be confirmed.
-	Initialized bool
-}
-
-// UnmarshalOptions configures the unmarshaler.
-type UnmarshalOptions = struct {
-	pragma.NoUnkeyedLiterals
-
-	Flags    UnmarshalFlags
+	Message  protoreflect.Message
+	Buf      []byte // input buffer
+	Flags    UnmarshalInputFlags
 	Resolver interface {
 		FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
 		FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
 	}
 }
 
-// UnmarshalFlags configure the unmarshaler.
-// Most flags correspond to fields in proto.UnmarshalOptions.
-type UnmarshalFlags = uint8
-
-const (
-	UnmarshalDiscardUnknown UnmarshalFlags = 1 << iota
-)
-
-// MergeInput is input to the merger.
-type MergeInput = struct {
+// UnmarshalOutput is output from the Unmarshal method.
+type UnmarshalOutput = struct {
 	pragma.NoUnkeyedLiterals
+
+	Flags UnmarshalOutputFlags
 }
 
-// MergeOutput is output from the merger.
+// UnmarshalInputFlags configure the unmarshaler.
+// Most flags correspond to fields in proto.UnmarshalOptions.
+type UnmarshalInputFlags = uint8
+
+const (
+	UnmarshalDiscardUnknown UnmarshalInputFlags = 1 << iota
+)
+
+// UnmarshalOutputFlags are output from the Unmarshal method.
+type UnmarshalOutputFlags = uint8
+
+const (
+	// UnmarshalInitialized may be set on return if all required fields are known to be set.
+	// A value of false does not indicate that the message is uninitialized, only
+	// that its status could not be confirmed.
+	UnmarshalInitialized UnmarshalOutputFlags = 1 << iota
+)
+
+// IsInitializedInput is input to the IsInitialized method.
+type IsInitializedInput = struct {
+	pragma.NoUnkeyedLiterals
+
+	Message protoreflect.Message
+}
+
+// IsInitializedOutput is output from the IsInitialized method.
+type IsInitializedOutput = struct {
+	pragma.NoUnkeyedLiterals
+
+	Flags IsInitializedOutputFlags
+}
+
+// IsInitializedOutputFlags are output from the IsInitialized method.
+type IsInitializedOutputFlags = uint8
+
+const (
+	// IsInitialized reports whether the message is initialized.
+	IsInitialized IsInitializedOutputFlags = 1 << iota
+)
+
+// MergeInput is input to the Merge method.
+type MergeInput = struct {
+	pragma.NoUnkeyedLiterals
+
+	Source      protoreflect.Message
+	Destination protoreflect.Message
+}
+
+// MergeOutput is output from the Merge method.
 type MergeOutput = struct {
 	pragma.NoUnkeyedLiterals
 
-	// Merged is true if the merge was performed, false otherwise.
-	// If false, the merger must have made no changes to the destination.
-	Merged bool
+	Flags MergeOutputFlags
 }
 
-// MergeOptions configure the merger.
-type MergeOptions = struct {
-	pragma.NoUnkeyedLiterals
-}
+// MergeOutputFlags are output from the Merge method.
+type MergeOutputFlags = uint8
+
+const (
+	// MergeComplete reports whether the merge was performed.
+	// If unset, the merger must have made no changes to the destination.
+	MergeComplete MergeOutputFlags = 1 << iota
+)