all: don't allow invalid field numbers when legacy support is on

The deprecated messageset format permits extension fields with numbers
greater than the usual maximum (1<<29-1). To support this, the
internal/encoding/wire package has disabled field number validation when
legacy support is enabled.

We shouldn't skip validating all field numbers for validity just because
we support larger ones in messagesets.

This change drops range validation from the wire package (other than
checking that numbers fit in an int32) and adds it to the wire
unmarshalers instead. This gives us validation where we care
about it (when unmarshaling a wire-format message) and allows for
best-effort handling of out-of-range numbers everywhere else.

Fixes golang/protobuf#996

Change-Id: I4e11b8a8aa177dd60e89723570af074a317c2451
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/210290
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/encoding/wire/wire.go b/internal/encoding/wire/wire.go
index 6c4fd5a..32622b1 100644
--- a/internal/encoding/wire/wire.go
+++ b/internal/encoding/wire/wire.go
@@ -13,7 +13,6 @@
 	"math/bits"
 
 	"google.golang.org/protobuf/internal/errors"
-	"google.golang.org/protobuf/internal/flags"
 )
 
 // Number represents the field number.
@@ -490,18 +489,12 @@
 }
 
 // DecodeTag decodes the field Number and wire Type from its unified form.
-// The Number is -1 if the decoded field number overflows.
+// The Number is -1 if the decoded field number overflows int32.
 // Other than overflow, this does not check for field number validity.
 func DecodeTag(x uint64) (Number, Type) {
 	// NOTE: MessageSet allows for larger field numbers than normal.
-	if flags.ProtoLegacy {
-		if x>>3 > uint64(math.MaxInt32) {
-			return -1, 0
-		}
-	} else {
-		if x>>3 > uint64(MaxValidNumber) {
-			return -1, 0
-		}
+	if x>>3 > uint64(math.MaxInt32) {
+		return -1, 0
 	}
 	return Number(x >> 3), Type(x & 7)
 }
diff --git a/internal/impl/decode.go b/internal/impl/decode.go
index 85cc6b0..9ccfa7f 100644
--- a/internal/impl/decode.go
+++ b/internal/impl/decode.go
@@ -85,6 +85,9 @@
 		if n < 0 {
 			return 0, wire.ParseError(n)
 		}
+		if num > wire.MaxValidNumber {
+			return 0, errors.New("invalid field number")
+		}
 		b = b[n:]
 
 		var f *coderFieldInfo
diff --git a/internal/testprotos/messageset/msetextpb/msetextpb.pb.go b/internal/testprotos/messageset/msetextpb/msetextpb.pb.go
index 9a49c69..411e24c 100644
--- a/internal/testprotos/messageset/msetextpb/msetextpb.pb.go
+++ b/internal/testprotos/messageset/msetextpb/msetextpb.pb.go
@@ -117,6 +117,44 @@
 	return 0
 }
 
