internal/impl: check for required fields in missing map value
If a map value is a message with required fields, the validator should
note that it is uninitialized if a map item contains no value. In this
case, the value is an empty message which obviously does not have the
required field set.
Change-Id: I7698e60765e3c95478f293e121bba3ad7fc88e27
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/213900
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/internal/impl/validate.go b/internal/impl/validate.go
index e6b3d23..5f3c3f5 100644
--- a/internal/impl/validate.go
+++ b/internal/impl/validate.go
@@ -266,6 +266,7 @@
case 2:
vi.typ = st.valType
vi.mi = st.mi
+ vi.requiredIndex = 1
}
default:
var f *coderFieldInfo
@@ -436,15 +437,23 @@
}
b = st.tail
PopState:
+ numRequiredFields := 0
switch st.typ {
case validationTypeMessage, validationTypeGroup:
- // If there are more than 64 required fields, this check will
- // always fail and we will report that the message is potentially
- // uninitialized.
- if st.mi.numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != int(st.mi.numRequiredFields) {
- initialized = false
+ numRequiredFields = int(st.mi.numRequiredFields)
+ case validationTypeMap:
+ // If this is a map field with a message value that contains
+ // required fields, require that the value be present.
+ if st.mi != nil && st.mi.numRequiredFields > 0 {
+ numRequiredFields = 1
}
}
+ // If there are more than 64 required fields, this check will
+ // always fail and we will report that the message is potentially
+ // uninitialized.
+ if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields {
+ initialized = false
+ }
states = states[:len(states)-1]
}
if !initialized {
diff --git a/proto/testmessages_test.go b/proto/testmessages_test.go
index 8c62dbb..de85421 100644
--- a/proto/testmessages_test.go
+++ b/proto/testmessages_test.go
@@ -1271,6 +1271,20 @@
}.Marshal(),
},
{
+ desc: "required field in absent map message value",
+ partial: true,
+ decodeTo: []proto.Message{&testpb.TestRequiredForeign{
+ MapMessage: map[int32]*testpb.TestRequired{
+ 2: {},
+ },
+ }},
+ wire: pack.Message{
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(2),
+ }),
+ }.Marshal(),
+ },
+ {
desc: "required field in map message set",
decodeTo: []proto.Message{&testpb.TestRequiredForeign{
MapMessage: map[int32]*testpb.TestRequired{