internal/encoding/messageset: fix decoding of some invalid data

For historical reasons, MessageSets items are allowed to have field
numbers outside the usual valid range. Detect the case where the field
number cannot fit in an int32 and report an error. Also check for
a field number of 0 (always invalid).

Handle the case where a MessageSet item includes an unknown field.
We have no place to put the contents of the field, so drop it. This is,
I believe, consistent with other implementations.

Change-Id: Ic403427e1c276cbfa232ca577e7a799cce706bc7
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/221939
Reviewed-by: Herbie Ong <herbie@google.com>
diff --git a/internal/encoding/messageset/messageset.go b/internal/encoding/messageset/messageset.go
index 55526cf..4bd0e4e 100644
--- a/internal/encoding/messageset/messageset.go
+++ b/internal/encoding/messageset/messageset.go
@@ -6,6 +6,8 @@
 package messageset
 
 import (
+	"math"
+
 	"google.golang.org/protobuf/internal/encoding/wire"
 	"google.golang.org/protobuf/internal/errors"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
@@ -146,6 +148,9 @@
 				return 0, nil, 0, wire.ParseError(n)
 			}
 			b = b[n:]
+			if v < 1 || v > math.MaxInt32 {
+				return 0, nil, 0, errors.New("invalid type_id in message set")
+			}
 			typeid = wire.Number(v)
 		case num == FieldMessage && wtyp == wire.BytesType:
 			m, n := wire.ConsumeBytes(b)
@@ -178,6 +183,13 @@
 				}
 			}
 			b = b[n:]
+		default:
+			// We have no place to put it, so we just ignore unknown fields.
+			n := wire.ConsumeFieldValue(num, wtyp, b)
+			if n < 0 {
+				return 0, nil, 0, wire.ParseError(n)
+			}
+			b = b[n:]
 		}
 	}
 }
diff --git a/proto/messageset_test.go b/proto/messageset_test.go
index ff41342..1eedcc9 100644
--- a/proto/messageset_test.go
+++ b/proto/messageset_test.go
@@ -17,6 +17,7 @@
 func init() {
 	if flags.ProtoLegacy {
 		testValidMessages = append(testValidMessages, messageSetTestProtos...)
+		testInvalidMessages = append(testInvalidMessages, messageSetInvalidTestProtos...)
 	}
 }
 
@@ -218,6 +219,27 @@
 		}.Marshal(),
 	},
 	{
+		desc: "MessageSet with unknown field",
+		decodeTo: []proto.Message{func() proto.Message {
+			m := &messagesetpb.MessageSetContainer{MessageSet: &messagesetpb.MessageSet{}}
+			proto.SetExtension(m.MessageSet, msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
+				Ext1Field1: proto.Int32(10),
+			})
+			return m
+		}()},
+		wire: pack.Message{
+			pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.StartGroupType},
+				pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+				pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+					pack.Tag{1, pack.VarintType}, pack.Varint(10),
+				}),
+				pack.Tag{4, pack.VarintType}, pack.Varint(0),
+				pack.Tag{1, pack.EndGroupType},
+			}),
+		}.Marshal(),
+	},
+	{
 		desc:          "MessageSet with required field set",
 		checkFastInit: true,
 		decodeTo: []proto.Message{func() proto.Message {
@@ -257,3 +279,34 @@
 		}.Marshal(),
 	},
 }
+
+var messageSetInvalidTestProtos = []testProto{
+	{
+		desc: "MessageSet with type id 0",
+		decodeTo: []proto.Message{
+			(*messagesetpb.MessageSetContainer)(nil),
+		},
+		wire: pack.Message{
+			pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.StartGroupType},
+				pack.Tag{2, pack.VarintType}, pack.Uvarint(0),
+				pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
+				pack.Tag{1, pack.EndGroupType},
+			}),
+		}.Marshal(),
+	},
+	{
+		desc: "MessageSet with type id overflowing int32",
+		decodeTo: []proto.Message{
+			(*messagesetpb.MessageSetContainer)(nil),
+		},
+		wire: pack.Message{
+			pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.StartGroupType},
+				pack.Tag{2, pack.VarintType}, pack.Uvarint(0x80000000),
+				pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
+				pack.Tag{1, pack.EndGroupType},
+			}),
+		}.Marshal(),
+	},
+}