|  | // Copyright 2019 The Go Authors. All rights reserved. | 
|  | // Use of this source code is governed by a BSD-style | 
|  | // license that can be found in the LICENSE file. | 
|  |  | 
|  | package impl | 
|  |  | 
|  | import ( | 
|  | "fmt" | 
|  | "math" | 
|  | "math/bits" | 
|  | "reflect" | 
|  | "unicode/utf8" | 
|  |  | 
|  | "google.golang.org/protobuf/encoding/protowire" | 
|  | "google.golang.org/protobuf/internal/encoding/messageset" | 
|  | "google.golang.org/protobuf/internal/flags" | 
|  | "google.golang.org/protobuf/internal/genid" | 
|  | "google.golang.org/protobuf/internal/strs" | 
|  | "google.golang.org/protobuf/reflect/protoreflect" | 
|  | "google.golang.org/protobuf/reflect/protoregistry" | 
|  | "google.golang.org/protobuf/runtime/protoiface" | 
|  | ) | 
|  |  | 
|  | // ValidationStatus is the result of validating the wire-format encoding of a message. | 
|  | type ValidationStatus int | 
|  |  | 
|  | const ( | 
|  | // ValidationUnknown indicates that unmarshaling the message might succeed or fail. | 
|  | // The validator was unable to render a judgement. | 
|  | // | 
|  | // The only causes of this status are an aberrant message type appearing somewhere | 
|  | // in the message or a failure in the extension resolver. | 
|  | ValidationUnknown ValidationStatus = iota + 1 | 
|  |  | 
|  | // ValidationInvalid indicates that unmarshaling the message will fail. | 
|  | ValidationInvalid | 
|  |  | 
|  | // ValidationValid indicates that unmarshaling the message will succeed. | 
|  | ValidationValid | 
|  | ) | 
|  |  | 
|  | func (v ValidationStatus) String() string { | 
|  | switch v { | 
|  | case ValidationUnknown: | 
|  | return "ValidationUnknown" | 
|  | case ValidationInvalid: | 
|  | return "ValidationInvalid" | 
|  | case ValidationValid: | 
|  | return "ValidationValid" | 
|  | default: | 
|  | return fmt.Sprintf("ValidationStatus(%d)", int(v)) | 
|  | } | 
|  | } | 
|  |  | 
|  | // Validate determines whether the contents of the buffer are a valid wire encoding | 
|  | // of the message type. | 
|  | // | 
|  | // This function is exposed for testing. | 
|  | func Validate(mt protoreflect.MessageType, in protoiface.UnmarshalInput) (out protoiface.UnmarshalOutput, _ ValidationStatus) { | 
|  | mi, ok := mt.(*MessageInfo) | 
|  | if !ok { | 
|  | return out, ValidationUnknown | 
|  | } | 
|  | if in.Resolver == nil { | 
|  | in.Resolver = protoregistry.GlobalTypes | 
|  | } | 
|  | o, st := mi.validate(in.Buf, 0, unmarshalOptions{ | 
|  | flags:    in.Flags, | 
|  | resolver: in.Resolver, | 
|  | }) | 
|  | if o.initialized { | 
|  | out.Flags |= protoiface.UnmarshalInitialized | 
|  | } | 
|  | return out, st | 
|  | } | 
|  |  | 
|  | type validationInfo struct { | 
|  | mi               *MessageInfo | 
|  | typ              validationType | 
|  | keyType, valType validationType | 
|  |  | 
|  | // For non-required fields, requiredBit is 0. | 
|  | // | 
|  | // For required fields, requiredBit's nth bit is set, where n is a | 
|  | // unique index in the range [0, MessageInfo.numRequiredFields). | 
|  | // | 
|  | // If there are more than 64 required fields, requiredBit is 0. | 
|  | requiredBit uint64 | 
|  | } | 
|  |  | 
|  | type validationType uint8 | 
|  |  | 
|  | const ( | 
|  | validationTypeOther validationType = iota | 
|  | validationTypeMessage | 
|  | validationTypeGroup | 
|  | validationTypeMap | 
|  | validationTypeRepeatedVarint | 
|  | validationTypeRepeatedFixed32 | 
|  | validationTypeRepeatedFixed64 | 
|  | validationTypeVarint | 
|  | validationTypeFixed32 | 
|  | validationTypeFixed64 | 
|  | validationTypeBytes | 
|  | validationTypeUTF8String | 
|  | validationTypeMessageSetItem | 
|  | ) | 
|  |  | 
|  | func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo { | 
|  | var vi validationInfo | 
|  | switch { | 
|  | case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic(): | 
|  | switch fd.Kind() { | 
|  | case protoreflect.MessageKind: | 
|  | vi.typ = validationTypeMessage | 
|  | if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok { | 
|  | vi.mi = getMessageInfo(ot.Field(0).Type) | 
|  | } | 
|  | case protoreflect.GroupKind: | 
|  | vi.typ = validationTypeGroup | 
|  | if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok { | 
|  | vi.mi = getMessageInfo(ot.Field(0).Type) | 
|  | } | 
|  | case protoreflect.StringKind: | 
|  | if strs.EnforceUTF8(fd) { | 
|  | vi.typ = validationTypeUTF8String | 
|  | } | 
|  | } | 
|  | default: | 
|  | vi = newValidationInfo(fd, ft) | 
|  | } | 
|  | if fd.Cardinality() == protoreflect.Required { | 
|  | // Avoid overflow. The required field check is done with a 64-bit mask, with | 
|  | // any message containing more than 64 required fields always reported as | 
|  | // potentially uninitialized, so it is not important to get a precise count | 
|  | // of the required fields past 64. | 
|  | if mi.numRequiredFields < math.MaxUint8 { | 
|  | mi.numRequiredFields++ | 
|  | vi.requiredBit = 1 << (mi.numRequiredFields - 1) | 
|  | } | 
|  | } | 
|  | return vi | 
|  | } | 
|  |  | 
|  | func newValidationInfo(fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo { | 
|  | var vi validationInfo | 
|  | switch { | 
|  | case fd.IsList(): | 
|  | switch fd.Kind() { | 
|  | case protoreflect.MessageKind: | 
|  | vi.typ = validationTypeMessage | 
|  | if ft.Kind() == reflect.Slice { | 
|  | vi.mi = getMessageInfo(ft.Elem()) | 
|  | } | 
|  | case protoreflect.GroupKind: | 
|  | vi.typ = validationTypeGroup | 
|  | if ft.Kind() == reflect.Slice { | 
|  | vi.mi = getMessageInfo(ft.Elem()) | 
|  | } | 
|  | case protoreflect.StringKind: | 
|  | vi.typ = validationTypeBytes | 
|  | if strs.EnforceUTF8(fd) { | 
|  | vi.typ = validationTypeUTF8String | 
|  | } | 
|  | default: | 
|  | switch wireTypes[fd.Kind()] { | 
|  | case protowire.VarintType: | 
|  | vi.typ = validationTypeRepeatedVarint | 
|  | case protowire.Fixed32Type: | 
|  | vi.typ = validationTypeRepeatedFixed32 | 
|  | case protowire.Fixed64Type: | 
|  | vi.typ = validationTypeRepeatedFixed64 | 
|  | } | 
|  | } | 
|  | case fd.IsMap(): | 
|  | vi.typ = validationTypeMap | 
|  | switch fd.MapKey().Kind() { | 
|  | case protoreflect.StringKind: | 
|  | if strs.EnforceUTF8(fd) { | 
|  | vi.keyType = validationTypeUTF8String | 
|  | } | 
|  | } | 
|  | switch fd.MapValue().Kind() { | 
|  | case protoreflect.MessageKind: | 
|  | vi.valType = validationTypeMessage | 
|  | if ft.Kind() == reflect.Map { | 
|  | vi.mi = getMessageInfo(ft.Elem()) | 
|  | } | 
|  | case protoreflect.StringKind: | 
|  | if strs.EnforceUTF8(fd) { | 
|  | vi.valType = validationTypeUTF8String | 
|  | } | 
|  | } | 
|  | default: | 
|  | switch fd.Kind() { | 
|  | case protoreflect.MessageKind: | 
|  | vi.typ = validationTypeMessage | 
|  | if !fd.IsWeak() { | 
|  | vi.mi = getMessageInfo(ft) | 
|  | } | 
|  | case protoreflect.GroupKind: | 
|  | vi.typ = validationTypeGroup | 
|  | vi.mi = getMessageInfo(ft) | 
|  | case protoreflect.StringKind: | 
|  | vi.typ = validationTypeBytes | 
|  | if strs.EnforceUTF8(fd) { | 
|  | vi.typ = validationTypeUTF8String | 
|  | } | 
|  | default: | 
|  | switch wireTypes[fd.Kind()] { | 
|  | case protowire.VarintType: | 
|  | vi.typ = validationTypeVarint | 
|  | case protowire.Fixed32Type: | 
|  | vi.typ = validationTypeFixed32 | 
|  | case protowire.Fixed64Type: | 
|  | vi.typ = validationTypeFixed64 | 
|  | case protowire.BytesType: | 
|  | vi.typ = validationTypeBytes | 
|  | } | 
|  | } | 
|  | } | 
|  | return vi | 
|  | } | 
|  |  | 
|  | func (mi *MessageInfo) validate(b []byte, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, result ValidationStatus) { | 
|  | mi.init() | 
|  | type validationState struct { | 
|  | typ              validationType | 
|  | keyType, valType validationType | 
|  | endGroup         protowire.Number | 
|  | mi               *MessageInfo | 
|  | tail             []byte | 
|  | requiredMask     uint64 | 
|  | } | 
|  |  | 
|  | // Pre-allocate some slots to avoid repeated slice reallocation. | 
|  | states := make([]validationState, 0, 16) | 
|  | states = append(states, validationState{ | 
|  | typ: validationTypeMessage, | 
|  | mi:  mi, | 
|  | }) | 
|  | if groupTag > 0 { | 
|  | states[0].typ = validationTypeGroup | 
|  | states[0].endGroup = groupTag | 
|  | } | 
|  | initialized := true | 
|  | start := len(b) | 
|  | State: | 
|  | for len(states) > 0 { | 
|  | st := &states[len(states)-1] | 
|  | for len(b) > 0 { | 
|  | // Parse the tag (field number and wire type). | 
|  | var tag uint64 | 
|  | if b[0] < 0x80 { | 
|  | tag = uint64(b[0]) | 
|  | b = b[1:] | 
|  | } else if len(b) >= 2 && b[1] < 128 { | 
|  | tag = uint64(b[0]&0x7f) + uint64(b[1])<<7 | 
|  | b = b[2:] | 
|  | } else { | 
|  | var n int | 
|  | tag, n = protowire.ConsumeVarint(b) | 
|  | if n < 0 { | 
|  | return out, ValidationInvalid | 
|  | } | 
|  | b = b[n:] | 
|  | } | 
|  | var num protowire.Number | 
|  | if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) { | 
|  | return out, ValidationInvalid | 
|  | } else { | 
|  | num = protowire.Number(n) | 
|  | } | 
|  | wtyp := protowire.Type(tag & 7) | 
|  |  | 
|  | if wtyp == protowire.EndGroupType { | 
|  | if st.endGroup == num { | 
|  | goto PopState | 
|  | } | 
|  | return out, ValidationInvalid | 
|  | } | 
|  | var vi validationInfo | 
|  | switch { | 
|  | case st.typ == validationTypeMap: | 
|  | switch num { | 
|  | case genid.MapEntry_Key_field_number: | 
|  | vi.typ = st.keyType | 
|  | case genid.MapEntry_Value_field_number: | 
|  | vi.typ = st.valType | 
|  | vi.mi = st.mi | 
|  | vi.requiredBit = 1 | 
|  | } | 
|  | case flags.ProtoLegacy && st.mi.isMessageSet: | 
|  | switch num { | 
|  | case messageset.FieldItem: | 
|  | vi.typ = validationTypeMessageSetItem | 
|  | } | 
|  | default: | 
|  | var f *coderFieldInfo | 
|  | if int(num) < len(st.mi.denseCoderFields) { | 
|  | f = st.mi.denseCoderFields[num] | 
|  | } else { | 
|  | f = st.mi.coderFields[num] | 
|  | } | 
|  | if f != nil { | 
|  | vi = f.validation | 
|  | if vi.typ == validationTypeMessage && vi.mi == nil { | 
|  | // Probable weak field. | 
|  | // | 
|  | // TODO: Consider storing the results of this lookup somewhere | 
|  | // rather than recomputing it on every validation. | 
|  | fd := st.mi.Desc.Fields().ByNumber(num) | 
|  | if fd == nil || !fd.IsWeak() { | 
|  | break | 
|  | } | 
|  | messageName := fd.Message().FullName() | 
|  | messageType, err := protoregistry.GlobalTypes.FindMessageByName(messageName) | 
|  | switch err { | 
|  | case nil: | 
|  | vi.mi, _ = messageType.(*MessageInfo) | 
|  | case protoregistry.NotFound: | 
|  | vi.typ = validationTypeBytes | 
|  | default: | 
|  | return out, ValidationUnknown | 
|  | } | 
|  | } | 
|  | break | 
|  | } | 
|  | // Possible extension field. | 
|  | // | 
|  | // TODO: We should return ValidationUnknown when: | 
|  | //   1. The resolver is not frozen. (More extensions may be added to it.) | 
|  | //   2. The resolver returns preg.NotFound. | 
|  | // In this case, a type added to the resolver in the future could cause | 
|  | // unmarshaling to begin failing. Supporting this requires some way to | 
|  | // determine if the resolver is frozen. | 
|  | xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num) | 
|  | if err != nil && err != protoregistry.NotFound { | 
|  | return out, ValidationUnknown | 
|  | } | 
|  | if err == nil { | 
|  | vi = getExtensionFieldInfo(xt).validation | 
|  | } | 
|  | } | 
|  | if vi.requiredBit != 0 { | 
|  | // Check that the field has a compatible wire type. | 
|  | // We only need to consider non-repeated field types, | 
|  | // since repeated fields (and maps) can never be required. | 
|  | ok := false | 
|  | switch vi.typ { | 
|  | case validationTypeVarint: | 
|  | ok = wtyp == protowire.VarintType | 
|  | case validationTypeFixed32: | 
|  | ok = wtyp == protowire.Fixed32Type | 
|  | case validationTypeFixed64: | 
|  | ok = wtyp == protowire.Fixed64Type | 
|  | case validationTypeBytes, validationTypeUTF8String, validationTypeMessage: | 
|  | ok = wtyp == protowire.BytesType | 
|  | case validationTypeGroup: | 
|  | ok = wtyp == protowire.StartGroupType | 
|  | } | 
|  | if ok { | 
|  | st.requiredMask |= vi.requiredBit | 
|  | } | 
|  | } | 
|  |  | 
|  | switch wtyp { | 
|  | case protowire.VarintType: | 
|  | if len(b) >= 10 { | 
|  | switch { | 
|  | case b[0] < 0x80: | 
|  | b = b[1:] | 
|  | case b[1] < 0x80: | 
|  | b = b[2:] | 
|  | case b[2] < 0x80: | 
|  | b = b[3:] | 
|  | case b[3] < 0x80: | 
|  | b = b[4:] | 
|  | case b[4] < 0x80: | 
|  | b = b[5:] | 
|  | case b[5] < 0x80: | 
|  | b = b[6:] | 
|  | case b[6] < 0x80: | 
|  | b = b[7:] | 
|  | case b[7] < 0x80: | 
|  | b = b[8:] | 
|  | case b[8] < 0x80: | 
|  | b = b[9:] | 
|  | case b[9] < 0x80 && b[9] < 2: | 
|  | b = b[10:] | 
|  | default: | 
|  | return out, ValidationInvalid | 
|  | } | 
|  | } else { | 
|  | switch { | 
|  | case len(b) > 0 && b[0] < 0x80: | 
|  | b = b[1:] | 
|  | case len(b) > 1 && b[1] < 0x80: | 
|  | b = b[2:] | 
|  | case len(b) > 2 && b[2] < 0x80: | 
|  | b = b[3:] | 
|  | case len(b) > 3 && b[3] < 0x80: | 
|  | b = b[4:] | 
|  | case len(b) > 4 && b[4] < 0x80: | 
|  | b = b[5:] | 
|  | case len(b) > 5 && b[5] < 0x80: | 
|  | b = b[6:] | 
|  | case len(b) > 6 && b[6] < 0x80: | 
|  | b = b[7:] | 
|  | case len(b) > 7 && b[7] < 0x80: | 
|  | b = b[8:] | 
|  | case len(b) > 8 && b[8] < 0x80: | 
|  | b = b[9:] | 
|  | case len(b) > 9 && b[9] < 2: | 
|  | b = b[10:] | 
|  | default: | 
|  | return out, ValidationInvalid | 
|  | } | 
|  | } | 
|  | continue State | 
|  | case protowire.BytesType: | 
|  | var size uint64 | 
|  | if len(b) >= 1 && b[0] < 0x80 { | 
|  | size = uint64(b[0]) | 
|  | b = b[1:] | 
|  | } else if len(b) >= 2 && b[1] < 128 { | 
|  | size = uint64(b[0]&0x7f) + uint64(b[1])<<7 | 
|  | b = b[2:] | 
|  | } else { | 
|  | var n int | 
|  | size, n = protowire.ConsumeVarint(b) | 
|  | if n < 0 { | 
|  | return out, ValidationInvalid | 
|  | } | 
|  | b = b[n:] | 
|  | } | 
|  | if size > uint64(len(b)) { | 
|  | return out, ValidationInvalid | 
|  | } | 
|  | v := b[:size] | 
|  | b = b[size:] | 
|  | switch vi.typ { | 
|  | case validationTypeMessage: | 
|  | if vi.mi == nil { | 
|  | return out, ValidationUnknown | 
|  | } | 
|  | vi.mi.init() | 
|  | fallthrough | 
|  | case validationTypeMap: | 
|  | if vi.mi != nil { | 
|  | vi.mi.init() | 
|  | } | 
|  | states = append(states, validationState{ | 
|  | typ:     vi.typ, | 
|  | keyType: vi.keyType, | 
|  | valType: vi.valType, | 
|  | mi:      vi.mi, | 
|  | tail:    b, | 
|  | }) | 
|  | b = v | 
|  | continue State | 
|  | case validationTypeRepeatedVarint: | 
|  | // Packed field. | 
|  | for len(v) > 0 { | 
|  | _, n := protowire.ConsumeVarint(v) | 
|  | if n < 0 { | 
|  | return out, ValidationInvalid | 
|  | } | 
|  | v = v[n:] | 
|  | } | 
|  | case validationTypeRepeatedFixed32: | 
|  | // Packed field. | 
|  | if len(v)%4 != 0 { | 
|  | return out, ValidationInvalid | 
|  | } | 
|  | case validationTypeRepeatedFixed64: | 
|  | // Packed field. | 
|  | if len(v)%8 != 0 { | 
|  | return out, ValidationInvalid | 
|  | } | 
|  | case validationTypeUTF8String: | 
|  | if !utf8.Valid(v) { | 
|  | return out, ValidationInvalid | 
|  | } | 
|  | } | 
|  | case protowire.Fixed32Type: | 
|  | if len(b) < 4 { | 
|  | return out, ValidationInvalid | 
|  | } | 
|  | b = b[4:] | 
|  | case protowire.Fixed64Type: | 
|  | if len(b) < 8 { | 
|  | return out, ValidationInvalid | 
|  | } | 
|  | b = b[8:] | 
|  | case protowire.StartGroupType: | 
|  | switch { | 
|  | case vi.typ == validationTypeGroup: | 
|  | if vi.mi == nil { | 
|  | return out, ValidationUnknown | 
|  | } | 
|  | vi.mi.init() | 
|  | states = append(states, validationState{ | 
|  | typ:      validationTypeGroup, | 
|  | mi:       vi.mi, | 
|  | endGroup: num, | 
|  | }) | 
|  | continue State | 
|  | case flags.ProtoLegacy && vi.typ == validationTypeMessageSetItem: | 
|  | typeid, v, n, err := messageset.ConsumeFieldValue(b, false) | 
|  | if err != nil { | 
|  | return out, ValidationInvalid | 
|  | } | 
|  | xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid) | 
|  | switch { | 
|  | case err == protoregistry.NotFound: | 
|  | b = b[n:] | 
|  | case err != nil: | 
|  | return out, ValidationUnknown | 
|  | default: | 
|  | xvi := getExtensionFieldInfo(xt).validation | 
|  | if xvi.mi != nil { | 
|  | xvi.mi.init() | 
|  | } | 
|  | states = append(states, validationState{ | 
|  | typ:  xvi.typ, | 
|  | mi:   xvi.mi, | 
|  | tail: b[n:], | 
|  | }) | 
|  | b = v | 
|  | continue State | 
|  | } | 
|  | default: | 
|  | n := protowire.ConsumeFieldValue(num, wtyp, b) | 
|  | if n < 0 { | 
|  | return out, ValidationInvalid | 
|  | } | 
|  | b = b[n:] | 
|  | } | 
|  | default: | 
|  | return out, ValidationInvalid | 
|  | } | 
|  | } | 
|  | if st.endGroup != 0 { | 
|  | return out, ValidationInvalid | 
|  | } | 
|  | if len(b) != 0 { | 
|  | return out, ValidationInvalid | 
|  | } | 
|  | b = st.tail | 
|  | PopState: | 
|  | numRequiredFields := 0 | 
|  | switch st.typ { | 
|  | case validationTypeMessage, validationTypeGroup: | 
|  | numRequiredFields = int(st.mi.numRequiredFields) | 
|  | case validationTypeMap: | 
|  | // If this is a map field with a message value that contains | 
|  | // required fields, require that the value be present. | 
|  | if st.mi != nil && st.mi.numRequiredFields > 0 { | 
|  | numRequiredFields = 1 | 
|  | } | 
|  | } | 
|  | // If there are more than 64 required fields, this check will | 
|  | // always fail and we will report that the message is potentially | 
|  | // uninitialized. | 
|  | if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields { | 
|  | initialized = false | 
|  | } | 
|  | states = states[:len(states)-1] | 
|  | } | 
|  | out.n = start - len(b) | 
|  | if initialized { | 
|  | out.initialized = true | 
|  | } | 
|  | return out, ValidationValid | 
|  | } |