internal/impl: validate messagesets
Change-Id: Id90bb386e7481bb9dee5a07889f308f1e1810825
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/218438
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/internal/encoding/messageset/messageset.go b/internal/encoding/messageset/messageset.go
index 77522de..837e5c4 100644
--- a/internal/encoding/messageset/messageset.go
+++ b/internal/encoding/messageset/messageset.go
@@ -99,7 +99,7 @@
b = b[n:]
continue
}
- typeID, value, n, err := consumeFieldValue(b, wantLen)
+ typeID, value, n, err := ConsumeFieldValue(b, wantLen)
if err != nil {
return err
}
@@ -114,13 +114,13 @@
return nil
}
-// consumeFieldValue parses b as a MessageSet item field value until and including
+// ConsumeFieldValue parses b as a MessageSet item field value until and including
// the trailing end group marker. It assumes the start group tag has already been parsed.
// It returns the contents of the type_id and message subfields and the total
// item length.
//
// If wantLen is true, the returned message value includes the length prefix.
-func consumeFieldValue(b []byte, wantLen bool) (typeid wire.Number, message []byte, n int, err error) {
+func ConsumeFieldValue(b []byte, wantLen bool) (typeid wire.Number, message []byte, n int, err error) {
ilen := len(b)
for {
num, wtyp, n := wire.ConsumeTag(b)
diff --git a/internal/impl/validate.go b/internal/impl/validate.go
index 06acc78..bb00cd0 100644
--- a/internal/impl/validate.go
+++ b/internal/impl/validate.go
@@ -11,6 +11,7 @@
"reflect"
"unicode/utf8"
+ "google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/internal/strs"
@@ -93,6 +94,7 @@
validationTypeFixed64
validationTypeBytes
validationTypeUTF8String
+ validationTypeMessageSetItem
)
func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
@@ -237,11 +239,6 @@
State:
for len(states) > 0 {
st := &states[len(states)-1]
- if st.mi != nil {
- if flags.ProtoLegacy && st.mi.isMessageSet {
- return out, ValidationUnknown
- }
- }
for len(b) > 0 {
// Parse the tag (field number and wire type).
var tag uint64
@@ -274,8 +271,8 @@
return out, ValidationInvalid
}
var vi validationInfo
- switch st.typ {
- case validationTypeMap:
+ switch {
+ case st.typ == validationTypeMap:
switch num {
case 1:
vi.typ = st.keyType
@@ -284,6 +281,11 @@
vi.mi = st.mi
vi.requiredBit = 1
}
+ case flags.ProtoLegacy && st.mi.isMessageSet:
+ switch num {
+ case messageset.FieldItem:
+ vi.typ = validationTypeMessageSetItem
+ }
default:
var f *coderFieldInfo
if int(num) < len(st.mi.denseCoderFields) {
@@ -483,8 +485,8 @@
}
b = b[8:]
case wire.StartGroupType:
- switch vi.typ {
- case validationTypeGroup:
+ switch {
+ case vi.typ == validationTypeGroup:
if vi.mi == nil {
return out, ValidationUnknown
}
@@ -495,6 +497,27 @@
endGroup: num,
})
continue State
+ case flags.ProtoLegacy && vi.typ == validationTypeMessageSetItem:
+ typeid, v, n, err := messageset.ConsumeFieldValue(b, false)
+ if err != nil {
+ return out, ValidationInvalid
+ }
+ xt, err := opts.Resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid)
+ switch {
+ case err == preg.NotFound:
+ b = b[n:]
+ case err != nil:
+ return out, ValidationUnknown
+ default:
+ xvi := getExtensionFieldInfo(xt).validation
+ states = append(states, validationState{
+ typ: xvi.typ,
+ mi: xvi.mi,
+ tail: b[n:],
+ })
+ b = v
+ continue State
+ }
default:
n := wire.ConsumeFieldValue(num, wtyp, b)
if n < 0 {
diff --git a/proto/messageset_test.go b/proto/messageset_test.go
index 9e70e59..c901800 100644
--- a/proto/messageset_test.go
+++ b/proto/messageset_test.go
@@ -8,7 +8,6 @@
"google.golang.org/protobuf/internal/encoding/pack"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/internal/flags"
- "google.golang.org/protobuf/internal/impl"
"google.golang.org/protobuf/proto"
messagesetpb "google.golang.org/protobuf/internal/testprotos/messageset/messagesetpb"
@@ -41,7 +40,6 @@
pack.Tag{1, pack.EndGroupType},
}),
}.Marshal(),
- validationStatus: impl.ValidationUnknown,
},
{
desc: "MessageSet type_id after message content",
@@ -62,7 +60,6 @@
pack.Tag{1, pack.EndGroupType},
}),
}.Marshal(),
- validationStatus: impl.ValidationUnknown,
},
{
desc: "MessageSet does not preserve unknown field",
@@ -82,7 +79,6 @@
// Unknown field
pack.Tag{4, pack.VarintType}, pack.Varint(30),
}.Marshal(),
- validationStatus: impl.ValidationUnknown,
},
{
desc: "MessageSet with unknown type_id",
@@ -102,7 +98,6 @@
}),
pack.Tag{1, pack.EndGroupType},
}.Marshal(),
- validationStatus: impl.ValidationUnknown,
},
{
desc: "MessageSet merges repeated message fields in item",
@@ -124,7 +119,6 @@
}),
pack.Tag{1, pack.EndGroupType},
}.Marshal(),
- validationStatus: impl.ValidationUnknown,
},
{
desc: "MessageSet merges message fields in repeated items",
@@ -161,7 +155,6 @@
}),
pack.Tag{1, pack.EndGroupType},
}.Marshal(),
- validationStatus: impl.ValidationUnknown,
},
{
desc: "MessageSet with missing type_id",
@@ -175,7 +168,6 @@
}),
pack.Tag{1, pack.EndGroupType},
}.Marshal(),
- validationStatus: impl.ValidationUnknown,
},
{
desc: "MessageSet with missing message",
@@ -188,7 +180,6 @@
pack.Tag{2, pack.VarintType}, pack.Varint(1000),
pack.Tag{1, pack.EndGroupType},
}.Marshal(),
- validationStatus: impl.ValidationUnknown,
},
{
desc: "MessageSet with type id out of valid field number range",
@@ -205,7 +196,6 @@
pack.Tag{1, pack.EndGroupType},
}),
}.Marshal(),
- validationStatus: impl.ValidationUnknown,
},
{
desc: "MessageSet with unknown type id out of valid field number range",
@@ -226,7 +216,6 @@
pack.Tag{1, pack.EndGroupType},
}),
}.Marshal(),
- validationStatus: impl.ValidationUnknown,
},
{
desc: "MessageSet with required field set",
@@ -248,7 +237,6 @@
pack.Tag{1, pack.EndGroupType},
}),
}.Marshal(),
- validationStatus: impl.ValidationUnknown,
},
{
desc: "MessageSet with required field unset",
@@ -267,6 +255,5 @@
pack.Tag{1, pack.EndGroupType},
}),
}.Marshal(),
- validationStatus: impl.ValidationUnknown,
},
}