internal/impl: support typed nil source for Merge of aberrant messages

When merging aberrant messages with legacy Marshal and Unmarshal
methods, check for a typed nil source before calling Marshal.

Add an aberrant message with Marshal/Unmarshal methods to
internal/testprotos/nullable and use it to test the internal/impl
support for these methods.

Fixes golang/protobuf#1324

Change-Id: Ib6ce85b30b46e3392a226ca6abe411932a371f02
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/321529
Trust: Damien Neil <dneil@google.com>
Run-TryBot: Damien Neil <dneil@google.com>
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/impl/legacy_message.go b/internal/impl/legacy_message.go
index 3759b01..029feee 100644
--- a/internal/impl/legacy_message.go
+++ b/internal/impl/legacy_message.go
@@ -440,6 +440,13 @@
 	if !ok {
 		return piface.MergeOutput{}
 	}
+	if !in.Source.IsValid() {
+		// Legacy Marshal methods may not function on nil messages.
+		// Check for a typed nil source only after we confirm that
+		// legacy Marshal/Unmarshal methods are present, for
+		// consistency.
+		return piface.MergeOutput{Flags: piface.MergeComplete}
+	}
 	b, err := marshaler.Marshal()
 	if err != nil {
 		return piface.MergeOutput{}
diff --git a/internal/testprotos/nullable/methods_test.go b/internal/testprotos/nullable/methods_test.go
index e272838..8e22ab2 100644
--- a/internal/testprotos/nullable/methods_test.go
+++ b/internal/testprotos/nullable/methods_test.go
@@ -2,45 +2,19 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-// Only test compatibility with the Marshal/Unmarshal functionality with
+// For messages which do not provide legacy Marshal and Unmarshal methods,
+// only test compatibility with the Marshal/Unmarshal functionality with
 // pure protobuf reflection since there is no support for nullable fields
 // in the table-driven implementation.
 // +build protoreflect
 
 package nullable
 
-import (
-	"testing"
-
-	"github.com/google/go-cmp/cmp"
-	"google.golang.org/protobuf/proto"
-	"google.golang.org/protobuf/reflect/protoreflect"
-	"google.golang.org/protobuf/testing/protocmp"
-)
+import "google.golang.org/protobuf/runtime/protoimpl"
 
 func init() {
-	testMethods = func(t *testing.T, mt protoreflect.MessageType) {
-		m1 := mt.New()
-		populated := testPopulateMessage(t, m1, 2)
-		b, err := proto.Marshal(m1.Interface())
-		if err != nil {
-			t.Errorf("proto.Marshal error: %v", err)
-		}
-		if populated && len(b) == 0 {
-			t.Errorf("len(proto.Marshal) = 0, want >0")
-		}
-		m2 := mt.New()
-		if err := proto.Unmarshal(b, m2.Interface()); err != nil {
-			t.Errorf("proto.Unmarshal error: %v", err)
-		}
-		if diff := cmp.Diff(m1.Interface(), m2.Interface(), protocmp.Transform()); diff != "" {
-			t.Errorf("message mismatch:\n%v", diff)
-		}
-		proto.Reset(m2.Interface())
-		testEmptyMessage(t, m2, true)
-		proto.Merge(m2.Interface(), m1.Interface())
-		if diff := cmp.Diff(m1.Interface(), m2.Interface(), protocmp.Transform()); diff != "" {
-			t.Errorf("message mismatch:\n%v", diff)
-		}
-	}
+	methodTestProtos = append(methodTestProtos,
+		protoimpl.X.ProtoMessageV2Of((*Proto2)(nil)).ProtoReflect().Type(),
+		protoimpl.X.ProtoMessageV2Of((*Proto3)(nil)).ProtoReflect().Type(),
+	)
 }
diff --git a/internal/testprotos/nullable/nullable.go b/internal/testprotos/nullable/nullable.go
index a291b45..3466455 100644
--- a/internal/testprotos/nullable/nullable.go
+++ b/internal/testprotos/nullable/nullable.go
@@ -6,6 +6,7 @@
 
 import (
 	"google.golang.org/protobuf/encoding/prototext"
+	"google.golang.org/protobuf/encoding/protowire"
 	"google.golang.org/protobuf/runtime/protoimpl"
 	"google.golang.org/protobuf/types/descriptorpb"
 )
@@ -223,3 +224,43 @@
 func (*Proto3_OneofBytes) isProto3_OneofUnion()   {}
 func (*Proto3_OneofEnum) isProto3_OneofUnion()    {}
 func (*Proto3_OneofMessage) isProto3_OneofUnion() {}
+
+type Methods struct {
+	OptionalInt32 int32 `protobuf:"varint,101,opt,name=optional_int32"`
+}
+
+func (x *Methods) ProtoMessage()  {}
+func (x *Methods) Reset()         { *x = Methods{} }
+func (x *Methods) String() string { return prototext.Format(protoimpl.X.ProtoMessageV2Of(x)) }
+
+func (x *Methods) Marshal() ([]byte, error) {
+	var b []byte
+	b = protowire.AppendTag(b, 101, protowire.VarintType)
+	b = protowire.AppendVarint(b, uint64(x.OptionalInt32))
+	return b, nil
+}
+
+func (x *Methods) Unmarshal(b []byte) error {
+	for len(b) > 0 {
+		num, typ, n := protowire.ConsumeTag(b)
+		if n < 0 {
+			return protowire.ParseError(n)
+		}
+		b = b[n:]
+		if num != 101 || typ != protowire.VarintType {
+			n = protowire.ConsumeFieldValue(num, typ, b)
+			if n < 0 {
+				return protowire.ParseError(n)
+			}
+			b = b[n:]
+			continue
+		}
+		v, n := protowire.ConsumeVarint(b)
+		if n < 0 {
+			return protowire.ParseError(n)
+		}
+		b = b[n:]
+		x.OptionalInt32 = int32(v)
+	}
+	return nil
+}
diff --git a/internal/testprotos/nullable/nullable_test.go b/internal/testprotos/nullable/nullable_test.go
index 6994e34..9c288b9 100644
--- a/internal/testprotos/nullable/nullable_test.go
+++ b/internal/testprotos/nullable/nullable_test.go
@@ -8,8 +8,11 @@
 	"reflect"
 	"testing"
 
+	"github.com/google/go-cmp/cmp"
+	"google.golang.org/protobuf/proto"
 	"google.golang.org/protobuf/reflect/protoreflect"
 	"google.golang.org/protobuf/runtime/protoimpl"
+	"google.golang.org/protobuf/testing/protocmp"
 )
 
 func Test(t *testing.T) {
@@ -20,12 +23,48 @@
 		t.Run(string(mt.Descriptor().FullName()), func(t *testing.T) {
 			testEmptyMessage(t, mt.Zero(), false)
 			testEmptyMessage(t, mt.New(), true)
+			//testMethods(t, mt)
+		})
+	}
+}
+
+var methodTestProtos = []protoreflect.MessageType{
+	protoimpl.X.ProtoMessageV2Of((*Methods)(nil)).ProtoReflect().Type(),
+}
+
+func TestMethods(t *testing.T) {
+	for _, mt := range methodTestProtos {
+		t.Run(string(mt.Descriptor().FullName()), func(t *testing.T) {
 			testMethods(t, mt)
 		})
 	}
 }
 
