| // Copyright 2019 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package impl |
| |
| import ( |
| "fmt" |
| "math" |
| "math/bits" |
| "reflect" |
| "unicode/utf8" |
| |
| "google.golang.org/protobuf/encoding/protowire" |
| "google.golang.org/protobuf/internal/encoding/messageset" |
| "google.golang.org/protobuf/internal/flags" |
| "google.golang.org/protobuf/internal/genid" |
| "google.golang.org/protobuf/internal/strs" |
| "google.golang.org/protobuf/reflect/protoreflect" |
| "google.golang.org/protobuf/reflect/protoregistry" |
| "google.golang.org/protobuf/runtime/protoiface" |
| ) |
| |
| // ValidationStatus is the result of validating the wire-format encoding of a message. |
| type ValidationStatus int |
| |
| const ( |
| // ValidationUnknown indicates that unmarshaling the message might succeed or fail. |
| // The validator was unable to render a judgement. |
| // |
| // The only causes of this status are an aberrant message type appearing somewhere |
| // in the message or a failure in the extension resolver. |
| ValidationUnknown ValidationStatus = iota + 1 |
| |
| // ValidationInvalid indicates that unmarshaling the message will fail. |
| ValidationInvalid |
| |
| // ValidationValid indicates that unmarshaling the message will succeed. |
| ValidationValid |
| ) |
| |
| func (v ValidationStatus) String() string { |
| switch v { |
| case ValidationUnknown: |
| return "ValidationUnknown" |
| case ValidationInvalid: |
| return "ValidationInvalid" |
| case ValidationValid: |
| return "ValidationValid" |
| default: |
| return fmt.Sprintf("ValidationStatus(%d)", int(v)) |
| } |
| } |
| |
| // Validate determines whether the contents of the buffer are a valid wire encoding |
| // of the message type. |
| // |
| // This function is exposed for testing. |
| func Validate(mt protoreflect.MessageType, in protoiface.UnmarshalInput) (out protoiface.UnmarshalOutput, _ ValidationStatus) { |
| mi, ok := mt.(*MessageInfo) |
| if !ok { |
| return out, ValidationUnknown |
| } |
| if in.Resolver == nil { |
| in.Resolver = protoregistry.GlobalTypes |
| } |
| o, st := mi.validate(in.Buf, 0, unmarshalOptions{ |
| flags: in.Flags, |
| resolver: in.Resolver, |
| }) |
| if o.initialized { |
| out.Flags |= protoiface.UnmarshalInitialized |
| } |
| return out, st |
| } |
| |
| type validationInfo struct { |
| mi *MessageInfo |
| typ validationType |
| keyType, valType validationType |
| |
| // For non-required fields, requiredBit is 0. |
| // |
| // For required fields, requiredBit's nth bit is set, where n is a |
| // unique index in the range [0, MessageInfo.numRequiredFields). |
| // |
| // If there are more than 64 required fields, requiredBit is 0. |
| requiredBit uint64 |
| } |
| |
| type validationType uint8 |
| |
| const ( |
| validationTypeOther validationType = iota |
| validationTypeMessage |
| validationTypeGroup |
| validationTypeMap |
| validationTypeRepeatedVarint |
| validationTypeRepeatedFixed32 |
| validationTypeRepeatedFixed64 |
| validationTypeVarint |
| validationTypeFixed32 |
| validationTypeFixed64 |
| validationTypeBytes |
| validationTypeUTF8String |
| validationTypeMessageSetItem |
| ) |
| |
| func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo { |
| var vi validationInfo |
| switch { |
| case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic(): |
| switch fd.Kind() { |
| case protoreflect.MessageKind: |
| vi.typ = validationTypeMessage |
| if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok { |
| vi.mi = getMessageInfo(ot.Field(0).Type) |
| } |
| case protoreflect.GroupKind: |
| vi.typ = validationTypeGroup |
| if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok { |
| vi.mi = getMessageInfo(ot.Field(0).Type) |
| } |
| case protoreflect.StringKind: |
| if strs.EnforceUTF8(fd) { |
| vi.typ = validationTypeUTF8String |
| } |
| } |
| default: |
| vi = newValidationInfo(fd, ft) |
| } |
| if fd.Cardinality() == protoreflect.Required { |
| // Avoid overflow. The required field check is done with a 64-bit mask, with |
| // any message containing more than 64 required fields always reported as |
| // potentially uninitialized, so it is not important to get a precise count |
| // of the required fields past 64. |
| if mi.numRequiredFields < math.MaxUint8 { |
| mi.numRequiredFields++ |
| vi.requiredBit = 1 << (mi.numRequiredFields - 1) |
| } |
| } |
| return vi |
| } |
| |
| func newValidationInfo(fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo { |
| var vi validationInfo |
| switch { |
| case fd.IsList(): |
| switch fd.Kind() { |
| case protoreflect.MessageKind: |
| vi.typ = validationTypeMessage |
| if ft.Kind() == reflect.Slice { |
| vi.mi = getMessageInfo(ft.Elem()) |
| } |
| case protoreflect.GroupKind: |
| vi.typ = validationTypeGroup |
| if ft.Kind() == reflect.Slice { |
| vi.mi = getMessageInfo(ft.Elem()) |
| } |
| case protoreflect.StringKind: |
| vi.typ = validationTypeBytes |
| if strs.EnforceUTF8(fd) { |
| vi.typ = validationTypeUTF8String |
| } |
| default: |
| switch wireTypes[fd.Kind()] { |
| case protowire.VarintType: |
| vi.typ = validationTypeRepeatedVarint |
| case protowire.Fixed32Type: |
| vi.typ = validationTypeRepeatedFixed32 |
| case protowire.Fixed64Type: |
| vi.typ = validationTypeRepeatedFixed64 |
| } |
| } |
| case fd.IsMap(): |
| vi.typ = validationTypeMap |
| switch fd.MapKey().Kind() { |
| case protoreflect.StringKind: |
| if strs.EnforceUTF8(fd) { |
| vi.keyType = validationTypeUTF8String |
| } |
| } |
| switch fd.MapValue().Kind() { |
| case protoreflect.MessageKind: |
| vi.valType = validationTypeMessage |
| if ft.Kind() == reflect.Map { |
| vi.mi = getMessageInfo(ft.Elem()) |
| } |
| case protoreflect.StringKind: |
| if strs.EnforceUTF8(fd) { |
| vi.valType = validationTypeUTF8String |
| } |
| } |
| default: |
| switch fd.Kind() { |
| case protoreflect.MessageKind: |
| vi.typ = validationTypeMessage |
| if !fd.IsWeak() { |
| vi.mi = getMessageInfo(ft) |
| } |
| case protoreflect.GroupKind: |
| vi.typ = validationTypeGroup |
| vi.mi = getMessageInfo(ft) |
| case protoreflect.StringKind: |
| vi.typ = validationTypeBytes |
| if strs.EnforceUTF8(fd) { |
| vi.typ = validationTypeUTF8String |
| } |
| default: |
| switch wireTypes[fd.Kind()] { |
| case protowire.VarintType: |
| vi.typ = validationTypeVarint |
| case protowire.Fixed32Type: |
| vi.typ = validationTypeFixed32 |
| case protowire.Fixed64Type: |
| vi.typ = validationTypeFixed64 |
| case protowire.BytesType: |
| vi.typ = validationTypeBytes |
| } |
| } |
| } |
| return vi |
| } |
| |
| func (mi *MessageInfo) validate(b []byte, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, result ValidationStatus) { |
| mi.init() |
| type validationState struct { |
| typ validationType |
| keyType, valType validationType |
| endGroup protowire.Number |
| mi *MessageInfo |
| tail []byte |
| requiredMask uint64 |
| } |
| |
| // Pre-allocate some slots to avoid repeated slice reallocation. |
| states := make([]validationState, 0, 16) |
| states = append(states, validationState{ |
| typ: validationTypeMessage, |
| mi: mi, |
| }) |
| if groupTag > 0 { |
| states[0].typ = validationTypeGroup |
| states[0].endGroup = groupTag |
| } |
| initialized := true |
| start := len(b) |
| State: |
| for len(states) > 0 { |
| st := &states[len(states)-1] |
| for len(b) > 0 { |
| // Parse the tag (field number and wire type). |
| var tag uint64 |
| if b[0] < 0x80 { |
| tag = uint64(b[0]) |
| b = b[1:] |
| } else if len(b) >= 2 && b[1] < 128 { |
| tag = uint64(b[0]&0x7f) + uint64(b[1])<<7 |
| b = b[2:] |
| } else { |
| var n int |
| tag, n = protowire.ConsumeVarint(b) |
| if n < 0 { |
| return out, ValidationInvalid |
| } |
| b = b[n:] |
| } |
| var num protowire.Number |
| if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) { |
| return out, ValidationInvalid |
| } else { |
| num = protowire.Number(n) |
| } |
| wtyp := protowire.Type(tag & 7) |
| |
| if wtyp == protowire.EndGroupType { |
| if st.endGroup == num { |
| goto PopState |
| } |
| return out, ValidationInvalid |
| } |
| var vi validationInfo |
| switch { |
| case st.typ == validationTypeMap: |
| switch num { |
| case genid.MapEntry_Key_field_number: |
| vi.typ = st.keyType |
| case genid.MapEntry_Value_field_number: |
| vi.typ = st.valType |
| vi.mi = st.mi |
| vi.requiredBit = 1 |
| } |
| case flags.ProtoLegacy && st.mi.isMessageSet: |
| switch num { |
| case messageset.FieldItem: |
| vi.typ = validationTypeMessageSetItem |
| } |
| default: |
| var f *coderFieldInfo |
| if int(num) < len(st.mi.denseCoderFields) { |
| f = st.mi.denseCoderFields[num] |
| } else { |
| f = st.mi.coderFields[num] |
| } |
| if f != nil { |
| vi = f.validation |
| if vi.typ == validationTypeMessage && vi.mi == nil { |
| // Probable weak field. |
| // |
| // TODO: Consider storing the results of this lookup somewhere |
| // rather than recomputing it on every validation. |
| fd := st.mi.Desc.Fields().ByNumber(num) |
| if fd == nil || !fd.IsWeak() { |
| break |
| } |
| messageName := fd.Message().FullName() |
| messageType, err := protoregistry.GlobalTypes.FindMessageByName(messageName) |
| switch err { |
| case nil: |
| vi.mi, _ = messageType.(*MessageInfo) |
| case protoregistry.NotFound: |
| vi.typ = validationTypeBytes |
| default: |
| return out, ValidationUnknown |
| } |
| } |
| break |
| } |
| // Possible extension field. |
| // |
| // TODO: We should return ValidationUnknown when: |
| // 1. The resolver is not frozen. (More extensions may be added to it.) |
| // 2. The resolver returns preg.NotFound. |
| // In this case, a type added to the resolver in the future could cause |
| // unmarshaling to begin failing. Supporting this requires some way to |
| // determine if the resolver is frozen. |
| xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num) |
| if err != nil && err != protoregistry.NotFound { |
| return out, ValidationUnknown |
| } |
| if err == nil { |
| vi = getExtensionFieldInfo(xt).validation |
| } |
| } |
| if vi.requiredBit != 0 { |
| // Check that the field has a compatible wire type. |
| // We only need to consider non-repeated field types, |
| // since repeated fields (and maps) can never be required. |
| ok := false |
| switch vi.typ { |
| case validationTypeVarint: |
| ok = wtyp == protowire.VarintType |
| case validationTypeFixed32: |
| ok = wtyp == protowire.Fixed32Type |
| case validationTypeFixed64: |
| ok = wtyp == protowire.Fixed64Type |
| case validationTypeBytes, validationTypeUTF8String, validationTypeMessage: |
| ok = wtyp == protowire.BytesType |
| case validationTypeGroup: |
| ok = wtyp == protowire.StartGroupType |
| } |
| if ok { |
| st.requiredMask |= vi.requiredBit |
| } |
| } |
| |
| switch wtyp { |
| case protowire.VarintType: |
| if len(b) >= 10 { |
| switch { |
| case b[0] < 0x80: |
| b = b[1:] |
| case b[1] < 0x80: |
| b = b[2:] |
| case b[2] < 0x80: |
| b = b[3:] |
| case b[3] < 0x80: |
| b = b[4:] |
| case b[4] < 0x80: |
| b = b[5:] |
| case b[5] < 0x80: |
| b = b[6:] |
| case b[6] < 0x80: |
| b = b[7:] |
| case b[7] < 0x80: |
| b = b[8:] |
| case b[8] < 0x80: |
| b = b[9:] |
| case b[9] < 0x80 && b[9] < 2: |
| b = b[10:] |
| default: |
| return out, ValidationInvalid |
| } |
| } else { |
| switch { |
| case len(b) > 0 && b[0] < 0x80: |
| b = b[1:] |
| case len(b) > 1 && b[1] < 0x80: |
| b = b[2:] |
| case len(b) > 2 && b[2] < 0x80: |
| b = b[3:] |
| case len(b) > 3 && b[3] < 0x80: |
| b = b[4:] |
| case len(b) > 4 && b[4] < 0x80: |
| b = b[5:] |
| case len(b) > 5 && b[5] < 0x80: |
| b = b[6:] |
| case len(b) > 6 && b[6] < 0x80: |
| b = b[7:] |
| case len(b) > 7 && b[7] < 0x80: |
| b = b[8:] |
| case len(b) > 8 && b[8] < 0x80: |
| b = b[9:] |
| case len(b) > 9 && b[9] < 2: |
| b = b[10:] |
| default: |
| return out, ValidationInvalid |
| } |
| } |
| continue State |
| case protowire.BytesType: |
| var size uint64 |
| if len(b) >= 1 && b[0] < 0x80 { |
| size = uint64(b[0]) |
| b = b[1:] |
| } else if len(b) >= 2 && b[1] < 128 { |
| size = uint64(b[0]&0x7f) + uint64(b[1])<<7 |
| b = b[2:] |
| } else { |
| var n int |
| size, n = protowire.ConsumeVarint(b) |
| if n < 0 { |
| return out, ValidationInvalid |
| } |
| b = b[n:] |
| } |
| if size > uint64(len(b)) { |
| return out, ValidationInvalid |
| } |
| v := b[:size] |
| b = b[size:] |
| switch vi.typ { |
| case validationTypeMessage: |
| if vi.mi == nil { |
| return out, ValidationUnknown |
| } |
| vi.mi.init() |
| fallthrough |
| case validationTypeMap: |
| if vi.mi != nil { |
| vi.mi.init() |
| } |
| states = append(states, validationState{ |
| typ: vi.typ, |
| keyType: vi.keyType, |
| valType: vi.valType, |
| mi: vi.mi, |
| tail: b, |
| }) |
| b = v |
| continue State |
| case validationTypeRepeatedVarint: |
| // Packed field. |
| for len(v) > 0 { |
| _, n := protowire.ConsumeVarint(v) |
| if n < 0 { |
| return out, ValidationInvalid |
| } |
| v = v[n:] |
| } |
| case validationTypeRepeatedFixed32: |
| // Packed field. |
| if len(v)%4 != 0 { |
| return out, ValidationInvalid |
| } |
| case validationTypeRepeatedFixed64: |
| // Packed field. |
| if len(v)%8 != 0 { |
| return out, ValidationInvalid |
| } |
| case validationTypeUTF8String: |
| if !utf8.Valid(v) { |
| return out, ValidationInvalid |
| } |
| } |
| case protowire.Fixed32Type: |
| if len(b) < 4 { |
| return out, ValidationInvalid |
| } |
| b = b[4:] |
| case protowire.Fixed64Type: |
| if len(b) < 8 { |
| return out, ValidationInvalid |
| } |
| b = b[8:] |
| case protowire.StartGroupType: |
| switch { |
| case vi.typ == validationTypeGroup: |
| if vi.mi == nil { |
| return out, ValidationUnknown |
| } |
| vi.mi.init() |
| states = append(states, validationState{ |
| typ: validationTypeGroup, |
| mi: vi.mi, |
| endGroup: num, |
| }) |
| continue State |
| case flags.ProtoLegacy && vi.typ == validationTypeMessageSetItem: |
| typeid, v, n, err := messageset.ConsumeFieldValue(b, false) |
| if err != nil { |
| return out, ValidationInvalid |
| } |
| xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid) |
| switch { |
| case err == protoregistry.NotFound: |
| b = b[n:] |
| case err != nil: |
| return out, ValidationUnknown |
| default: |
| xvi := getExtensionFieldInfo(xt).validation |
| if xvi.mi != nil { |
| xvi.mi.init() |
| } |
| states = append(states, validationState{ |
| typ: xvi.typ, |
| mi: xvi.mi, |
| tail: b[n:], |
| }) |
| b = v |
| continue State |
| } |
| default: |
| n := protowire.ConsumeFieldValue(num, wtyp, b) |
| if n < 0 { |
| return out, ValidationInvalid |
| } |
| b = b[n:] |
| } |
| default: |
| return out, ValidationInvalid |
| } |
| } |
| if st.endGroup != 0 { |
| return out, ValidationInvalid |
| } |
| if len(b) != 0 { |
| return out, ValidationInvalid |
| } |
| b = st.tail |
| PopState: |
| numRequiredFields := 0 |
| switch st.typ { |
| case validationTypeMessage, validationTypeGroup: |
| numRequiredFields = int(st.mi.numRequiredFields) |
| case validationTypeMap: |
| // If this is a map field with a message value that contains |
| // required fields, require that the value be present. |
| if st.mi != nil && st.mi.numRequiredFields > 0 { |
| numRequiredFields = 1 |
| } |
| } |
| // If there are more than 64 required fields, this check will |
| // always fail and we will report that the message is potentially |
| // uninitialized. |
| if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields { |
| initialized = false |
| } |
| states = states[:len(states)-1] |
| } |
| out.n = start - len(b) |
| if initialized { |
| out.initialized = true |
| } |
| return out, ValidationValid |
| } |