internal/impl: refactor validation a bit
Return the size of the field read from the validator, permitting us to
avoid an extra parse when skipping over groups.
Return an UnmarshalOutput from the validator, since it already combines
two of the validator outputs: bytes read and initialization status.
Remove initialization status from the ValidationStatus enum, since it's
covered by the UnmarshalOutput.
Change-Id: I3e684c45d15aa1992d8dc3bde0f608880d34a94b
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/217763
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/internal/benchmarks/micro/micro_test.go b/internal/benchmarks/micro/micro_test.go
index 097a326..b36dc78 100644
--- a/internal/benchmarks/micro/micro_test.go
+++ b/internal/benchmarks/micro/micro_test.go
@@ -55,7 +55,9 @@
Resolver: protoregistry.GlobalTypes,
}
for pb.Next() {
- if got, want := impl.Validate([]byte{}, mt, opts), impl.ValidationValidInitialized; got != want {
+ _, got := impl.Validate([]byte{}, mt, opts)
+ want := impl.ValidationValid
+ if got != want {
b.Fatalf("Validate = %v, want %v", got, want)
}
}
@@ -106,7 +108,9 @@
Resolver: protoregistry.GlobalTypes,
}
for pb.Next() {
- if got, want := impl.Validate(w, mt, opts), impl.ValidationValidInitialized; got != want {
+ _, got := impl.Validate(w, mt, opts)
+ want := impl.ValidationValid
+ if got != want {
b.Fatalf("Validate = %v, want %v", got, want)
}
}
@@ -167,7 +171,9 @@
Resolver: protoregistry.GlobalTypes,
}
for pb.Next() {
- if got, want := impl.Validate(w, mt, opts), impl.ValidationValidInitialized; got != want {
+ _, got := impl.Validate(w, mt, opts)
+ want := impl.ValidationValid
+ if got != want {
b.Fatalf("Validate = %v, want %v", got, want)
}
}
diff --git a/internal/fuzz/wirefuzz/fuzz.go b/internal/fuzz/wirefuzz/fuzz.go
index 7ca46ba..28aed57 100644
--- a/internal/fuzz/wirefuzz/fuzz.go
+++ b/internal/fuzz/wirefuzz/fuzz.go
@@ -19,7 +19,7 @@
// Fuzz is a fuzzer for proto.Marshal and proto.Unmarshal.
func Fuzz(data []byte) (score int) {
m1 := &fuzzpb.Fuzz{}
- valid := impl.Validate(data, m1.ProtoReflect().Type(), piface.UnmarshalOptions{
+ vout, valid := impl.Validate(data, m1.ProtoReflect().Type(), piface.UnmarshalOptions{
Resolver: protoregistry.GlobalTypes,
})
if err := (proto.UnmarshalOptions{
@@ -33,21 +33,14 @@
}
return 0
}
- if proto.IsInitialized(m1) == nil {
- switch valid {
- case impl.ValidationUnknown:
- case impl.ValidationValidInitialized:
- case impl.ValidationValidMaybeUninitalized:
- default:
- panic("unmarshal ok with validation status: " + valid.String())
- }
- } else {
- switch valid {
- case impl.ValidationUnknown:
- case impl.ValidationValidMaybeUninitalized:
- default:
- panic("partial unmarshal ok with validation status: " + valid.String())
- }
+ switch valid {
+ case impl.ValidationUnknown:
+ case impl.ValidationValid:
+ default:
+ panic("unmarshal ok with validation status: " + valid.String())
+ }
+ if proto.IsInitialized(m1) != nil && vout.Initialized {
+ panic("validation reports partial message is initialized")
}
data1, err := proto.MarshalOptions{
AllowPartial: true,
diff --git a/internal/impl/decode.go b/internal/impl/decode.go
index 290fc41..3155bc5 100644
--- a/internal/impl/decode.go
+++ b/internal/impl/decode.go
@@ -196,11 +196,9 @@
}
if flags.LazyUnmarshalExtensions {
if opts.IsDefault() && x.canLazy(xt) {
- if n, ok := skipExtension(b, xi, num, wtyp, opts); ok {
- x.appendLazyBytes(xt, xi, num, wtyp, b[:n])
+ if out, ok := skipExtension(b, xi, num, wtyp, opts); ok && out.initialized {
+ x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n])
exts[int32(num)] = x
- out.n = n
- out.initialized = true
return out, nil
}
}
@@ -224,35 +222,31 @@
return out, nil
}
-func skipExtension(b []byte, xi *extensionFieldInfo, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (n int, ok bool) {
+func skipExtension(b []byte, xi *extensionFieldInfo, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (out unmarshalOutput, ok bool) {
if xi.validation.mi == nil {
- return 0, false
+ return out, false
}
xi.validation.mi.init()
var v []byte
switch xi.validation.typ {
case validationTypeMessage:
if wtyp != wire.BytesType {
- return 0, false
+ return out, false
}
- v, n = wire.ConsumeBytes(b)
+ v, n := wire.ConsumeBytes(b)
if n < 0 {
- return 0, false
+ return out, false
}
+ out, st := xi.validation.mi.validate(v, 0, opts)
+ out.n = n
+ return out, st == ValidationValid
case validationTypeGroup:
if wtyp != wire.StartGroupType {
- return 0, false
+ return out, false
}
- v, n = wire.ConsumeGroup(num, b)
- if n < 0 {
- return 0, false
- }
+ out, st := xi.validation.mi.validate(v, num, opts)
+ return out, st == ValidationValid
default:
- return 0, false
+ return out, false
}
- if xi.validation.mi.validate(v, 0, opts) != ValidationValidInitialized {
- return 0, false
- }
- return n, true
-
}
diff --git a/internal/impl/validate.go b/internal/impl/validate.go
index eab8ec0..0c32026 100644
--- a/internal/impl/validate.go
+++ b/internal/impl/validate.go
@@ -33,16 +33,8 @@
// ValidationInvalid indicates that unmarshaling the message will fail.
ValidationInvalid
- // ValidationValidInitialized indicates that unmarshaling the message will succeed
- // and IsInitialized on the result will report success.
- ValidationValidInitialized
-
- // ValidationValidMaybeUninitalized indicates unmarshaling the message will succeed,
- // but the output of IsInitialized on the result is unknown.
- //
- // This status may be returned for an initialized message when a message value
- // is split across multiple fields.
- ValidationValidMaybeUninitalized
+ // ValidationValid indicates that unmarshaling the message will succeed.
+ ValidationValid
)
func (v ValidationStatus) String() string {
@@ -51,10 +43,8 @@
return "ValidationUnknown"
case ValidationInvalid:
return "ValidationInvalid"
- case ValidationValidInitialized:
- return "ValidationValidInitialized"
- case ValidationValidMaybeUninitalized:
- return "ValidationValidMaybeUninitalized"
+ case ValidationValid:
+ return "ValidationValid"
default:
return fmt.Sprintf("ValidationStatus(%d)", int(v))
}
@@ -64,12 +54,14 @@
// of the message type.
//
// This function is exposed for testing.
-func Validate(b []byte, mt pref.MessageType, opts piface.UnmarshalOptions) ValidationStatus {
+func Validate(b []byte, mt pref.MessageType, opts piface.UnmarshalOptions) (out piface.UnmarshalOutput, _ ValidationStatus) {
mi, ok := mt.(*MessageInfo)
if !ok {
- return ValidationUnknown
+ return out, ValidationUnknown
}
- return mi.validate(b, 0, unmarshalOptions(opts))
+ o, st := mi.validate(b, 0, unmarshalOptions(opts))
+ out.Initialized = o.initialized
+ return out, st
}
type validationInfo struct {
@@ -219,7 +211,7 @@
return vi
}
-func (mi *MessageInfo) validate(b []byte, groupTag wire.Number, opts unmarshalOptions) (result ValidationStatus) {
+func (mi *MessageInfo) validate(b []byte, groupTag wire.Number, opts unmarshalOptions) (out unmarshalOutput, result ValidationStatus) {
mi.init()
type validationState struct {
typ validationType
@@ -241,12 +233,13 @@
states[0].endGroup = groupTag
}
initialized := true
+ start := len(b)
State:
for len(states) > 0 {
st := &states[len(states)-1]
if st.mi != nil {
if flags.ProtoLegacy && st.mi.isMessageSet {
- return ValidationUnknown
+ return out, ValidationUnknown
}
}
for len(b) > 0 {
@@ -262,13 +255,13 @@
var n int
tag, n = wire.ConsumeVarint(b)
if n < 0 {
- return ValidationInvalid
+ return out, ValidationInvalid
}
b = b[n:]
}
var num wire.Number
if n := tag >> 3; n < uint64(wire.MinValidNumber) || n > uint64(wire.MaxValidNumber) {
- return ValidationInvalid
+ return out, ValidationInvalid
} else {
num = wire.Number(n)
}
@@ -278,7 +271,7 @@
if st.endGroup == num {
goto PopState
}
- return ValidationInvalid
+ return out, ValidationInvalid
}
var vi validationInfo
switch st.typ {
@@ -317,7 +310,7 @@
case preg.NotFound:
vi.typ = validationTypeBytes
default:
- return ValidationUnknown
+ return out, ValidationUnknown
}
}
break
@@ -332,7 +325,7 @@
// determine if the resolver is frozen.
xt, err := opts.Resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num)
if err != nil && err != preg.NotFound {
- return ValidationUnknown
+ return out, ValidationUnknown
}
if err == nil {
vi = getExtensionFieldInfo(xt).validation
@@ -383,7 +376,7 @@
case b[9] < 0x80 && b[9] < 2:
b = b[10:]
default:
- return ValidationInvalid
+ return out, ValidationInvalid
}
} else {
switch {
@@ -408,7 +401,7 @@
case len(b) > 9 && b[9] < 2:
b = b[10:]
default:
- return ValidationInvalid
+ return out, ValidationInvalid
}
}
continue State
@@ -424,19 +417,19 @@
var n int
size, n = wire.ConsumeVarint(b)
if n < 0 {
- return ValidationInvalid
+ return out, ValidationInvalid
}
b = b[n:]
}
if size > uint64(len(b)) {
- return ValidationInvalid
+ return out, ValidationInvalid
}
v := b[:size]
b = b[size:]
switch vi.typ {
case validationTypeMessage:
if vi.mi == nil {
- return ValidationUnknown
+ return out, ValidationUnknown
}
vi.mi.init()
fallthrough
@@ -455,40 +448,40 @@
for len(v) > 0 {
_, n := wire.ConsumeVarint(v)
if n < 0 {
- return ValidationInvalid
+ return out, ValidationInvalid
}
v = v[n:]
}
case validationTypeRepeatedFixed32:
// Packed field.
if len(v)%4 != 0 {
- return ValidationInvalid
+ return out, ValidationInvalid
}
case validationTypeRepeatedFixed64:
// Packed field.
if len(v)%8 != 0 {
- return ValidationInvalid
+ return out, ValidationInvalid
}
case validationTypeUTF8String:
if !utf8.Valid(v) {
- return ValidationInvalid
+ return out, ValidationInvalid
}
}
case wire.Fixed32Type:
if len(b) < 4 {
- return ValidationInvalid
+ return out, ValidationInvalid
}
b = b[4:]
case wire.Fixed64Type:
if len(b) < 8 {
- return ValidationInvalid
+ return out, ValidationInvalid
}
b = b[8:]
case wire.StartGroupType:
switch vi.typ {
case validationTypeGroup:
if vi.mi == nil {
- return ValidationUnknown
+ return out, ValidationUnknown
}
vi.mi.init()
states = append(states, validationState{
@@ -500,19 +493,19 @@
default:
n := wire.ConsumeFieldValue(num, wtyp, b)
if n < 0 {
- return ValidationInvalid
+ return out, ValidationInvalid
}
b = b[n:]
}
default:
- return ValidationInvalid
+ return out, ValidationInvalid
}
}
if st.endGroup != 0 {
- return ValidationInvalid
+ return out, ValidationInvalid
}
if len(b) != 0 {
- return ValidationInvalid
+ return out, ValidationInvalid
}
b = st.tail
PopState:
@@ -535,8 +528,9 @@
}
states = states[:len(states)-1]
}
- if !initialized {
- return ValidationValidMaybeUninitalized
+ out.n = start - len(b)
+ if initialized {
+ out.initialized = true
}
- return ValidationValidInitialized
+ return out, ValidationValid
}
diff --git a/proto/testmessages_test.go b/proto/testmessages_test.go
index d980fd1..3847c21 100644
--- a/proto/testmessages_test.go
+++ b/proto/testmessages_test.go
@@ -28,6 +28,7 @@
checkFastInit bool
unmarshalOptions proto.UnmarshalOptions
validationStatus impl.ValidationStatus
+ nocheckValidInit bool
}
func makeMessages(in protobuild.Message, messages ...proto.Message) []proto.Message {
@@ -1045,8 +1046,9 @@
}.Marshal(),
},
{
- desc: "required field in optional message set (split across multiple tags)",
- checkFastInit: false, // fast init checks don't handle split messages
+ desc: "required field in optional message set (split across multiple tags)",
+ checkFastInit: false, // fast init checks don't handle split messages
+ nocheckValidInit: true, // validation doesn't either
decodeTo: makeMessages(protobuild.Message{
"optional_message": protobuild.Message{
"required_field": 1,
@@ -1058,7 +1060,6 @@
pack.Tag{1, pack.VarintType}, pack.Varint(1),
}),
}.Marshal(),
- validationStatus: impl.ValidationValidMaybeUninitalized,
},
{
desc: "required field in repeated message unset",
diff --git a/proto/validate_test.go b/proto/validate_test.go
index bd4b811..490115a 100644
--- a/proto/validate_test.go
+++ b/proto/validate_test.go
@@ -23,16 +23,18 @@
for _, m := range test.decodeTo {
t.Run(fmt.Sprintf("%s (%T)", test.desc, m), func(t *testing.T) {
mt := m.ProtoReflect().Type()
- want := impl.ValidationValidInitialized
+ want := impl.ValidationValid
if test.validationStatus != 0 {
want = test.validationStatus
- } else if test.partial {
- want = impl.ValidationValidMaybeUninitalized
}
var opts piface.UnmarshalOptions
opts.Resolver = protoregistry.GlobalTypes
- if got, want := impl.Validate(test.wire, mt, opts), want; got != want {
- t.Errorf("Validate(%x) = %v, want %v", test.wire, got, want)
+ out, status := impl.Validate(test.wire, mt, opts)
+ if status != want {
+ t.Errorf("Validate(%x) = %v, want %v", test.wire, status, want)
+ }
+ if got, want := out.Initialized, !test.partial; got != want && !test.nocheckValidInit && status == impl.ValidationValid {
+ t.Errorf("Validate(%x): initialized = %v, want %v", test.wire, got, want)
}
})
}
@@ -46,7 +48,9 @@
mt := m.ProtoReflect().Type()
var opts piface.UnmarshalOptions
opts.Resolver = protoregistry.GlobalTypes
- if got, want := impl.Validate(test.wire, mt, opts), impl.ValidationInvalid; got != want {
+ _, got := impl.Validate(test.wire, mt, opts)
+ want := impl.ValidationInvalid
+ if got != want {
t.Errorf("Validate(%x) = %v, want %v", test.wire, got, want)
}
})