internal/impl: validate messagesets

Change-Id: Id90bb386e7481bb9dee5a07889f308f1e1810825
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/218438
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/internal/encoding/messageset/messageset.go b/internal/encoding/messageset/messageset.go
index 77522de..837e5c4 100644
--- a/internal/encoding/messageset/messageset.go
+++ b/internal/encoding/messageset/messageset.go
@@ -99,7 +99,7 @@
 			b = b[n:]
 			continue
 		}
-		typeID, value, n, err := consumeFieldValue(b, wantLen)
+		typeID, value, n, err := ConsumeFieldValue(b, wantLen)
 		if err != nil {
 			return err
 		}
@@ -114,13 +114,13 @@
 	return nil
 }
 
-// consumeFieldValue parses b as a MessageSet item field value until and including
+// ConsumeFieldValue parses b as a MessageSet item field value until and including
 // the trailing end group marker. It assumes the start group tag has already been parsed.
 // It returns the contents of the type_id and message subfields and the total
 // item length.
 //
 // If wantLen is true, the returned message value includes the length prefix.
-func consumeFieldValue(b []byte, wantLen bool) (typeid wire.Number, message []byte, n int, err error) {
+func ConsumeFieldValue(b []byte, wantLen bool) (typeid wire.Number, message []byte, n int, err error) {
 	ilen := len(b)
 	for {
 		num, wtyp, n := wire.ConsumeTag(b)
diff --git a/internal/impl/validate.go b/internal/impl/validate.go
index 06acc78..bb00cd0 100644
--- a/internal/impl/validate.go
+++ b/internal/impl/validate.go
@@ -11,6 +11,7 @@
 	"reflect"
 	"unicode/utf8"
 
+	"google.golang.org/protobuf/internal/encoding/messageset"
 	"google.golang.org/protobuf/internal/encoding/wire"
 	"google.golang.org/protobuf/internal/flags"
 	"google.golang.org/protobuf/internal/strs"
@@ -93,6 +94,7 @@
 	validationTypeFixed64
 	validationTypeBytes
 	validationTypeUTF8String
+	validationTypeMessageSetItem
 )
 
 func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
@@ -237,11 +239,6 @@
 State:
 	for len(states) > 0 {
 		st := &states[len(states)-1]
-		if st.mi != nil {
-			if flags.ProtoLegacy && st.mi.isMessageSet {
-				return out, ValidationUnknown
-			}
-		}
 		for len(b) > 0 {
 			// Parse the tag (field number and wire type).
 			var tag uint64
@@ -274,8 +271,8 @@
 				return out, ValidationInvalid
 			}
 			var vi validationInfo
-			switch st.typ {
-			case validationTypeMap:
+			switch {
+			case st.typ == validationTypeMap:
 				switch num {
 				case 1:
 					vi.typ = st.keyType
@@ -284,6 +281,11 @@
 					vi.mi = st.mi
 					vi.requiredBit = 1
 				}
+			case flags.ProtoLegacy && st.mi.isMessageSet:
+				switch num {
+				case messageset.FieldItem:
+					vi.typ = validationTypeMessageSetItem
+				}
 			default:
 				var f *coderFieldInfo
 				if int(num) < len(st.mi.denseCoderFields) {
@@ -483,8 +485,8 @@
 				}
 				b = b[8:]
 			case wire.StartGroupType:
-				switch vi.typ {
-				case validationTypeGroup:
+				switch {
+				case vi.typ == validationTypeGroup:
 					if vi.mi == nil {
 						return out, ValidationUnknown
 					}
@@ -495,6 +497,27 @@
 						endGroup: num,
 					})
 					continue State
+				case flags.ProtoLegacy && vi.typ == validationTypeMessageSetItem:
+					typeid, v, n, err := messageset.ConsumeFieldValue(b, false)
+					if err != nil {
+						return out, ValidationInvalid
+					}
+					xt, err := opts.Resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid)
+					switch {
+					case err == preg.NotFound:
+						b = b[n:]
+					case err != nil:
+						return out, ValidationUnknown
+					default:
+						xvi := getExtensionFieldInfo(xt).validation
+						states = append(states, validationState{
+							typ:  xvi.typ,
+							mi:   xvi.mi,
+							tail: b[n:],
+						})
+						b = v
+						continue State
+					}
 				default:
 					n := wire.ConsumeFieldValue(num, wtyp, b)
 					if n < 0 {
diff --git a/proto/messageset_test.go b/proto/messageset_test.go
index 9e70e59..c901800 100644
--- a/proto/messageset_test.go
+++ b/proto/messageset_test.go
@@ -8,7 +8,6 @@
 	"google.golang.org/protobuf/internal/encoding/pack"
 	"google.golang.org/protobuf/internal/encoding/wire"
 	"google.golang.org/protobuf/internal/flags"
-	"google.golang.org/protobuf/internal/impl"
 	"google.golang.org/protobuf/proto"
 
 	messagesetpb "google.golang.org/protobuf/internal/testprotos/messageset/messagesetpb"
@@ -41,7 +40,6 @@
 				pack.Tag{1, pack.EndGroupType},
 			}),
 		}.Marshal(),
-		validationStatus: impl.ValidationUnknown,
 	},
 	{
 		desc: "MessageSet type_id after message content",
@@ -62,7 +60,6 @@
 				pack.Tag{1, pack.EndGroupType},
 			}),
 		}.Marshal(),
