proto: eagerly unmarshal extensions

CL/172399 switches the v1 code to eagerly unmarshal extensions.
This CL does the equivalent for v2.

For the test, we simply switch from protoV1.Equal to protoV2.Equal,
since the v2 equal does not magically unmarshal raw extensions.

Change-Id: I6f64455b0a75bbc9a9a82108558641a29bd2b982
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/175838
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/proto/decode.go b/proto/decode.go
index 5a867a2..fa6d443 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -9,6 +9,7 @@
 	"github.com/golang/protobuf/v2/internal/errors"
 	"github.com/golang/protobuf/v2/internal/pragma"
 	"github.com/golang/protobuf/v2/reflect/protoreflect"
+	"github.com/golang/protobuf/v2/reflect/protoregistry"
 	"github.com/golang/protobuf/v2/runtime/protoiface"
 )
 
@@ -25,6 +26,10 @@
 	// If DiscardUnknown is set, unknown fields are ignored.
 	DiscardUnknown bool
 
+	// Resolver is used for looking up types when unmarshaling extension fields.
+	// If nil, this defaults to using protoregistry.GlobalTypes.
+	Resolver *protoregistry.Types
+
 	pragma.NoUnkeyedLiterals
 }
 
@@ -37,6 +42,10 @@
 
 // Unmarshal parses the wire-format message in b and places the result in m.
 func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
+	if o.Resolver == nil {
+		o.Resolver = protoregistry.GlobalTypes
+	}
+
 	// TODO: Reset m?
 	err := o.unmarshalMessageFast(b, m)
 	if err == errInternalNoFast {
@@ -77,6 +86,16 @@
 		fieldType := fieldTypes.ByNumber(num)
 		if fieldType == nil {
 			fieldType = knownFields.ExtensionTypes().ByNumber(num)
+			if fieldType == nil && messageType.ExtensionRanges().Has(num) {
+				extType, err := o.Resolver.FindExtensionByNumber(messageType.FullName(), num)
+				if err != nil && err != protoregistry.NotFound {
+					return err
+				}
+				if extType != nil {
+					knownFields.ExtensionTypes().Register(extType)
+					fieldType = extType
+				}
+			}
 		}
 		var err error
 		var valLen int
diff --git a/proto/decode_test.go b/proto/decode_test.go
index 084014f..4eb2598 100644
--- a/proto/decode_test.go
+++ b/proto/decode_test.go
@@ -54,7 +54,7 @@
 					// Equal doesn't work on messages containing invalid extension data.
 					return
 				}
-				if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
+				if !proto.Equal(got, want) {
 					t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
 				}
 			})
diff --git a/runtime/protoiface/methods.go b/runtime/protoiface/methods.go
index 42832de..fe17ca7 100644
--- a/runtime/protoiface/methods.go
+++ b/runtime/protoiface/methods.go
@@ -7,6 +7,7 @@
 import (
 	"github.com/golang/protobuf/v2/internal/pragma"
 	"github.com/golang/protobuf/v2/reflect/protoreflect"
+	"github.com/golang/protobuf/v2/reflect/protoregistry"
 )
 
 // Methoder is an optional interface implemented by generated messages to
@@ -62,6 +63,7 @@
 type UnmarshalOptions struct {
 	AllowPartial   bool
 	DiscardUnknown bool
+	Resolver       *protoregistry.Types
 
 	pragma.NoUnkeyedLiterals
 }