+type ExtLargeNumber struct {
+	state         protoimpl.MessageState
+	sizeCache     protoimpl.SizeCache
+	unknownFields protoimpl.UnknownFields
+}
+
+func (x *ExtLargeNumber) Reset() {
+	*x = ExtLargeNumber{}
+	if protoimpl.UnsafeEnabled {
+		mi := &file_messageset_msetextpb_msetextpb_proto_msgTypes[2]
+		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+		ms.StoreMessageInfo(mi)
+	}
+}
+
+func (x *ExtLargeNumber) String() string {
+	return protoimpl.X.MessageStringOf(x)
+}
+
+func (*ExtLargeNumber) ProtoMessage() {}
+
+func (x *ExtLargeNumber) ProtoReflect() protoreflect.Message {
+	mi := &file_messageset_msetextpb_msetextpb_proto_msgTypes[2]
+	if protoimpl.UnsafeEnabled && x != nil {
+		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+		if ms.LoadMessageInfo() == nil {
+			ms.StoreMessageInfo(mi)
+		}
+		return ms
+	}
+	return mi.MessageOf(x)
+}
+
+// Deprecated: Use ExtLargeNumber.ProtoReflect.Descriptor instead.
+func (*ExtLargeNumber) Descriptor() ([]byte, []int) {
+	return file_messageset_msetextpb_msetextpb_proto_rawDescGZIP(), []int{2}
+}
+
 var file_messageset_msetextpb_msetextpb_proto_extTypes = []protoimpl.ExtensionInfo{
 	{
 		ExtendedType:  (*messagesetpb.MessageSet)(nil),
@@ -134,6 +172,14 @@
 		Tag:           "bytes,1001,opt,name=message_set_extension",
 		Filename:      "messageset/msetextpb/msetextpb.proto",
 	},
+	{
+		ExtendedType:  (*messagesetpb.MessageSet)(nil),
+		ExtensionType: (*ExtLargeNumber)(nil),
+		Field:         536870912,
+		Name:          "goproto.proto.messageset.ExtLargeNumber",
+		Tag:           "bytes,536870912,opt,name=message_set_extension",
+		Filename:      "messageset/msetextpb/msetextpb.proto",
+	},
 }
 
 // Extension fields to messagesetpb.MessageSet.
@@ -142,6 +188,8 @@
 	E_Ext1_MessageSetExtension = &file_messageset_msetextpb_msetextpb_proto_extTypes[0]
 	// optional goproto.proto.messageset.Ext2 message_set_extension = 1001;
 	E_Ext2_MessageSetExtension = &file_messageset_msetextpb_msetextpb_proto_extTypes[1]
+	// optional goproto.proto.messageset.ExtLargeNumber message_set_extension = 536870912;
+	E_ExtLargeNumber_MessageSetExtension = &file_messageset_msetextpb_msetextpb_proto_extTypes[2] // 1<<29
 )
 
 var File_messageset_msetextpb_msetextpb_proto protoreflect.FileDescriptor
@@ -176,11 +224,21 @@
 	0x0b, 0x32, 0x1e, 0x2e, 0x67, 0x6f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74,
 	0x6f, 0x2e, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x2e, 0x45, 0x78, 0x74,
 	0x32, 0x52, 0x13, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x53, 0x65, 0x74, 0x45, 0x78, 0x74,
-	0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x42, 0x45, 0x5a, 0x43, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65,
-	0x2e, 0x67, 0x6f, 0x6c, 0x61, 0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x70, 0x72, 0x6f, 0x74,
-	0x6f, 0x62, 0x75, 0x66, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x74, 0x65,
-	0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65,
-	0x73, 0x65, 0x74, 0x2f, 0x6d, 0x73, 0x65, 0x74, 0x65, 0x78, 0x74, 0x70, 0x62,
+	0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x99, 0x01, 0x0a, 0x0e, 0x45, 0x78, 0x74, 0x4c, 0x61,
+	0x72, 0x67, 0x65, 0x4e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x32, 0x86, 0x01, 0x0a, 0x15, 0x6d, 0x65,
+	0x73, 0x73, 0x61, 0x67, 0x65, 0x5f, 0x73, 0x65, 0x74, 0x5f, 0x65, 0x78, 0x74, 0x65, 0x6e, 0x73,
+	0x69, 0x6f, 0x6e, 0x12, 0x24, 0x2e, 0x67, 0x6f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72,
+	0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x2e, 0x4d,
+	0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x53, 0x65, 0x74, 0x18, 0x80, 0x80, 0x80, 0x80, 0x02, 0x20,
+	0x01, 0x28, 0x0b, 0x32, 0x28, 0x2e, 0x67, 0x6f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72,
+	0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x2e, 0x45,
+	0x78, 0x74, 0x4c, 0x61, 0x72, 0x67, 0x65, 0x4e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x52, 0x13, 0x6d,
+	0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x53, 0x65, 0x74, 0x45, 0x78, 0x74, 0x65, 0x6e, 0x73, 0x69,
+	0x6f, 0x6e, 0x42, 0x45, 0x5a, 0x43, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x67, 0x6f, 0x6c,
+	0x61, 0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66,
+	0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72,
+	0x6f, 0x74, 0x6f, 0x73, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x2f,
+	0x6d, 0x73, 0x65, 0x74, 0x65, 0x78, 0x74, 0x70, 0x62,
 }
 
 var (
@@ -195,21 +253,24 @@
 	return file_messageset_msetextpb_msetextpb_proto_rawDescData
 }
 