-		validationStatus: impl.ValidationUnknown,
 	},
 	{
 		desc: "MessageSet does not preserve unknown field",
@@ -82,7 +79,6 @@
 			// Unknown field
 			pack.Tag{4, pack.VarintType}, pack.Varint(30),
 		}.Marshal(),
-		validationStatus: impl.ValidationUnknown,
 	},
 	{
 		desc: "MessageSet with unknown type_id",
@@ -102,7 +98,6 @@
 			}),
 			pack.Tag{1, pack.EndGroupType},
 		}.Marshal(),
-		validationStatus: impl.ValidationUnknown,
 	},
 	{
 		desc: "MessageSet merges repeated message fields in item",
@@ -124,7 +119,6 @@
 			}),
 			pack.Tag{1, pack.EndGroupType},
 		}.Marshal(),
-		validationStatus: impl.ValidationUnknown,
 	},
 	{
 		desc: "MessageSet merges message fields in repeated items",
@@ -161,7 +155,6 @@
 			}),
 			pack.Tag{1, pack.EndGroupType},
 		}.Marshal(),
-		validationStatus: impl.ValidationUnknown,
 	},
 	{
 		desc: "MessageSet with missing type_id",
@@ -175,7 +168,6 @@
 			}),
 			pack.Tag{1, pack.EndGroupType},
 		}.Marshal(),
-		validationStatus: impl.ValidationUnknown,
 	},
 	{
 		desc: "MessageSet with missing message",
@@ -188,7 +180,6 @@
 			pack.Tag{2, pack.VarintType}, pack.Varint(1000),
 			pack.Tag{1, pack.EndGroupType},
 		}.Marshal(),
-		validationStatus: impl.ValidationUnknown,
 	},
 	{
 		desc: "MessageSet with type id out of valid field number range",
@@ -205,7 +196,6 @@
 				pack.Tag{1, pack.EndGroupType},
 			}),
 		}.Marshal(),
-		validationStatus: impl.ValidationUnknown,
 	},
 	{
 		desc: "MessageSet with unknown type id out of valid field number range",
@@ -226,7 +216,6 @@
 				pack.Tag{1, pack.EndGroupType},
 			}),
 		}.Marshal(),
-		validationStatus: impl.ValidationUnknown,
 	},
 	{
 		desc:          "MessageSet with required field set",
@@ -248,7 +237,6 @@
 				pack.Tag{1, pack.EndGroupType},
 			}),
 		}.Marshal(),
-		validationStatus: impl.ValidationUnknown,
 	},
 	{
 		desc:          "MessageSet with required field unset",
@@ -267,6 +255,5 @@
 				pack.Tag{1, pack.EndGroupType},
 			}),
 		}.Marshal(),
-		validationStatus: impl.ValidationUnknown,
 	},
 }