proto: fix signature of UnmarshalState and MarshalState

The pseudo-internal MarshalState and UnmarshalState method should
not have a seperate Message argument since it is passed in through
the extensible MarshalInput and UnmarshalInput values.

Change-Id: I838aadaee30e91cdf888ab024e65348c73c1cd7e
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/222678
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/impl/codec_field.go b/internal/impl/codec_field.go
index 6204e65..f9b6836 100644
--- a/internal/impl/codec_field.go
+++ b/internal/impl/codec_field.go
@@ -278,8 +278,9 @@
 	if n < 0 {
 		return out, wire.ParseError(n)
 	}
-	o, err := opts.Options().UnmarshalState(m, piface.UnmarshalInput{
-		Buf: v,
+	o, err := opts.Options().UnmarshalState(piface.UnmarshalInput{
+		Buf:     v,
+		Message: m.ProtoReflect(),
 	})
 	if err != nil {
 		return out, err
@@ -421,8 +422,9 @@
 	if n < 0 {
 		return out, wire.ParseError(n)
 	}
-	o, err := opts.Options().UnmarshalState(m, piface.UnmarshalInput{
-		Buf: b,
+	o, err := opts.Options().UnmarshalState(piface.UnmarshalInput{
+		Buf:     b,
+		Message: m.ProtoReflect(),
 	})
 	if err != nil {
 		return out, err
@@ -551,8 +553,9 @@
 		return out, wire.ParseError(n)
 	}
 	mp := reflect.New(goType.Elem())
-	o, err := opts.Options().UnmarshalState(asMessage(mp), piface.UnmarshalInput{
-		Buf: v,
+	o, err := opts.Options().UnmarshalState(piface.UnmarshalInput{
+		Buf:     v,
+		Message: asMessage(mp).ProtoReflect(),
 	})
 	if err != nil {
 		return out, err
@@ -613,8 +616,9 @@
 		return pref.Value{}, out, wire.ParseError(n)
 	}
 	m := list.NewElement()
-	o, err := opts.Options().UnmarshalState(m.Message().Interface(), piface.UnmarshalInput{
-		Buf: v,
+	o, err := opts.Options().UnmarshalState(piface.UnmarshalInput{
+		Buf:     v,
+		Message: m.Message(),
 	})
 	if err != nil {
 		return pref.Value{}, out, err
@@ -680,8 +684,9 @@
 		return pref.Value{}, out, wire.ParseError(n)
 	}
 	m := list.NewElement()
-	o, err := opts.Options().UnmarshalState(m.Message().Interface(), piface.UnmarshalInput{
-		Buf: b,
+	o, err := opts.Options().UnmarshalState(piface.UnmarshalInput{
+		Buf:     b,
+		Message: m.Message(),
 	})
 	if err != nil {
 		return pref.Value{}, out, err
@@ -765,8 +770,9 @@
 		return out, wire.ParseError(n)
 	}
 	mp := reflect.New(goType.Elem())
-	o, err := opts.Options().UnmarshalState(asMessage(mp), piface.UnmarshalInput{
-		Buf: b,
+	o, err := opts.Options().UnmarshalState(piface.UnmarshalInput{
+		Buf:     b,
+		Message: asMessage(mp).ProtoReflect(),
 	})
 	if err != nil {
 		return out, err
diff --git a/proto/decode.go b/proto/decode.go
index 30fd529..536491b 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -45,13 +45,13 @@
 
 // Unmarshal parses the wire-format message in b and places the result in m.
 func Unmarshal(b []byte, m Message) error {
-	_, err := UnmarshalOptions{}.unmarshal(b, m)
+	_, err := UnmarshalOptions{}.unmarshal(b, m.ProtoReflect())
 	return err
 }
 
 // Unmarshal parses the wire-format message in b and places the result in m.
 func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
-	_, err := o.unmarshal(b, m)
+	_, err := o.unmarshal(b, m.ProtoReflect())
 	return err
 }
 
@@ -59,21 +59,20 @@
 //
 // This method permits fine-grained control over the unmarshaler.
 // Most users should use Unmarshal instead.
-func (o UnmarshalOptions) UnmarshalState(m Message, in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
-	return o.unmarshal(in.Buf, m)
+func (o UnmarshalOptions) UnmarshalState(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
+	return o.unmarshal(in.Buf, in.Message)
 }
 
-func (o UnmarshalOptions) unmarshal(b []byte, message Message) (out protoiface.UnmarshalOutput, err error) {
+func (o UnmarshalOptions) unmarshal(b []byte, m protoreflect.Message) (out protoiface.UnmarshalOutput, err error) {
 	if o.Resolver == nil {
 		o.Resolver = protoregistry.GlobalTypes
 	}
 	if !o.Merge {
-		Reset(message)
+		Reset(m.Interface()) // TODO
 	}
 	allowPartial := o.AllowPartial
 	o.Merge = true
 	o.AllowPartial = true
-	m := message.ProtoReflect()
 	methods := protoMethods(m)
 	if methods != nil && methods.Unmarshal != nil &&
 		!(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) {
@@ -99,7 +98,7 @@
 }
 
 func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
-	_, err := o.unmarshal(b, m.Interface())
+	_, err := o.unmarshal(b, m)
 	return err
 }
 
diff --git a/proto/encode.go b/proto/encode.go
index 2a8e895..999aa7b 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -74,35 +74,34 @@
 
 // Marshal returns the wire-format encoding of m.
 func Marshal(m Message) ([]byte, error) {
-	out, err := MarshalOptions{}.marshal(nil, m)
+	out, err := MarshalOptions{}.marshal(nil, m.ProtoReflect())
 	return out.Buf, err
 }
 
 // Marshal returns the wire-format encoding of m.
 func (o MarshalOptions) Marshal(m Message) ([]byte, error) {
-	out, err := o.marshal(nil, m)
+	out, err := o.marshal(nil, m.ProtoReflect())
 	return out.Buf, err
 }
 
 // MarshalAppend appends the wire-format encoding of m to b,
 // returning the result.
 func (o MarshalOptions) MarshalAppend(b []byte, m Message) ([]byte, error) {
-	out, err := o.marshal(b, m)
+	out, err := o.marshal(b, m.ProtoReflect())
 	return out.Buf, err
 }
 
-// MarshalState returns the wire-format encoding of m.
+// MarshalState returns the wire-format encoding of a message.
 //
 // This method permits fine-grained control over the marshaler.
 // Most users should use Marshal instead.
-func (o MarshalOptions) MarshalState(m Message, in protoiface.MarshalInput) (protoiface.MarshalOutput, error) {
-	return o.marshal(in.Buf, m)
+func (o MarshalOptions) MarshalState(in protoiface.MarshalInput) (protoiface.MarshalOutput, error) {
+	return o.marshal(in.Buf, in.Message)
 }
 
-func (o MarshalOptions) marshal(b []byte, message Message) (out protoiface.MarshalOutput, err error) {
+func (o MarshalOptions) marshal(b []byte, m protoreflect.Message) (out protoiface.MarshalOutput, err error) {
 	allowPartial := o.AllowPartial
 	o.AllowPartial = true
-	m := message.ProtoReflect()
 	if methods := protoMethods(m); methods != nil && methods.Marshal != nil &&
 		!(o.Deterministic && methods.Flags&protoiface.SupportMarshalDeterministic == 0) {
 		in := protoiface.MarshalInput{
@@ -140,7 +139,7 @@
 }
 
 func (o MarshalOptions) marshalMessage(b []byte, m protoreflect.Message) ([]byte, error) {
-	out, err := o.marshal(b, m.Interface())
+	out, err := o.marshal(b, m)
 	return out.Buf, err
 }
 
diff --git a/proto/methods_test.go b/proto/methods_test.go
index 6809f04..b1dcce3 100644
--- a/proto/methods_test.go
+++ b/proto/methods_test.go
@@ -142,8 +142,9 @@
 				opts := proto.UnmarshalOptions{
 					AllowPartial: true,
 				}
-				out, err := opts.UnmarshalState(m.Interface(), protoiface.UnmarshalInput{
-					Buf: test.wire,
+				out, err := opts.UnmarshalState(protoiface.UnmarshalInput{
+					Buf:     test.wire,
+					Message: m,
 				})
 				if err != nil {
 					t.Fatalf("Unmarshal error: %v", err)