encoding/jsonpb: add support for unmarshaling extensions and messagesets

Change-Id: I7f3e0091c4f46924d2e8a08c614c7ab64917014c
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/167773
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/encoding/jsonpb/decode.go b/encoding/jsonpb/decode.go
index a43efd1..5ea6fb8 100644
--- a/encoding/jsonpb/decode.go
+++ b/encoding/jsonpb/decode.go
@@ -13,9 +13,11 @@
 
 	"github.com/golang/protobuf/v2/internal/encoding/json"
 	"github.com/golang/protobuf/v2/internal/errors"
+	"github.com/golang/protobuf/v2/internal/pragma"
 	"github.com/golang/protobuf/v2/internal/set"
 	"github.com/golang/protobuf/v2/proto"
 	pref "github.com/golang/protobuf/v2/reflect/protoreflect"
+	"github.com/golang/protobuf/v2/reflect/protoregistry"
 )
 
 // Unmarshal reads the given []byte into the given proto.Message.
@@ -24,7 +26,14 @@
 }
 
 // UnmarshalOptions is a configurable JSON format parser.
-type UnmarshalOptions struct{}
+type UnmarshalOptions struct {
+	pragma.NoUnkeyedLiterals
+
+	// Resolver is the registry used for type lookups when unmarshaling extensions
+	// and processing Any. If Resolver is not set, unmarshaling will default to
+	// using protoregistry.GlobalTypes.
+	Resolver *protoregistry.Types
+}
 
 // Unmarshal reads the given []byte and populates the given proto.Message using
 // options in UnmarshalOptions object. It will clear the message first before
@@ -37,7 +46,15 @@
 	// marshaling.
 	resetMessage(mr)
 
