internal/impl: support typed nil source for Merge of aberrant messages
When merging aberrant messages with legacy Marshal and Unmarshal
methods, check for a typed nil source before calling Marshal.
Add an aberrant message with Marshal/Unmarshal methods to
internal/testprotos/nullable and use it to test the internal/impl
support for these methods.
Fixes golang/protobuf#1324
Change-Id: Ib6ce85b30b46e3392a226ca6abe411932a371f02
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/321529
Trust: Damien Neil <dneil@google.com>
Run-TryBot: Damien Neil <dneil@google.com>
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/impl/legacy_message.go b/internal/impl/legacy_message.go
index 3759b01..029feee 100644
--- a/internal/impl/legacy_message.go
+++ b/internal/impl/legacy_message.go
@@ -440,6 +440,13 @@
if !ok {
return piface.MergeOutput{}
}
+ if !in.Source.IsValid() {
+ // Legacy Marshal methods may not function on nil messages.
+ // Check for a typed nil source only after we confirm that
+ // legacy Marshal/Unmarshal methods are present, for
+ // consistency.
+ return piface.MergeOutput{Flags: piface.MergeComplete}
+ }
b, err := marshaler.Marshal()
if err != nil {
return piface.MergeOutput{}
diff --git a/internal/testprotos/nullable/methods_test.go b/internal/testprotos/nullable/methods_test.go
index e272838..8e22ab2 100644
--- a/internal/testprotos/nullable/methods_test.go
+++ b/internal/testprotos/nullable/methods_test.go
@@ -2,45 +2,19 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Only test compatibility with the Marshal/Unmarshal functionality with
+// For messages which do not provide legacy Marshal and Unmarshal methods,
+// only test compatibility with the Marshal/Unmarshal functionality with
// pure protobuf reflection since there is no support for nullable fields
// in the table-driven implementation.
// +build protoreflect
package nullable
-import (
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "google.golang.org/protobuf/proto"
- "google.golang.org/protobuf/reflect/protoreflect"
- "google.golang.org/protobuf/testing/protocmp"
-)
+import "google.golang.org/protobuf/runtime/protoimpl"
func init() {
- testMethods = func(t *testing.T, mt protoreflect.MessageType) {
- m1 := mt.New()
- populated := testPopulateMessage(t, m1, 2)
- b, err := proto.Marshal(m1.Interface())
- if err != nil {
- t.Errorf("proto.Marshal error: %v", err)
- }
- if populated && len(b) == 0 {
- t.Errorf("len(proto.Marshal) = 0, want >0")
- }
- m2 := mt.New()
- if err := proto.Unmarshal(b, m2.Interface()); err != nil {
- t.Errorf("proto.Unmarshal error: %v", err)
- }
- if diff := cmp.Diff(m1.Interface(), m2.Interface(), protocmp.Transform()); diff != "" {
- t.Errorf("message mismatch:\n%v", diff)
- }
- proto.Reset(m2.Interface())
- testEmptyMessage(t, m2, true)
- proto.Merge(m2.Interface(), m1.Interface())
- if diff := cmp.Diff(m1.Interface(), m2.Interface(), protocmp.Transform()); diff != "" {
- t.Errorf("message mismatch:\n%v", diff)
- }
- }
+ methodTestProtos = append(methodTestProtos,
+ protoimpl.X.ProtoMessageV2Of((*Proto2)(nil)).ProtoReflect().Type(),
+ protoimpl.X.ProtoMessageV2Of((*Proto3)(nil)).ProtoReflect().Type(),
+ )
}
diff --git a/internal/testprotos/nullable/nullable.go b/internal/testprotos/nullable/nullable.go
index a291b45..3466455 100644
--- a/internal/testprotos/nullable/nullable.go
+++ b/internal/testprotos/nullable/nullable.go
@@ -6,6 +6,7 @@
import (
"google.golang.org/protobuf/encoding/prototext"
+ "google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/runtime/protoimpl"
"google.golang.org/protobuf/types/descriptorpb"
)
@@ -223,3 +224,43 @@
func (*Proto3_OneofBytes) isProto3_OneofUnion() {}
func (*Proto3_OneofEnum) isProto3_OneofUnion() {}
func (*Proto3_OneofMessage) isProto3_OneofUnion() {}
+
+type Methods struct {
+ OptionalInt32 int32 `protobuf:"varint,101,opt,name=optional_int32"`
+}
+
+func (x *Methods) ProtoMessage() {}
+func (x *Methods) Reset() { *x = Methods{} }
+func (x *Methods) String() string { return prototext.Format(protoimpl.X.ProtoMessageV2Of(x)) }
+
+func (x *Methods) Marshal() ([]byte, error) {
+ var b []byte
+ b = protowire.AppendTag(b, 101, protowire.VarintType)
+ b = protowire.AppendVarint(b, uint64(x.OptionalInt32))
+ return b, nil
+}
+
+func (x *Methods) Unmarshal(b []byte) error {
+ for len(b) > 0 {
+ num, typ, n := protowire.ConsumeTag(b)
+ if n < 0 {
+ return protowire.ParseError(n)
+ }
+ b = b[n:]
+ if num != 101 || typ != protowire.VarintType {
+ n = protowire.ConsumeFieldValue(num, typ, b)
+ if n < 0 {
+ return protowire.ParseError(n)
+ }
+ b = b[n:]
+ continue
+ }
+ v, n := protowire.ConsumeVarint(b)
+ if n < 0 {
+ return protowire.ParseError(n)
+ }
+ b = b[n:]
+ x.OptionalInt32 = int32(v)
+ }
+ return nil
+}
diff --git a/internal/testprotos/nullable/nullable_test.go b/internal/testprotos/nullable/nullable_test.go
index 6994e34..9c288b9 100644
--- a/internal/testprotos/nullable/nullable_test.go
+++ b/internal/testprotos/nullable/nullable_test.go
@@ -8,8 +8,11 @@
"reflect"
"testing"
+ "github.com/google/go-cmp/cmp"
+ "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoimpl"
+ "google.golang.org/protobuf/testing/protocmp"
)
func Test(t *testing.T) {
@@ -20,12 +23,48 @@
t.Run(string(mt.Descriptor().FullName()), func(t *testing.T) {
testEmptyMessage(t, mt.Zero(), false)
testEmptyMessage(t, mt.New(), true)
+ //testMethods(t, mt)
+ })
+ }
+}
+
+var methodTestProtos = []protoreflect.MessageType{
+ protoimpl.X.ProtoMessageV2Of((*Methods)(nil)).ProtoReflect().Type(),
+}
+
+func TestMethods(t *testing.T) {
+ for _, mt := range methodTestProtos {
+ t.Run(string(mt.Descriptor().FullName()), func(t *testing.T) {
testMethods(t, mt)
})
}
}
-var testMethods = func(*testing.T, protoreflect.MessageType) {}
+func testMethods(t *testing.T, mt protoreflect.MessageType) {
+ m1 := mt.New()
+ populated := testPopulateMessage(t, m1, 2)
+ b, err := proto.Marshal(m1.Interface())
+ if err != nil {
+ t.Errorf("proto.Marshal error: %v", err)
+ }
+ if populated && len(b) == 0 {
+ t.Errorf("len(proto.Marshal) = 0, want >0")
+ }
+ m2 := mt.New()
+ if err := proto.Unmarshal(b, m2.Interface()); err != nil {
+ t.Errorf("proto.Unmarshal error: %v", err)
+ }
+ if diff := cmp.Diff(m1.Interface(), m2.Interface(), protocmp.Transform()); diff != "" {
+ t.Errorf("message mismatch:\n%v", diff)
+ }
+ proto.Reset(m2.Interface())
+ testEmptyMessage(t, m2, true)
+ proto.Merge(m2.Interface(), m1.Interface())
+ if diff := cmp.Diff(m1.Interface(), m2.Interface(), protocmp.Transform()); diff != "" {
+ t.Errorf("message mismatch:\n%v", diff)
+ }
+ proto.Merge(mt.New().Interface(), mt.Zero().Interface())
+}
func testEmptyMessage(t *testing.T, m protoreflect.Message, wantValid bool) {
numFields := func(m protoreflect.Message) (n int) {