encoding/jsonpb: add support for marshaling of extensions and messagesets
Change-Id: I839660146760a66c5cbf25d24f81f0ba5096d9e1
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/167395
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/encoding/jsonpb/encode.go b/encoding/jsonpb/encode.go
index 4152a74..3830f31 100644
--- a/encoding/jsonpb/encode.go
+++ b/encoding/jsonpb/encode.go
@@ -13,6 +13,8 @@
"github.com/golang/protobuf/v2/internal/pragma"
"github.com/golang/protobuf/v2/proto"
pref "github.com/golang/protobuf/v2/reflect/protoreflect"
+
+ descpb "github.com/golang/protobuf/v2/types/descriptor"
)
// Marshal writes the given proto.Message in JSON format using default options.
@@ -70,6 +72,7 @@
fieldDescs := m.Type().Fields()
knownFields := m.KnownFields()
+ // Marshal out known fields.
for i := 0; i < fieldDescs.Len(); i++ {
fd := fieldDescs.Get(i)
num := fd.Number()
@@ -92,6 +95,11 @@
return err
}
}
+
+ // Marshal out extensions.
+ if err := e.marshalExtensions(knownFields); !nerr.Merge(err) {
+ return err
+ }
return nerr.E
}
@@ -222,7 +230,6 @@
if err := e.WriteName(entry.key.String()); !nerr.Merge(err) {
return err
}
-
if err := e.marshalSingular(entry.value, valType); !nerr.Merge(err) {
return err
}
@@ -230,22 +237,94 @@
return nerr.E
}
-// sortMap orders list based on value of key field for deterministic output.
+// sortMap orders list based on value of key field for deterministic ordering.
func sortMap(keyKind pref.Kind, values []mapEntry) {
- less := func(i, j int) bool {
- return values[i].key.String() < values[j].key.String()
- }
- switch keyKind {
- case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind,
- pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
- less = func(i, j int) bool {
+ sort.Slice(values, func(i, j int) bool {
+ switch keyKind {
+ case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind,
+ pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
return values[i].key.Int() < values[j].key.Int()
- }
- case pref.Uint32Kind, pref.Fixed32Kind,
- pref.Uint64Kind, pref.Fixed64Kind:
- less = func(i, j int) bool {
+
+ case pref.Uint32Kind, pref.Fixed32Kind,
+ pref.Uint64Kind, pref.Fixed64Kind:
return values[i].key.Uint() < values[j].key.Uint()
}
+ return values[i].key.String() < values[j].key.String()
+ })
+}
+
+// marshalExtensions marshals extension fields.
+func (e encoder) marshalExtensions(knownFields pref.KnownFields) error {
+ type xtEntry struct {
+ key string
+ value pref.Value
+ xtType pref.ExtensionType
}
- sort.Slice(values, less)
+
+ xtTypes := knownFields.ExtensionTypes()
+
+ // Get a sorted list based on field key first.
+ entries := make([]xtEntry, 0, xtTypes.Len())
+ xtTypes.Range(func(xt pref.ExtensionType) bool {
+ name := xt.FullName()
+ // If extended type is a MessageSet, set field name to be the message type name.
+ if isMessageSetExtension(xt) {
+ name = xt.MessageType().FullName()
+ }
+
+ num := xt.Number()
+ if knownFields.Has(num) {
+ // Use [name] format for JSON field name.
+ pval := knownFields.Get(num)
+ entries = append(entries, xtEntry{
+ key: string(name),
+ value: pval,
+ xtType: xt,
+ })
+ }
+ return true
+ })
+
+ // Sort extensions lexicographically.
+ sort.Slice(entries, func(i, j int) bool {
+ return entries[i].key < entries[j].key
+ })
+
+ // Write out sorted list.
+ var nerr errors.NonFatal
+ for _, entry := range entries {
+ // JSON field name is the proto field name enclosed in [], similar to
+ // textproto. This is consistent with Go v1 lib. C++ lib v3.7.0 does not
+ // marshal out extension fields.
+ if err := e.WriteName("[" + entry.key + "]"); !nerr.Merge(err) {
+ return err
+ }
+ if err := e.marshalValue(entry.value, entry.xtType); !nerr.Merge(err) {
+ return err
+ }
+ }
+ return nerr.E
+}
+
+// isMessageSetExtension reports whether extension extends a message set.
+func isMessageSetExtension(xt pref.ExtensionType) bool {
+ if xt.Name() != "message_set_extension" {
+ return false
+ }
+ mt := xt.MessageType()
+ if mt == nil {
+ return false
+ }
+ if xt.FullName().Parent() != mt.FullName() {
+ return false
+ }
+ xmt := xt.ExtendedType()
+ if xmt.Fields().Len() != 0 {
+ return false
+ }
+ opt := xmt.Options().(*descpb.MessageOptions)
+ if opt == nil {
+ return false
+ }
+ return opt.GetMessageSetWireFormat()
}
diff --git a/encoding/jsonpb/encode_test.go b/encoding/jsonpb/encode_test.go
index 43e7529..280e8db 100644
--- a/encoding/jsonpb/encode_test.go
+++ b/encoding/jsonpb/encode_test.go
@@ -9,13 +9,18 @@
"strings"
"testing"
+ "github.com/golang/protobuf/protoapi"
"github.com/golang/protobuf/v2/encoding/jsonpb"
"github.com/golang/protobuf/v2/internal/encoding/pack"
+ "github.com/golang/protobuf/v2/internal/encoding/wire"
"github.com/golang/protobuf/v2/internal/scalar"
"github.com/golang/protobuf/v2/proto"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
+ // This legacy package is still needed when importing legacy message.
+ _ "github.com/golang/protobuf/v2/internal/legacy"
+
"github.com/golang/protobuf/v2/encoding/testprotos/pb2"
"github.com/golang/protobuf/v2/encoding/testprotos/pb3"
)
@@ -37,6 +42,17 @@
return p
}
+func setExtension(m proto.Message, xd *protoapi.ExtensionDesc, val interface{}) {
+ knownFields := m.ProtoReflect().KnownFields()
+ extTypes := knownFields.ExtensionTypes()
+ extTypes.Register(xd.Type)
+ if val == nil {
+ return
+ }
+ pval := xd.Type.ValueOf(val)
+ knownFields.Set(wire.Number(xd.Field), pval)
+}
+
func TestMarshal(t *testing.T) {
tests := []struct {
desc string
@@ -701,12 +717,209 @@
want: `{
"foo_bar": "json_name"
}`,
+ }, {
+ desc: "extensions of non-repeated fields",
+ input: func() proto.Message {
+ m := &pb2.Extensions{
+ OptString: scalar.String("non-extension field"),
+ OptBool: scalar.Bool(true),
+ OptInt32: scalar.Int32(42),
+ }
+ setExtension(m, pb2.E_OptExtBool, true)
+ setExtension(m, pb2.E_OptExtString, "extension field")
+ setExtension(m, pb2.E_OptExtEnum, pb2.Enum_TEN)
+ setExtension(m, pb2.E_OptExtNested, &pb2.Nested{
+ OptString: scalar.String("nested in an extension"),
+ OptNested: &pb2.Nested{
+ OptString: scalar.String("another nested in an extension"),
+ },
+ })
+ return m
+ }(),
+ want: `{
+ "optString": "non-extension field",
+ "optBool": true,
+ "optInt32": 42,
+ "[pb2.opt_ext_bool]": true,
+ "[pb2.opt_ext_enum]": "TEN",
+ "[pb2.opt_ext_nested]": {
+ "optString": "nested in an extension",
+ "optNested": {
+ "optString": "another nested in an extension"
+ }
+ },
+ "[pb2.opt_ext_string]": "extension field"
+}`,
+ }, {
+ desc: "extension message field set to nil",
+ input: func() proto.Message {
+ m := &pb2.Extensions{}
+ setExtension(m, pb2.E_OptExtNested, nil)
+ return m
+ }(),
+ want: "{}",
+ }, {
+ desc: "extensions of repeated fields",
+ input: func() proto.Message {
+ m := &pb2.Extensions{}
+ setExtension(m, pb2.E_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
+ setExtension(m, pb2.E_RptExtFixed32, &[]uint32{42, 47})
+ setExtension(m, pb2.E_RptExtNested, &[]*pb2.Nested{
+ &pb2.Nested{OptString: scalar.String("one")},
+ &pb2.Nested{OptString: scalar.String("two")},
+ &pb2.Nested{OptString: scalar.String("three")},
+ })
+ return m
+ }(),
+ want: `{
+ "[pb2.rpt_ext_enum]": [
+ "TEN",
+ 101,
+ "ONE"
+ ],
+ "[pb2.rpt_ext_fixed32]": [
+ 42,
+ 47
+ ],
+ "[pb2.rpt_ext_nested]": [
+ {
+ "optString": "one"
+ },
+ {
+ "optString": "two"
+ },
+ {
+ "optString": "three"
+ }
+ ]
+}`,
+ }, {
+ desc: "extensions of non-repeated fields in another message",
+ input: func() proto.Message {
+ m := &pb2.Extensions{}
+ setExtension(m, pb2.E_ExtensionsContainer_OptExtBool, true)
+ setExtension(m, pb2.E_ExtensionsContainer_OptExtString, "extension field")
+ setExtension(m, pb2.E_ExtensionsContainer_OptExtEnum, pb2.Enum_TEN)
+ setExtension(m, pb2.E_ExtensionsContainer_OptExtNested, &pb2.Nested{
+ OptString: scalar.String("nested in an extension"),
+ OptNested: &pb2.Nested{
+ OptString: scalar.String("another nested in an extension"),
+ },
+ })
+ return m
+ }(),
+ want: `{
+ "[pb2.ExtensionsContainer.opt_ext_bool]": true,
+ "[pb2.ExtensionsContainer.opt_ext_enum]": "TEN",
+ "[pb2.ExtensionsContainer.opt_ext_nested]": {
+ "optString": "nested in an extension",
+ "optNested": {
+ "optString": "another nested in an extension"
+ }
+ },
+ "[pb2.ExtensionsContainer.opt_ext_string]": "extension field"
+}`,
+ }, {
+ desc: "extensions of repeated fields in another message",
+ input: func() proto.Message {
+ m := &pb2.Extensions{
+ OptString: scalar.String("non-extension field"),
+ OptBool: scalar.Bool(true),
+ OptInt32: scalar.Int32(42),
+ }
+ setExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
+ setExtension(m, pb2.E_ExtensionsContainer_RptExtString, &[]string{"hello", "world"})
+ setExtension(m, pb2.E_ExtensionsContainer_RptExtNested, &[]*pb2.Nested{
+ &pb2.Nested{OptString: scalar.String("one")},
+ &pb2.Nested{OptString: scalar.String("two")},
+ &pb2.Nested{OptString: scalar.String("three")},
+ })
+ return m
+ }(),
+ want: `{
+ "optString": "non-extension field",
+ "optBool": true,
+ "optInt32": 42,
+ "[pb2.ExtensionsContainer.rpt_ext_enum]": [
+ "TEN",
+ 101,
+ "ONE"
+ ],
+ "[pb2.ExtensionsContainer.rpt_ext_nested]": [
+ {
+ "optString": "one"
+ },
+ {
+ "optString": "two"
+ },
+ {
+ "optString": "three"
+ }
+ ],
+ "[pb2.ExtensionsContainer.rpt_ext_string]": [
+ "hello",
+ "world"
+ ]
+}`,
+ }, {
+ desc: "MessageSet",
+ input: func() proto.Message {
+ m := &pb2.MessageSet{}
+ setExtension(m, pb2.E_MessageSetExtension_MessageSetExtension, &pb2.MessageSetExtension{
+ OptString: scalar.String("a messageset extension"),
+ })
+ setExtension(m, pb2.E_MessageSetExtension_NotMessageSetExtension, &pb2.MessageSetExtension{
+ OptString: scalar.String("not a messageset extension"),
+ })
+ setExtension(m, pb2.E_MessageSetExtension_ExtNested, &pb2.Nested{
+ OptString: scalar.String("just a regular extension"),
+ })
+ return m
+ }(),
+ want: `{
+ "[pb2.MessageSetExtension]": {
+ "optString": "a messageset extension"
+ },
+ "[pb2.MessageSetExtension.ext_nested]": {
+ "optString": "just a regular extension"
+ },
+ "[pb2.MessageSetExtension.not_message_set_extension]": {
+ "optString": "not a messageset extension"
+ }
+}`,
+ }, {
+ desc: "not real MessageSet 1",
+ input: func() proto.Message {
+ m := &pb2.FakeMessageSet{}
+ setExtension(m, pb2.E_FakeMessageSetExtension_MessageSetExtension, &pb2.FakeMessageSetExtension{
+ OptString: scalar.String("not a messageset extension"),
+ })
+ return m
+ }(),
+ want: `{
+ "[pb2.FakeMessageSetExtension.message_set_extension]": {
+ "optString": "not a messageset extension"
+ }
+}`,
+ }, {
+ desc: "not real MessageSet 2",
+ input: func() proto.Message {
+ m := &pb2.MessageSet{}
+ setExtension(m, pb2.E_MessageSetExtension, &pb2.FakeMessageSetExtension{
+ OptString: scalar.String("another not a messageset extension"),
+ })
+ return m
+ }(),
+ want: `{
+ "[pb2.message_set_extension]": {
+ "optString": "another not a messageset extension"
+ }
+}`,
}}
for _, tt := range tests {
tt := tt
t.Run(tt.desc, func(t *testing.T) {
- t.Parallel()
b, err := tt.mo.Marshal(tt.input)
if err != nil {
t.Errorf("Marshal() returned error: %v\n", err)