encoding/jsonpb: add support for marshaling of extensions and messagesets

Change-Id: I839660146760a66c5cbf25d24f81f0ba5096d9e1
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/167395
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/encoding/jsonpb/encode.go b/encoding/jsonpb/encode.go
index 4152a74..3830f31 100644
--- a/encoding/jsonpb/encode.go
+++ b/encoding/jsonpb/encode.go
@@ -13,6 +13,8 @@
 	"github.com/golang/protobuf/v2/internal/pragma"
 	"github.com/golang/protobuf/v2/proto"
 	pref "github.com/golang/protobuf/v2/reflect/protoreflect"
+
+	descpb "github.com/golang/protobuf/v2/types/descriptor"
 )
 
 // Marshal writes the given proto.Message in JSON format using default options.
@@ -70,6 +72,7 @@
 	fieldDescs := m.Type().Fields()
 	knownFields := m.KnownFields()
 
+	// Marshal out known fields.
 	for i := 0; i < fieldDescs.Len(); i++ {
 		fd := fieldDescs.Get(i)
 		num := fd.Number()
@@ -92,6 +95,11 @@
 			return err
 		}
 	}
+
+	// Marshal out extensions.
+	if err := e.marshalExtensions(knownFields); !nerr.Merge(err) {
+		return err
+	}
 	return nerr.E
 }
 
@@ -222,7 +230,6 @@
 		if err := e.WriteName(entry.key.String()); !nerr.Merge(err) {
 			return err
 		}
-
 		if err := e.marshalSingular(entry.value, valType); !nerr.Merge(err) {
 			return err
 		}
@@ -230,22 +237,94 @@
 	return nerr.E
 }
 
-// sortMap orders list based on value of key field for deterministic output.
+// sortMap orders list based on value of key field for deterministic ordering.
 func sortMap(keyKind pref.Kind, values []mapEntry) {
-	less := func(i, j int) bool {
-		return values[i].key.String() < values[j].key.String()
-	}
-	switch keyKind {
-	case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind,
-		pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
-		less = func(i, j int) bool {
+	sort.Slice(values, func(i, j int) bool {
+		switch keyKind {
+		case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind,
+			pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
 			return values[i].key.Int() < values[j].key.Int()
-		}
-	case pref.Uint32Kind, pref.Fixed32Kind,
-		pref.Uint64Kind, pref.Fixed64Kind:
-		less = func(i, j int) bool {
+
+		case pref.Uint32Kind, pref.Fixed32Kind,
+			pref.Uint64Kind, pref.Fixed64Kind:
 			return values[i].key.Uint() < values[j].key.Uint()
 		}
+		return values[i].key.String() < values[j].key.String()
+	})
+}
+
+// marshalExtensions marshals extension fields.
+func (e encoder) marshalExtensions(knownFields pref.KnownFields) error {
+	type xtEntry struct {
+		key    string
+		value  pref.Value
+		xtType pref.ExtensionType
 	}
-	sort.Slice(values, less)
+
+	xtTypes := knownFields.ExtensionTypes()
+
+	// Get a sorted list based on field key first.
+	entries := make([]xtEntry, 0, xtTypes.Len())
+	xtTypes.Range(func(xt pref.ExtensionType) bool {
+		name := xt.FullName()
+		// If extended type is a MessageSet, set field name to be the message type name.
+		if isMessageSetExtension(xt) {
+			name = xt.MessageType().FullName()
+		}
+
+		num := xt.Number()
+		if knownFields.Has(num) {
+			// Use [name] format for JSON field name.
+			pval := knownFields.Get(num)
+			entries = append(entries, xtEntry{
+				key:    string(name),
+				value:  pval,
+				xtType: xt,
+			})
+		}
+		return true
+	})
+
+	// Sort extensions lexicographically.
+	sort.Slice(entries, func(i, j int) bool {
+		return entries[i].key < entries[j].key
+	})
+
+	// Write out sorted list.
+	var nerr errors.NonFatal
+	for _, entry := range entries {
+		// JSON field name is the proto field name enclosed in [], similar to
+		// textproto. This is consistent with Go v1 lib. C++ lib v3.7.0 does not
+		// marshal out extension fields.
+		if err := e.WriteName("[" + entry.key + "]"); !nerr.Merge(err) {
+			return err
+		}
+		if err := e.marshalValue(entry.value, entry.xtType); !nerr.Merge(err) {
+			return err
+		}
+	}
+	return nerr.E
+}
+
+// isMessageSetExtension reports whether extension extends a message set.
+func isMessageSetExtension(xt pref.ExtensionType) bool {
+	if xt.Name() != "message_set_extension" {
+		return false
+	}
+	mt := xt.MessageType()
+	if mt == nil {
+		return false
+	}
+	if xt.FullName().Parent() != mt.FullName() {
+		return false
+	}
+	xmt := xt.ExtendedType()
+	if xmt.Fields().Len() != 0 {
+		return false
+	}
+	opt := xmt.Options().(*descpb.MessageOptions)
+	if opt == nil {
+		return false
+	}
+	return opt.GetMessageSetWireFormat()
 }
diff --git a/encoding/jsonpb/encode_test.go b/encoding/jsonpb/encode_test.go
index 43e7529..280e8db 100644
--- a/encoding/jsonpb/encode_test.go
+++ b/encoding/jsonpb/encode_test.go
@@ -9,13 +9,18 @@
 	"strings"
 	"testing"
 
+	"github.com/golang/protobuf/protoapi"
 	"github.com/golang/protobuf/v2/encoding/jsonpb"
 	"github.com/golang/protobuf/v2/internal/encoding/pack"
+	"github.com/golang/protobuf/v2/internal/encoding/wire"
 	"github.com/golang/protobuf/v2/internal/scalar"
 	"github.com/golang/protobuf/v2/proto"
 	"github.com/google/go-cmp/cmp"
 	"github.com/google/go-cmp/cmp/cmpopts"
 
+	// This legacy package is still needed when importing legacy message.
+	_ "github.com/golang/protobuf/v2/internal/legacy"
+
 	"github.com/golang/protobuf/v2/encoding/testprotos/pb2"
 	"github.com/golang/protobuf/v2/encoding/testprotos/pb3"
 )
@@ -37,6 +42,17 @@
 	return p
 }
 