-var file_messageset_msetextpb_msetextpb_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
+var file_messageset_msetextpb_msetextpb_proto_msgTypes = make([]protoimpl.MessageInfo, 3)
 var file_messageset_msetextpb_msetextpb_proto_goTypes = []interface{}{
 	(*Ext1)(nil),                    // 0: goproto.proto.messageset.Ext1
 	(*Ext2)(nil),                    // 1: goproto.proto.messageset.Ext2
-	(*messagesetpb.MessageSet)(nil), // 2: goproto.proto.messageset.MessageSet
+	(*ExtLargeNumber)(nil),          // 2: goproto.proto.messageset.ExtLargeNumber
+	(*messagesetpb.MessageSet)(nil), // 3: goproto.proto.messageset.MessageSet
 }
 var file_messageset_msetextpb_msetextpb_proto_depIdxs = []int32{
-	2, // 0: goproto.proto.messageset.Ext1.message_set_extension:extendee -> goproto.proto.messageset.MessageSet
-	2, // 1: goproto.proto.messageset.Ext2.message_set_extension:extendee -> goproto.proto.messageset.MessageSet
-	0, // 2: goproto.proto.messageset.Ext1.message_set_extension:type_name -> goproto.proto.messageset.Ext1
-	1, // 3: goproto.proto.messageset.Ext2.message_set_extension:type_name -> goproto.proto.messageset.Ext2
-	4, // [4:4] is the sub-list for method output_type
-	4, // [4:4] is the sub-list for method input_type
-	2, // [2:4] is the sub-list for extension type_name
-	0, // [0:2] is the sub-list for extension extendee
+	3, // 0: goproto.proto.messageset.Ext1.message_set_extension:extendee -> goproto.proto.messageset.MessageSet
+	3, // 1: goproto.proto.messageset.Ext2.message_set_extension:extendee -> goproto.proto.messageset.MessageSet
+	3, // 2: goproto.proto.messageset.ExtLargeNumber.message_set_extension:extendee -> goproto.proto.messageset.MessageSet
+	0, // 3: goproto.proto.messageset.Ext1.message_set_extension:type_name -> goproto.proto.messageset.Ext1
+	1, // 4: goproto.proto.messageset.Ext2.message_set_extension:type_name -> goproto.proto.messageset.Ext2
+	2, // 5: goproto.proto.messageset.ExtLargeNumber.message_set_extension:type_name -> goproto.proto.messageset.ExtLargeNumber
+	6, // [6:6] is the sub-list for method output_type
+	6, // [6:6] is the sub-list for method input_type
+	3, // [3:6] is the sub-list for extension type_name
+	0, // [0:3] is the sub-list for extension extendee
 	0, // [0:0] is the sub-list for field type_name
 }
 
@@ -243,6 +304,18 @@
 				return nil
 			}
 		}
+		file_messageset_msetextpb_msetextpb_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
+			switch v := v.(*ExtLargeNumber); i {
+			case 0:
+				return &v.state
+			case 1:
+				return &v.sizeCache
+			case 2:
+				return &v.unknownFields
+			default:
+				return nil
+			}
+		}
 	}
 	type x struct{}
 	out := protoimpl.TypeBuilder{
@@ -250,8 +323,8 @@
 			GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
 			RawDescriptor: file_messageset_msetextpb_msetextpb_proto_rawDesc,
 			NumEnums:      0,
-			NumMessages:   2,
-			NumExtensions: 2,
+			NumMessages:   3,
+			NumExtensions: 3,
 			NumServices:   0,
 		},
 		GoTypes:           file_messageset_msetextpb_msetextpb_proto_goTypes,