-	dec := decoder{json.NewDecoder(b)}
+	resolver := o.Resolver
+	if resolver == nil {
+		resolver = protoregistry.GlobalTypes
+	}
+
+	dec := decoder{
+		Decoder:  json.NewDecoder(b),
+		resolver: resolver,
+	}
 	var nerr errors.NonFatal
 	if err := dec.unmarshalMessage(mr); !nerr.Merge(err) {
 		return err
@@ -108,6 +125,7 @@
 // decoder decodes JSON into protoreflect values.
 type decoder struct {
 	*json.Decoder
+	resolver *protoregistry.Types
 }
 
 // unmarshalMessage unmarshals a message into the given protoreflect.Message.
@@ -119,6 +137,7 @@
 	msgType := m.Type()
 	knownFields := m.KnownFields()
 	fieldDescs := msgType.Fields()
+	xtTypes := knownFields.ExtensionTypes()
 
 	jval, err := d.Read()
 	if !nerr.Merge(err) {
@@ -149,11 +168,28 @@
 			return err
 		}
 
-		// Get the FieldDescriptor based on the field name. The name can either
-		// be the JSON name for the field or the proto field name.
-		fd := fieldDescs.ByJSONName(name)
-		if fd == nil {
-			fd = fieldDescs.ByName(pref.Name(name))
+		// Get the FieldDescriptor.
+		var fd pref.FieldDescriptor
+		if strings.HasPrefix(name, "[") && strings.HasSuffix(name, "]") {
+			// Only extension names are in [name] format.
+			xtName := pref.FullName(name[1 : len(name)-1])
+			xt := xtTypes.ByName(xtName)
+			if xt == nil {
+				xt, err = d.findExtension(xtName)
+				if err != nil && err != protoregistry.NotFound {
+					return errors.New("unable to resolve [%v]: %v", xtName, err)
+				}
+				if xt != nil {
+					xtTypes.Register(xt)
+				}
+			}
+			fd = xt
+		} else {
+			// The name can either be the JSON name or the proto field name.
+			fd = fieldDescs.ByJSONName(name)
+			if fd == nil {
+				fd = fieldDescs.ByName(pref.Name(name))
+			}
 		}
 
 		if fd == nil {
@@ -204,6 +240,21 @@
 	return nerr.E
 }
 
+// findExtension returns protoreflect.ExtensionType from the resolver if found.
+func (d decoder) findExtension(xtName pref.FullName) (pref.ExtensionType, error) {
+	xt, err := d.resolver.FindExtensionByName(xtName)
+	if err == nil {
+		return xt, nil
+	}
+
+	// Check if this is a MessageSet extension field.
+	xt, err = d.resolver.FindExtensionByName(xtName + ".message_set_extension")
+	if err == nil && isMessageSetExtension(xt) {
+		return xt, nil
+	}
+	return nil, protoregistry.NotFound
+}
+
 // unmarshalSingular unmarshals to the non-repeated field specified by the given
 // FieldDescriptor.
 func (d decoder) unmarshalSingular(fd pref.FieldDescriptor, knownFields pref.KnownFields) error {
@@ -294,8 +345,8 @@
 		return getInt(jval, bitSize)
 
 	case json.String:
-		// Use another decoder to decode number from string.
-		dec := decoder{json.NewDecoder([]byte(jval.String()))}
+		// Decode number from string.
+		dec := json.NewDecoder([]byte(jval.String()))
 		var nerr errors.NonFatal
 		jval, err := dec.Read()
 		if !nerr.Merge(err) {
@@ -323,8 +374,8 @@
 		return getUint(jval, bitSize)
 
 	case json.String:
-		// Use another decoder to decode number from string.
-		dec := decoder{json.NewDecoder([]byte(jval.String()))}
+		// Decode number from string.
+		dec := json.NewDecoder([]byte(jval.String()))
 		var nerr errors.NonFatal
 		jval, err := dec.Read()
 		if !nerr.Merge(err) {
@@ -370,8 +421,8 @@
 			}
 			return pref.ValueOf(math.Inf(-1)), nil
 		}
-		// Use another decoder to decode number from string.
-		dec := decoder{json.NewDecoder([]byte(s))}
+		// Decode number from string.
+		dec := json.NewDecoder([]byte(s))
 		var nerr errors.NonFatal
 		jval, err := dec.Read()
 		if !nerr.Merge(err) {
diff --git a/encoding/jsonpb/decode_test.go b/encoding/jsonpb/decode_test.go
index 97a0619..2ef81cd 100644
--- a/encoding/jsonpb/decode_test.go
+++ b/encoding/jsonpb/decode_test.go
@@ -9,13 +9,43 @@
 	"testing"
 
 	protoV1 "github.com/golang/protobuf/proto"
+	"github.com/golang/protobuf/protoapi"
 	"github.com/golang/protobuf/v2/encoding/jsonpb"
 	"github.com/golang/protobuf/v2/encoding/testprotos/pb2"
 	"github.com/golang/protobuf/v2/encoding/testprotos/pb3"
 	"github.com/golang/protobuf/v2/internal/scalar"
 	"github.com/golang/protobuf/v2/proto"
+	preg "github.com/golang/protobuf/v2/reflect/protoregistry"
 )
 
+func init() {
+	// TODO: remove these registerExtension calls when generated code registers
+	// to V2 global registry.
+	registerExtension(pb2.E_OptExtBool)
+	registerExtension(pb2.E_OptExtString)
+	registerExtension(pb2.E_OptExtEnum)
+	registerExtension(pb2.E_OptExtNested)
+	registerExtension(pb2.E_RptExtFixed32)
+	registerExtension(pb2.E_RptExtEnum)
+	registerExtension(pb2.E_RptExtNested)
+	registerExtension(pb2.E_ExtensionsContainer_OptExtBool)
+	registerExtension(pb2.E_ExtensionsContainer_OptExtString)
+	registerExtension(pb2.E_ExtensionsContainer_OptExtEnum)
+	registerExtension(pb2.E_ExtensionsContainer_OptExtNested)
+	registerExtension(pb2.E_ExtensionsContainer_RptExtString)
+	registerExtension(pb2.E_ExtensionsContainer_RptExtEnum)
+	registerExtension(pb2.E_ExtensionsContainer_RptExtNested)
+	registerExtension(pb2.E_MessageSetExtension)
+	registerExtension(pb2.E_MessageSetExtension_MessageSetExtension)
+	registerExtension(pb2.E_MessageSetExtension_NotMessageSetExtension)
+	registerExtension(pb2.E_MessageSetExtension_ExtNested)
+	registerExtension(pb2.E_FakeMessageSetExtension_MessageSetExtension)
+}
+
+func registerExtension(xd *protoapi.ExtensionDesc) {
+	preg.GlobalTypes.Register(xd.Type)
+}
+
 func TestUnmarshal(t *testing.T) {
 	tests := []struct {
 		desc         string
@@ -907,6 +937,215 @@
   }
 }`,
 		wantErr: true,
+	}, {
+		desc:         "extensions of non-repeated fields",
+		inputMessage: &pb2.Extensions{},
+		inputText: `{
+  "optString": "non-extension field",
+  "optBool": true,
+  "optInt32": 42,
+  "[pb2.opt_ext_bool]": true,
+  "[pb2.opt_ext_nested]": {
+    "optString": "nested in an extension",
+    "opt_nested": {
+      "opt_string": "another nested in an extension"
+    }
+  },
+  "[pb2.opt_ext_string]": "extension field",
+  "[pb2.opt_ext_enum]": "TEN"
+}`,
+		wantMessage: func() proto.Message {
+			m := &pb2.Extensions{
+				OptString: scalar.String("non-extension field"),
+				OptBool:   scalar.Bool(true),
+				OptInt32:  scalar.Int32(42),
+			}
+			setExtension(m, pb2.E_OptExtBool, true)
+			setExtension(m, pb2.E_OptExtString, "extension field")
+			setExtension(m, pb2.E_OptExtEnum, pb2.Enum_TEN)
+			setExtension(m, pb2.E_OptExtNested, &pb2.Nested{
+				OptString: scalar.String("nested in an extension"),
+				OptNested: &pb2.Nested{
+					OptString: scalar.String("another nested in an extension"),
+				},
+			})
+			return m
+		}(),
+	}, {
+		desc:         "extensions of repeated fields",
+		inputMessage: &pb2.Extensions{},
+		inputText: `{
+  "[pb2.rpt_ext_enum]": ["TEN", 101, "ONE"],
+  "[pb2.rpt_ext_fixed32]": [42, 47],
+  "[pb2.rpt_ext_nested]": [
+    {"optString": "one"},
+	{"optString": "two"},
+	{"optString": "three"}
+  ]
+}`,
+		wantMessage: func() proto.Message {
+			m := &pb2.Extensions{}
+			setExtension(m, pb2.E_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
+			setExtension(m, pb2.E_RptExtFixed32, &[]uint32{42, 47})
+			setExtension(m, pb2.E_RptExtNested, &[]*pb2.Nested{
+				&pb2.Nested{OptString: scalar.String("one")},
+				&pb2.Nested{OptString: scalar.String("two")},
+				&pb2.Nested{OptString: scalar.String("three")},
+			})
+			return m
+		}(),
+	}, {
+		desc:         "extensions of non-repeated fields in another message",
+		inputMessage: &pb2.Extensions{},
+		inputText: `{
+  "[pb2.ExtensionsContainer.opt_ext_bool]": true,
+  "[pb2.ExtensionsContainer.opt_ext_enum]": "TEN",
+  "[pb2.ExtensionsContainer.opt_ext_nested]": {
+    "optString": "nested in an extension",
+    "optNested": {
+      "optString": "another nested in an extension"
+    }
+  },
+  "[pb2.ExtensionsContainer.opt_ext_string]": "extension field"
+}`,
+		wantMessage: func() proto.Message {
+			m := &pb2.Extensions{}
+			setExtension(m, pb2.E_ExtensionsContainer_OptExtBool, true)
+			setExtension(m, pb2.E_ExtensionsContainer_OptExtString, "extension field")
+			setExtension(m, pb2.E_ExtensionsContainer_OptExtEnum, pb2.Enum_TEN)
+			setExtension(m, pb2.E_ExtensionsContainer_OptExtNested, &pb2.Nested{
+				OptString: scalar.String("nested in an extension"),
+				OptNested: &pb2.Nested{
+					OptString: scalar.String("another nested in an extension"),
+				},
+			})
+			return m
+		}(),
+	}, {
+		desc:         "extensions of repeated fields in another message",
+		inputMessage: &pb2.Extensions{},
+		inputText: `{
+  "optString": "non-extension field",
+  "optBool": true,
+  "optInt32": 42,
+  "[pb2.ExtensionsContainer.rpt_ext_nested]": [
+    {"optString": "one"},
+    {"optString": "two"},
+    {"optString": "three"}
+  ],
+  "[pb2.ExtensionsContainer.rpt_ext_enum]": ["TEN", 101, "ONE"],
+  "[pb2.ExtensionsContainer.rpt_ext_string]": ["hello", "world"]
+}`,
+		wantMessage: func() proto.Message {
+			m := &pb2.Extensions{
+				OptString: scalar.String("non-extension field"),
+				OptBool:   scalar.Bool(true),
+				OptInt32:  scalar.Int32(42),
+			}
+			setExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
+			setExtension(m, pb2.E_ExtensionsContainer_RptExtString, &[]string{"hello", "world"})
+			setExtension(m, pb2.E_ExtensionsContainer_RptExtNested, &[]*pb2.Nested{
+				&pb2.Nested{OptString: scalar.String("one")},
+				&pb2.Nested{OptString: scalar.String("two")},
+				&pb2.Nested{OptString: scalar.String("three")},
+			})
+			return m
+		}(),
+	}, {
+		desc:         "invalid extension field name",
+		inputMessage: &pb2.Extensions{},
+		inputText:    `{ "[pb2.invalid_message_field]": true }`,
+		wantErr:      true,
+	}, {
+		desc:         "MessageSet",
+		inputMessage: &pb2.MessageSet{},
+		inputText: `{
+  "[pb2.MessageSetExtension]": {
+    "optString": "a messageset extension"
+  },
+  "[pb2.MessageSetExtension.ext_nested]": {
+    "optString": "just a regular extension"
+  },
+  "[pb2.MessageSetExtension.not_message_set_extension]": {
+    "optString": "not a messageset extension"
+  }
+}`,
+		wantMessage: func() proto.Message {
+			m := &pb2.MessageSet{}
+			setExtension(m, pb2.E_MessageSetExtension_MessageSetExtension, &pb2.MessageSetExtension{
+				OptString: scalar.String("a messageset extension"),
+			})
+			setExtension(m, pb2.E_MessageSetExtension_NotMessageSetExtension, &pb2.MessageSetExtension{
+				OptString: scalar.String("not a messageset extension"),
+			})
+			setExtension(m, pb2.E_MessageSetExtension_ExtNested, &pb2.Nested{
+				OptString: scalar.String("just a regular extension"),
+			})
+			return m
+		}(),
+	}, {
+		desc:         "extension field set to null",
+		inputMessage: &pb2.Extensions{},
+		inputText: `{
+  "[pb2.ExtensionsContainer.opt_ext_bool]": null,
+  "[pb2.ExtensionsContainer.opt_ext_nested]": null
+}`,
+		wantMessage: func() proto.Message {
+			m := &pb2.Extensions{}
+			setExtension(m, pb2.E_ExtensionsContainer_OptExtBool, nil)
+			setExtension(m, pb2.E_ExtensionsContainer_OptExtNested, nil)
+			return m
+		}(),
+	}, {
+		desc:         "extensions of repeated field contains null",
+		inputMessage: &pb2.Extensions{},
+		inputText: `{
+  "[pb2.ExtensionsContainer.rpt_ext_nested]": [
+    {"optString": "one"},
+	null,
+    {"optString": "three"}
+  ],
+}`,
+		wantErr: true,
+	}, {
+		desc:         "not real MessageSet 1",
+		inputMessage: &pb2.FakeMessageSet{},
+		inputText: `{
+  "[pb2.FakeMessageSetExtension.message_set_extension]": {
+    "optString": "not a messageset extension"
+  }
+}`,
+		wantMessage: func() proto.Message {
+			m := &pb2.FakeMessageSet{}
+			setExtension(m, pb2.E_FakeMessageSetExtension_MessageSetExtension, &pb2.FakeMessageSetExtension{
+				OptString: scalar.String("not a messageset extension"),
+			})
+			return m
+		}(),
+	}, {
+		desc:         "not real MessageSet 2",
+		inputMessage: &pb2.FakeMessageSet{},
+		inputText: `{
+  "[pb2.FakeMessageSetExtension]": {
+    "optString": "not a messageset extension"
+  }
+}`,
+		wantErr: true,
+	}, {
+		desc:         "not real MessageSet 3",
+		inputMessage: &pb2.MessageSet{},
+		inputText: `{
+  "[pb2.message_set_extension]": {
+    "optString": "another not a messageset extension"
+  }
+}`,
+		wantMessage: func() proto.Message {
+			m := &pb2.MessageSet{}
+			setExtension(m, pb2.E_MessageSetExtension, &pb2.FakeMessageSetExtension{
+				OptString: scalar.String("another not a messageset extension"),
+			})
+			return m
+		}(),
 	}}
 
 	for _, tt := range tests {