-var testMethods = func(*testing.T, protoreflect.MessageType) {}
+func testMethods(t *testing.T, mt protoreflect.MessageType) {
+	m1 := mt.New()
+	populated := testPopulateMessage(t, m1, 2)
+	b, err := proto.Marshal(m1.Interface())
+	if err != nil {
+		t.Errorf("proto.Marshal error: %v", err)
+	}
+	if populated && len(b) == 0 {
+		t.Errorf("len(proto.Marshal) = 0, want >0")
+	}
+	m2 := mt.New()
+	if err := proto.Unmarshal(b, m2.Interface()); err != nil {
+		t.Errorf("proto.Unmarshal error: %v", err)
+	}
+	if diff := cmp.Diff(m1.Interface(), m2.Interface(), protocmp.Transform()); diff != "" {
+		t.Errorf("message mismatch:\n%v", diff)
+	}
+	proto.Reset(m2.Interface())
+	testEmptyMessage(t, m2, true)
+	proto.Merge(m2.Interface(), m1.Interface())
+	if diff := cmp.Diff(m1.Interface(), m2.Interface(), protocmp.Transform()); diff != "" {
+		t.Errorf("message mismatch:\n%v", diff)
+	}
+	proto.Merge(mt.New().Interface(), mt.Zero().Interface())
+}
 
 func testEmptyMessage(t *testing.T, m protoreflect.Message, wantValid bool) {
 	numFields := func(m protoreflect.Message) (n int) {