proto, internal/impl: implement support for weak fields
Change-Id: I0a3ff79542a3316295fd6c58e1447e597be97ab9
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/189923
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/encoding/protojson/decode.go b/encoding/protojson/decode.go
index faa56b0..db8a3df 100644
--- a/encoding/protojson/decode.go
+++ b/encoding/protojson/decode.go
@@ -201,6 +201,8 @@
fd = nil // reset since field name is actually the message name
}
}
+ }
+ if flags.ProtoLegacy {
if fd != nil && fd.IsWeak() && fd.Message().IsPlaceholder() {
fd = nil // reset since the weak reference is not linked in
}
diff --git a/encoding/prototext/decode.go b/encoding/prototext/decode.go
index 27d23d3..8951e6b 100644
--- a/encoding/prototext/decode.go
+++ b/encoding/prototext/decode.go
@@ -138,8 +138,10 @@
} else if xtErr != nil && xtErr != protoregistry.NotFound {
return errors.New("unable to resolve: %v", xtErr)
}
- if fd != nil && fd.IsWeak() && fd.Message().IsPlaceholder() {
- fd = nil // reset since the weak reference is not linked in
+ if flags.ProtoLegacy {
+ if fd != nil && fd.IsWeak() && fd.Message().IsPlaceholder() {
+ fd = nil // reset since the weak reference is not linked in
+ }
}
// Handle unknown fields.
diff --git a/internal/impl/codec_field.go b/internal/impl/codec_field.go
index 5414635..45b4664 100644
--- a/internal/impl/codec_field.go
+++ b/internal/impl/codec_field.go
@@ -6,10 +6,13 @@
import (
"reflect"
+ "sync"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/proto"
pref "google.golang.org/protobuf/reflect/protoreflect"
+ preg "google.golang.org/protobuf/reflect/protoregistry"
+ piface "google.golang.org/protobuf/runtime/protoiface"
)
type errInvalidUTF8 struct{}
@@ -17,7 +20,7 @@
func (errInvalidUTF8) Error() string { return "string field contains invalid UTF-8" }
func (errInvalidUTF8) InvalidUTF8() bool { return true }
-func makeOneofFieldCoder(si structInfo, fd pref.FieldDescriptor) pointerCoderFuncs {
+func makeOneofFieldCoder(fd pref.FieldDescriptor, si structInfo) pointerCoderFuncs {
ot := si.oneofWrappersByNumber[fd.Number()]
funcs := fieldCoder(fd, ot.Field(0).Type)
fs := si.oneofsByName[fd.ContainingOneof().Name()]
@@ -78,6 +81,61 @@
return pcf
}
+func makeWeakMessageFieldCoder(fd pref.FieldDescriptor) pointerCoderFuncs {
+ var once sync.Once
+ var messageType pref.MessageType
+ lazyInit := func() {
+ once.Do(func() {
+ messageName := fd.Message().FullName()
+ messageType, _ = preg.GlobalTypes.FindMessageByName(messageName)
+ })
+ }
+
+ num := int32(fd.Number())
+ return pointerCoderFuncs{
+ size: func(p pointer, tagsize int, opts marshalOptions) int {
+ fs := p.WeakFields()
+ m, ok := (*fs)[num]
+ if !ok {
+ return 0
+ }
+ return sizeMessage(m.(proto.Message), tagsize, opts)
+ },
+ marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
+ fs := p.WeakFields()
+ m, ok := (*fs)[num]
+ if !ok {
+ return b, nil
+ }
+ return appendMessage(b, m.(proto.Message), wiretag, opts)
+ },
+ unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ fs := p.WeakFields()
+ m, ok := (*fs)[num]
+ if !ok {
+ lazyInit()
+ if messageType == nil {
+ return 0, errUnknown
+ }
+ m = messageType.New().Interface().(piface.MessageV1)
+ if *fs == nil {
+ *fs = make(WeakFields)
+ }
+ (*fs)[num] = m
+ }
+ return consumeMessage(b, m.(proto.Message), wtyp, opts)
+ },
+ isInit: func(p pointer) error {
+ fs := p.WeakFields()
+ m, ok := (*fs)[num]
+ if !ok {
+ return nil
+ }
+ return proto.IsInitialized(m.(proto.Message))
+ },
+ }
+}
+
func makeMessageFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
if mi := getMessageInfo(ft); mi != nil {
return pointerCoderFuncs{
diff --git a/internal/impl/codec_message.go b/internal/impl/codec_message.go
index 92f38e9..d7235a2 100644
--- a/internal/impl/codec_message.go
+++ b/internal/impl/codec_message.go
@@ -65,15 +65,22 @@
} else {
wiretag = wire.EncodeTag(fd.Number(), wire.BytesType)
}
+ var fieldOffset offset
var funcs pointerCoderFuncs
- if fd.ContainingOneof() != nil {
- funcs = makeOneofFieldCoder(si, fd)
- } else {
+ switch {
+ case fd.ContainingOneof() != nil:
+ fieldOffset = offsetOf(fs, mi.Exporter)
+ funcs = makeOneofFieldCoder(fd, si)
+ case fd.IsWeak():
+ fieldOffset = si.weakOffset
+ funcs = makeWeakMessageFieldCoder(fd)
+ default:
+ fieldOffset = offsetOf(fs, mi.Exporter)
funcs = fieldCoder(fd, ft)
}
cf := &coderFieldInfo{
num: fd.Number(),
- offset: offsetOf(fs, mi.Exporter),
+ offset: fieldOffset,
wiretag: wiretag,
tagsize: wire.SizeVarint(wiretag),
funcs: funcs,
diff --git a/internal/testprotos/test/weak1/test_weak.pb.go b/internal/testprotos/test/weak1/test_weak.pb.go
index 1b80b8b..c62ba49 100644
--- a/internal/testprotos/test/weak1/test_weak.pb.go
+++ b/internal/testprotos/test/weak1/test_weak.pb.go
@@ -19,7 +19,7 @@
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
- A *int32 `protobuf:"varint,1,opt,name=a" json:"a,omitempty"`
+ A *int32 `protobuf:"varint,1,req,name=a" json:"a,omitempty"`
}
func (x *WeakImportMessage1) Reset() {
@@ -64,7 +64,7 @@
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x74, 0x65, 0x73, 0x74,
0x2e, 0x77, 0x65, 0x61, 0x6b, 0x22, 0x22, 0x0a, 0x12, 0x57, 0x65, 0x61, 0x6b, 0x49, 0x6d, 0x70,
0x6f, 0x72, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x31, 0x12, 0x0c, 0x0a, 0x01, 0x61,
- 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x01, 0x61, 0x42, 0x3b, 0x5a, 0x39, 0x67, 0x6f, 0x6f,
+ 0x18, 0x01, 0x20, 0x02, 0x28, 0x05, 0x52, 0x01, 0x61, 0x42, 0x3b, 0x5a, 0x39, 0x67, 0x6f, 0x6f,
0x67, 0x6c, 0x65, 0x2e, 0x67, 0x6f, 0x6c, 0x61, 0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x70,
0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c,
0x2f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x2f, 0x74, 0x65, 0x73, 0x74,
diff --git a/internal/testprotos/test/weak1/test_weak.proto b/internal/testprotos/test/weak1/test_weak.proto
index 5a4814d..113ed1b 100644
--- a/internal/testprotos/test/weak1/test_weak.proto
+++ b/internal/testprotos/test/weak1/test_weak.proto
@@ -9,5 +9,5 @@
option go_package = "google.golang.org/protobuf/internal/testprotos/test/weak1";
message WeakImportMessage1 {
- optional int32 a = 1;
+ required int32 a = 1;
}
diff --git a/internal/testprotos/test/weak2/test_weak.pb.go b/internal/testprotos/test/weak2/test_weak.pb.go
index a1d2e07..e04031c 100644
--- a/internal/testprotos/test/weak2/test_weak.pb.go
+++ b/internal/testprotos/test/weak2/test_weak.pb.go
@@ -19,7 +19,7 @@
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
- A *int32 `protobuf:"varint,1,opt,name=a" json:"a,omitempty"`
+ A *int32 `protobuf:"varint,1,req,name=a" json:"a,omitempty"`
}
func (x *WeakImportMessage2) Reset() {
@@ -64,7 +64,7 @@
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x74, 0x65, 0x73, 0x74,
0x2e, 0x77, 0x65, 0x61, 0x6b, 0x22, 0x22, 0x0a, 0x12, 0x57, 0x65, 0x61, 0x6b, 0x49, 0x6d, 0x70,
0x6f, 0x72, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x32, 0x12, 0x0c, 0x0a, 0x01, 0x61,
- 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x01, 0x61, 0x42, 0x3b, 0x5a, 0x39, 0x67, 0x6f, 0x6f,
+ 0x18, 0x01, 0x20, 0x02, 0x28, 0x05, 0x52, 0x01, 0x61, 0x42, 0x3b, 0x5a, 0x39, 0x67, 0x6f, 0x6f,
0x67, 0x6c, 0x65, 0x2e, 0x67, 0x6f, 0x6c, 0x61, 0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x70,
0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c,
0x2f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x2f, 0x74, 0x65, 0x73, 0x74,
diff --git a/internal/testprotos/test/weak2/test_weak.proto b/internal/testprotos/test/weak2/test_weak.proto
index eb4d0ff..61a03d6 100644
--- a/internal/testprotos/test/weak2/test_weak.proto
+++ b/internal/testprotos/test/weak2/test_weak.proto
@@ -9,5 +9,5 @@
option go_package = "google.golang.org/protobuf/internal/testprotos/test/weak2";
message WeakImportMessage2 {
- optional int32 a = 1;
+ required int32 a = 1;
}
diff --git a/proto/decode.go b/proto/decode.go
index ab62ff3..f64f887 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -8,6 +8,7 @@
"google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/internal/errors"
+ "google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/internal/pragma"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
@@ -88,7 +89,7 @@
return wire.ParseError(tagLen)
}
- // Parse the field value.
+ // Find the field descriptor for this field number.
fd := fields.ByNumber(num)
if fd == nil && md.ExtensionRanges().Has(num) {
extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
@@ -100,10 +101,18 @@
}
}
var err error
+ if fd == nil {
+ err = errUnknown
+ } else if flags.ProtoLegacy {
+ if fd.IsWeak() && fd.Message().IsPlaceholder() {
+ err = errUnknown // weak referent is not linked in
+ }
+ }
+
+ // Parse the field value.
var valLen int
switch {
- case fd == nil:
- err = errUnknown
+ case err != nil:
case fd.IsList():
valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd)
case fd.IsMap():
@@ -111,14 +120,15 @@
default:
valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd)
}
- if err == errUnknown {
+ if err != nil {
+ if err != errUnknown {
+ return err
+ }
valLen = wire.ConsumeFieldValue(num, wtyp, b[tagLen:])
if valLen < 0 {
return wire.ParseError(valLen)
}
m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
- } else if err != nil {
- return err
}
b = b[tagLen+valLen:]
}
diff --git a/proto/decode_test.go b/proto/decode_test.go
index c696778..b6258da 100644
--- a/proto/decode_test.go
+++ b/proto/decode_test.go
@@ -21,6 +21,7 @@
legacypb "google.golang.org/protobuf/internal/testprotos/legacy"
legacy1pb "google.golang.org/protobuf/internal/testprotos/legacy/proto2.v0.0.0-20160225-2fc053c5"
testpb "google.golang.org/protobuf/internal/testprotos/test"
+ weakpb "google.golang.org/protobuf/internal/testprotos/test/weak1"
test3pb "google.golang.org/protobuf/internal/testprotos/test3"
"google.golang.org/protobuf/types/descriptorpb"
)
@@ -1726,6 +1727,46 @@
},
}
+func TestWeak(t *testing.T) {
+ if !flags.ProtoLegacy {
+ t.SkipNow()
+ }
+
+ m := new(testpb.TestWeak)
+ b := pack.Message{
+ pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(1000),
+ }),
+ pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(2000),
+ }),
+ }.Marshal()
+ if err := proto.Unmarshal(b, m); err != nil {
+ t.Errorf("Unmarshal error: %v", err)
+ }
+
+ mw := m.GetWeakMessage1().(*weakpb.WeakImportMessage1)
+ if mw.GetA() != 1000 {
+ t.Errorf("m.WeakMessage1.a = %d, want %d", mw.GetA(), 1000)
+ }
+
+ if len(m.ProtoReflect().GetUnknown()) == 0 {
+ t.Errorf("m has no unknown fields, expected at least something")
+ }
+
+ if n := proto.Size(m); n != len(b) {
+ t.Errorf("Size() = %d, want %d", n, len(b))
+ }
+
+ b2, err := proto.Marshal(m)
+ if err != nil {
+ t.Errorf("Marshal error: %v", err)
+ }
+ if len(b2) != len(b) {
+ t.Errorf("len(Marshal) = %d, want %d", len(b2), len(b))
+ }
+}
+
func build(m proto.Message, opts ...buildOpt) proto.Message {
for _, opt := range opts {
opt(m)
diff --git a/proto/isinit_test.go b/proto/isinit_test.go
index d935d3a..6a3a8c9 100644
--- a/proto/isinit_test.go
+++ b/proto/isinit_test.go
@@ -9,51 +9,80 @@
"strings"
"testing"
+ "google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/proto"
testpb "google.golang.org/protobuf/internal/testprotos/test"
+ weakpb "google.golang.org/protobuf/internal/testprotos/test/weak1"
)
func TestIsInitializedErrors(t *testing.T) {
- for _, test := range []struct {
+ type test struct {
m proto.Message
want string
- }{
- {
- &testpb.TestRequired{},
- `goproto.proto.test.TestRequired.required_field`,
+ skip bool
+ }
+ tests := []test{{
+ m: &testpb.TestRequired{},
+ want: `goproto.proto.test.TestRequired.required_field`,
+ }, {
+ m: &testpb.TestRequiredForeign{
+ OptionalMessage: &testpb.TestRequired{},
},
- {
- &testpb.TestRequiredForeign{
- OptionalMessage: &testpb.TestRequired{},
+ want: `goproto.proto.test.TestRequired.required_field`,
+ }, {
+ m: &testpb.TestRequiredForeign{
+ RepeatedMessage: []*testpb.TestRequired{
+ {RequiredField: proto.Int32(1)},
+ {},
},
- `goproto.proto.test.TestRequired.required_field`,
},
- {
- &testpb.TestRequiredForeign{
- RepeatedMessage: []*testpb.TestRequired{
- {RequiredField: proto.Int32(1)},
- {},
- },
+ want: `goproto.proto.test.TestRequired.required_field`,
+ }, {
+ m: &testpb.TestRequiredForeign{
+ MapMessage: map[int32]*testpb.TestRequired{
+ 1: {},
},
- `goproto.proto.test.TestRequired.required_field`,
},
- {
- &testpb.TestRequiredForeign{
- MapMessage: map[int32]*testpb.TestRequired{
- 1: {},
- },
- },
- `goproto.proto.test.TestRequired.required_field`,
- },
- } {
- err := proto.IsInitialized(test.m)
- got := "<nil>"
- if err != nil {
- got = fmt.Sprintf("%q", err)
- }
- if !strings.Contains(got, test.want) {
- t.Errorf("IsInitialized(m):\n got: %v\nwant contains: %v\nMessage:\n%v", got, test.want, marshalText(test.m))
- }
+ want: `goproto.proto.test.TestRequired.required_field`,
+ }, {
+ m: &testpb.TestWeak{},
+ want: `<nil>`,
+ skip: !flags.ProtoLegacy,
+ }, {
+ m: func() proto.Message {
+ m := &testpb.TestWeak{}
+ m.SetWeakMessage1(&weakpb.WeakImportMessage1{})
+ return m
+ }(),
+ want: `goproto.proto.test.weak.WeakImportMessage1.a`,
+ skip: !flags.ProtoLegacy,
+ }, {
+ m: func() proto.Message {
+ m := &testpb.TestWeak{}
+ m.SetWeakMessage1(&weakpb.WeakImportMessage1{
+ A: proto.Int32(1),
+ })
+ return m
+ }(),
+ want: `<nil>`,
+ skip: !flags.ProtoLegacy,
+ }}
+
+ for _, tt := range tests {
+ t.Run("", func(t *testing.T) {
+ if tt.skip {
+ t.SkipNow()
+ }
+
+ err := proto.IsInitialized(tt.m)
+ got := "<nil>"
+ if err != nil {
+ got = fmt.Sprintf("%q", err)
+ }
+ if !strings.Contains(got, tt.want) {
+ t.Errorf("IsInitialized(m):\n got: %v\nwant contains: %v\nMessage:\n%v", got, tt.want, marshalText(tt.m))
+ }
+ })
}
}