internal/impl: fix tag decoding when field num doesn't fit in int32

Discoverd by OSS-Fuzz.

Change-Id: Ie2feefacee4ae632802fa920ac9694b525690eb2
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/216619
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/internal/impl/decode.go b/internal/impl/decode.go
index 3cd7f5a..cbc21b3 100644
--- a/internal/impl/decode.go
+++ b/internal/impl/decode.go
@@ -82,12 +82,13 @@
 			}
 			b = b[n:]
 		}
-		num := wire.Number(tag >> 3)
-		wtyp := wire.Type(tag & 7)
-
-		if num < wire.MinValidNumber || num > wire.MaxValidNumber {
+		var num wire.Number
+		if n := tag >> 3; n < uint64(wire.MinValidNumber) || n > uint64(wire.MaxValidNumber) {
 			return out, errors.New("invalid field number")
+		} else {
+			num = wire.Number(n)
 		}
+		wtyp := wire.Type(tag & 7)
 
 		if wtyp == wire.EndGroupType {
 			if num != groupTag {
diff --git a/proto/testmessages_test.go b/proto/testmessages_test.go
index 3a878e9..9994804 100644
--- a/proto/testmessages_test.go
+++ b/proto/testmessages_test.go
@@ -1786,6 +1786,16 @@
 		}.Marshal(),
 	},
 	{
+		desc: "invalid field number wraps int32",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+			(*testpb.TestAllExtensions)(nil),
+		},
+		wire: pack.Message{
+			pack.Varint(2234993595104), pack.Varint(0),
+		}.Marshal(),
+	},
+	{
 		desc:     "invalid field number in map",
 		decodeTo: []proto.Message{(*testpb.TestAllTypes)(nil)},
 		wire: pack.Message{