+func setExtension(m proto.Message, xd *protoapi.ExtensionDesc, val interface{}) {
+	knownFields := m.ProtoReflect().KnownFields()
+	extTypes := knownFields.ExtensionTypes()
+	extTypes.Register(xd.Type)
+	if val == nil {
+		return
+	}
+	pval := xd.Type.ValueOf(val)
+	knownFields.Set(wire.Number(xd.Field), pval)
+}
+
 func TestMarshal(t *testing.T) {
 	tests := []struct {
 		desc  string
@@ -701,12 +717,209 @@
 		want: `{
   "foo_bar": "json_name"
 }`,
+	}, {
+		desc: "extensions of non-repeated fields",
+		input: func() proto.Message {
+			m := &pb2.Extensions{
+				OptString: scalar.String("non-extension field"),
+				OptBool:   scalar.Bool(true),
+				OptInt32:  scalar.Int32(42),
+			}
+			setExtension(m, pb2.E_OptExtBool, true)
+			setExtension(m, pb2.E_OptExtString, "extension field")
+			setExtension(m, pb2.E_OptExtEnum, pb2.Enum_TEN)
+			setExtension(m, pb2.E_OptExtNested, &pb2.Nested{
+				OptString: scalar.String("nested in an extension"),
+				OptNested: &pb2.Nested{
+					OptString: scalar.String("another nested in an extension"),
+				},
+			})
+			return m
+		}(),
+		want: `{
+  "optString": "non-extension field",
+  "optBool": true,
+  "optInt32": 42,
+  "[pb2.opt_ext_bool]": true,
+  "[pb2.opt_ext_enum]": "TEN",
+  "[pb2.opt_ext_nested]": {
+    "optString": "nested in an extension",
+    "optNested": {
+      "optString": "another nested in an extension"
+    }
+  },
+  "[pb2.opt_ext_string]": "extension field"
+}`,
+	}, {
+		desc: "extension message field set to nil",
+		input: func() proto.Message {
+			m := &pb2.Extensions{}
+			setExtension(m, pb2.E_OptExtNested, nil)
+			return m
+		}(),
+		want: "{}",
+	}, {
+		desc: "extensions of repeated fields",
+		input: func() proto.Message {
+			m := &pb2.Extensions{}
+			setExtension(m, pb2.E_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
+			setExtension(m, pb2.E_RptExtFixed32, &[]uint32{42, 47})
+			setExtension(m, pb2.E_RptExtNested, &[]*pb2.Nested{
+				&pb2.Nested{OptString: scalar.String("one")},
+				&pb2.Nested{OptString: scalar.String("two")},
+				&pb2.Nested{OptString: scalar.String("three")},
+			})
+			return m
+		}(),
+		want: `{
+  "[pb2.rpt_ext_enum]": [
+    "TEN",
+    101,
+    "ONE"
+  ],
+  "[pb2.rpt_ext_fixed32]": [
+    42,
+    47
+  ],
+  "[pb2.rpt_ext_nested]": [
+    {
+      "optString": "one"
+    },
+    {
+      "optString": "two"
+    },
+    {
+      "optString": "three"
+    }
+  ]
+}`,
+	}, {
+		desc: "extensions of non-repeated fields in another message",
+		input: func() proto.Message {
+			m := &pb2.Extensions{}
+			setExtension(m, pb2.E_ExtensionsContainer_OptExtBool, true)
+			setExtension(m, pb2.E_ExtensionsContainer_OptExtString, "extension field")
+			setExtension(m, pb2.E_ExtensionsContainer_OptExtEnum, pb2.Enum_TEN)
+			setExtension(m, pb2.E_ExtensionsContainer_OptExtNested, &pb2.Nested{
+				OptString: scalar.String("nested in an extension"),
+				OptNested: &pb2.Nested{
+					OptString: scalar.String("another nested in an extension"),
+				},
+			})
+			return m
+		}(),
+		want: `{
+  "[pb2.ExtensionsContainer.opt_ext_bool]": true,
+  "[pb2.ExtensionsContainer.opt_ext_enum]": "TEN",
+  "[pb2.ExtensionsContainer.opt_ext_nested]": {
+    "optString": "nested in an extension",
+    "optNested": {
+      "optString": "another nested in an extension"
+    }
+  },
+  "[pb2.ExtensionsContainer.opt_ext_string]": "extension field"
+}`,
+	}, {
+		desc: "extensions of repeated fields in another message",
+		input: func() proto.Message {
+			m := &pb2.Extensions{
+				OptString: scalar.String("non-extension field"),
+				OptBool:   scalar.Bool(true),
+				OptInt32:  scalar.Int32(42),
+			}
+			setExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
+			setExtension(m, pb2.E_ExtensionsContainer_RptExtString, &[]string{"hello", "world"})
+			setExtension(m, pb2.E_ExtensionsContainer_RptExtNested, &[]*pb2.Nested{
+				&pb2.Nested{OptString: scalar.String("one")},
+				&pb2.Nested{OptString: scalar.String("two")},
+				&pb2.Nested{OptString: scalar.String("three")},
+			})
+			return m
+		}(),
+		want: `{
+  "optString": "non-extension field",
+  "optBool": true,
+  "optInt32": 42,
+  "[pb2.ExtensionsContainer.rpt_ext_enum]": [
+    "TEN",
+    101,
+    "ONE"
+  ],
+  "[pb2.ExtensionsContainer.rpt_ext_nested]": [
+    {
+      "optString": "one"
+    },
+    {
+      "optString": "two"
+    },
+    {
+      "optString": "three"
+    }
+  ],
+  "[pb2.ExtensionsContainer.rpt_ext_string]": [
+    "hello",
+    "world"
+  ]
+}`,
+	}, {
+		desc: "MessageSet",
+		input: func() proto.Message {
+			m := &pb2.MessageSet{}
+			setExtension(m, pb2.E_MessageSetExtension_MessageSetExtension, &pb2.MessageSetExtension{
+				OptString: scalar.String("a messageset extension"),
+			})
+			setExtension(m, pb2.E_MessageSetExtension_NotMessageSetExtension, &pb2.MessageSetExtension{
+				OptString: scalar.String("not a messageset extension"),
+			})
+			setExtension(m, pb2.E_MessageSetExtension_ExtNested, &pb2.Nested{
+				OptString: scalar.String("just a regular extension"),
+			})
+			return m
+		}(),
+		want: `{
+  "[pb2.MessageSetExtension]": {
+    "optString": "a messageset extension"
+  },
+  "[pb2.MessageSetExtension.ext_nested]": {
+    "optString": "just a regular extension"
+  },
+  "[pb2.MessageSetExtension.not_message_set_extension]": {
+    "optString": "not a messageset extension"
+  }
+}`,
+	}, {
+		desc: "not real MessageSet 1",
+		input: func() proto.Message {
+			m := &pb2.FakeMessageSet{}
+			setExtension(m, pb2.E_FakeMessageSetExtension_MessageSetExtension, &pb2.FakeMessageSetExtension{
+				OptString: scalar.String("not a messageset extension"),
+			})
+			return m
+		}(),
+		want: `{
+  "[pb2.FakeMessageSetExtension.message_set_extension]": {
+    "optString": "not a messageset extension"
+  }
+}`,
+	}, {
+		desc: "not real MessageSet 2",
+		input: func() proto.Message {
+			m := &pb2.MessageSet{}
+			setExtension(m, pb2.E_MessageSetExtension, &pb2.FakeMessageSetExtension{
+				OptString: scalar.String("another not a messageset extension"),
+			})
+			return m
+		}(),
+		want: `{
+  "[pb2.message_set_extension]": {
+    "optString": "another not a messageset extension"
+  }
+}`,
 	}}
 
 	for _, tt := range tests {
 		tt := tt
 		t.Run(tt.desc, func(t *testing.T) {
-			t.Parallel()
 			b, err := tt.mo.Marshal(tt.input)
 			if err != nil {
 				t.Errorf("Marshal() returned error: %v\n", err)