diff --git a/internal/testprotos/messageset/msetextpb/msetextpb.proto b/internal/testprotos/messageset/msetextpb/msetextpb.proto
index 5d1bf08..b5b50d8 100644
--- a/internal/testprotos/messageset/msetextpb/msetextpb.proto
+++ b/internal/testprotos/messageset/msetextpb/msetextpb.proto
@@ -24,3 +24,9 @@
   }
   optional int32 ext2_field1 = 1;
 }
+
+message ExtLargeNumber {
+  extend MessageSet {
+    optional ExtLargeNumber message_set_extension = 536870912; // 1<<29
+  }
+}
diff --git a/proto/decode.go b/proto/decode.go
index f64f887..07ae467 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -88,6 +88,9 @@
 		if tagLen < 0 {
 			return wire.ParseError(tagLen)
 		}
+		if num > wire.MaxValidNumber {
+			return errors.New("invalid field number")
+		}
 
 		// Find the field descriptor for this field number.
 		fd := fields.ByNumber(num)
diff --git a/proto/decode_test.go b/proto/decode_test.go
index 2c50dde..13db09d 100644
--- a/proto/decode_test.go
+++ b/proto/decode_test.go
@@ -1762,14 +1762,12 @@
 			pack.Tag{pack.MaxValidNumber, pack.VarintType}, pack.Varint(1006),
 			pack.Tag{pack.MaxValidNumber + 1, pack.VarintType}, pack.Varint(1007),
 		}.Marshal(),
-		allowed: flags.ProtoLegacy,
 	},
 	{
 		desc: "max+1",
 		wire: pack.Message{
 			pack.Tag{pack.MaxValidNumber + 1, pack.VarintType}, pack.Varint(1008),
 		}.Marshal(),
-		allowed: flags.ProtoLegacy,
 	},
 }
 
diff --git a/proto/messageset_test.go b/proto/messageset_test.go
index b7c4c72..3ef47a4 100644
--- a/proto/messageset_test.go
+++ b/proto/messageset_test.go
@@ -6,6 +6,7 @@
 
 import (
 	"google.golang.org/protobuf/internal/encoding/pack"
+	"google.golang.org/protobuf/internal/encoding/wire"
 	"google.golang.org/protobuf/internal/flags"
 	"google.golang.org/protobuf/proto"
 
@@ -180,4 +181,40 @@
 			pack.Tag{1, pack.EndGroupType},
 		}.Marshal(),
 	},
+	{
+		desc: "MessageSet with type id out of valid field number range",
+		decodeTo: []proto.Message{func() proto.Message {
+			m := &messagesetpb.MessageSetContainer{MessageSet: &messagesetpb.MessageSet{}}
+			proto.SetExtension(m.MessageSet, msetextpb.E_ExtLargeNumber_MessageSetExtension, &msetextpb.ExtLargeNumber{})
+			return m
+		}()},
+		wire: pack.Message{
+			pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.StartGroupType},
+				pack.Tag{2, pack.VarintType}, pack.Varint(wire.MaxValidNumber + 1),
+				pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
+				pack.Tag{1, pack.EndGroupType},
+			}),
+		}.Marshal(),
+	},
+	{
+		desc: "MessageSet with unknown type id out of valid field number range",
+		decodeTo: []proto.Message{func() proto.Message {
+			m := &messagesetpb.MessageSetContainer{MessageSet: &messagesetpb.MessageSet{}}
+			m.MessageSet.ProtoReflect().SetUnknown(
+				pack.Message{
+					pack.Tag{wire.MaxValidNumber + 2, pack.BytesType}, pack.LengthPrefix{},
+				}.Marshal(),
+			)
+			return m
+		}()},
+		wire: pack.Message{
+			pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.StartGroupType},
+				pack.Tag{2, pack.VarintType}, pack.Varint(wire.MaxValidNumber + 2),
+				pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
+				pack.Tag{1, pack.EndGroupType},
+			}),
+		}.Marshal(),
+	},
 }