all: implement depth limit for unmarshaling
+ This change introduce a default and configurable depth limit for
proto.Unmarshal. If a message is nested deeper than the limit,
unmarshaling will fail. There are two ways to nest messages. Either by
having fields which are message types itself or by using groups.
+ The default limit is 10,000 for now. This might change in the future
to align it with other language implementation (C++ and Java use 100
as limit).
+ If pure groups (groups that don't contain message fields) are nested
deeper than the default limit the unmarshaling fails with:
proto: cannot parse invalid wire-format data
+ Note: the configured limit does not apply to pure groups.
+ This change is introduced to improve security and robustness. Because
unmarshaling is implemented using recursion it can lead to stack overflows
for certain inputs. The introduced limit protects against this.
+ A secondary motivation for this limit is the alignment with other
languages. Protocol buffers are a language interoperability mechanism
and thus either all implementations should accept the input or all
implementation should reject the input.
Change-Id: I14bdb44d06e4bd1aa90d6336c2cf6446003b2037
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/385854
Trust: Dmitri Shuralyov <dmitshur@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
Trust: Damien Neil <dneil@google.com>
Reviewed-by: Nicolas Hillegeer <aktau@google.com>
Reviewed-by: Chressie Himpel <chressie@google.com>
diff --git a/encoding/protowire/wire.go b/encoding/protowire/wire.go
index a427f8b..9c61112 100644
--- a/encoding/protowire/wire.go
+++ b/encoding/protowire/wire.go
@@ -21,10 +21,11 @@
type Number int32
const (
- MinValidNumber Number = 1
- FirstReservedNumber Number = 19000
- LastReservedNumber Number = 19999
- MaxValidNumber Number = 1<<29 - 1
+ MinValidNumber Number = 1
+ FirstReservedNumber Number = 19000
+ LastReservedNumber Number = 19999
+ MaxValidNumber Number = 1<<29 - 1
+ DefaultRecursionLimit = 10000
)
// IsValid reports whether the field number is semantically valid.
@@ -55,6 +56,7 @@
errCodeOverflow
errCodeReserved
errCodeEndGroup
+ errCodeRecursionDepth
)
var (
@@ -112,6 +114,10 @@
// When parsing a group, the length includes the end group marker and
// the end group is verified to match the starting field number.
func ConsumeFieldValue(num Number, typ Type, b []byte) (n int) {
+ return consumeFieldValueD(num, typ, b, DefaultRecursionLimit)
+}
+
+func consumeFieldValueD(num Number, typ Type, b []byte, depth int) (n int) {
switch typ {
case VarintType:
_, n = ConsumeVarint(b)
@@ -126,6 +132,9 @@
_, n = ConsumeBytes(b)
return n
case StartGroupType:
+ if depth < 0 {
+ return errCodeRecursionDepth
+ }
n0 := len(b)
for {
num2, typ2, n := ConsumeTag(b)
@@ -140,7 +149,7 @@
return n0 - len(b)
}
- n = ConsumeFieldValue(num2, typ2, b)
+ n = consumeFieldValueD(num2, typ2, b, depth-1)
if n < 0 {
return n // forward error code
}
diff --git a/internal/fuzz/wirefuzz/fuzz.go b/internal/fuzz/wirefuzz/fuzz.go
index f7a9b74..fd27cca 100644
--- a/internal/fuzz/wirefuzz/fuzz.go
+++ b/internal/fuzz/wirefuzz/fuzz.go
@@ -41,7 +41,7 @@
// Unmarshal, Validate, and CheckInitialized should agree about initialization.
checkInit := proto.CheckInitialized(m1) == nil
methods := m1.ProtoReflect().ProtoMethods()
- in := piface.UnmarshalInput{Message: mt.New(), Resolver: protoregistry.GlobalTypes}
+ in := piface.UnmarshalInput{Message: mt.New(), Resolver: protoregistry.GlobalTypes, Depth: 10000}
if checkInit {
// If the message initialized, the both Unmarshal and Validate should
// report it as such. False negatives are tolerated, but have a
diff --git a/internal/impl/decode.go b/internal/impl/decode.go
index 949dc49..c65b032 100644
--- a/internal/impl/decode.go
+++ b/internal/impl/decode.go
@@ -18,6 +18,7 @@
)
var errDecode = errors.New("cannot parse invalid wire-format data")
+var errRecursionDepth = errors.New("exceeded maximum recursion depth")
type unmarshalOptions struct {
flags protoiface.UnmarshalInputFlags
@@ -25,6 +26,7 @@
FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
}
+ depth int
}
func (o unmarshalOptions) Options() proto.UnmarshalOptions {
@@ -44,6 +46,7 @@
var lazyUnmarshalOptions = unmarshalOptions{
resolver: preg.GlobalTypes,
+ depth: protowire.DefaultRecursionLimit,
}
type unmarshalOutput struct {
@@ -62,6 +65,7 @@
out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{
flags: in.Flags,
resolver: in.Resolver,
+ depth: in.Depth,
})
var flags piface.UnmarshalOutputFlags
if out.initialized {
@@ -82,6 +86,10 @@
func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
mi.init()
+ opts.depth--
+ if opts.depth < 0 {
+ return out, errRecursionDepth
+ }
if flags.ProtoLegacy && mi.isMessageSet {
return unmarshalMessageSet(mi, b, p, opts)
}
diff --git a/proto/decode.go b/proto/decode.go
index 49f9b8c..11bf717 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -42,18 +42,25 @@
FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
}
+
+ // RecursionLimit limits how deeply messages may be nested.
+ // If zero, a default limit is applied.
+ RecursionLimit int
}
// Unmarshal parses the wire-format message in b and places the result in m.
// The provided message must be mutable (e.g., a non-nil pointer to a message).
func Unmarshal(b []byte, m Message) error {
- _, err := UnmarshalOptions{}.unmarshal(b, m.ProtoReflect())
+ _, err := UnmarshalOptions{RecursionLimit: protowire.DefaultRecursionLimit}.unmarshal(b, m.ProtoReflect())
return err
}
// Unmarshal parses the wire-format message in b and places the result in m.
// The provided message must be mutable (e.g., a non-nil pointer to a message).
func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
+ if o.RecursionLimit == 0 {
+ o.RecursionLimit = protowire.DefaultRecursionLimit
+ }
_, err := o.unmarshal(b, m.ProtoReflect())
return err
}
@@ -63,6 +70,9 @@
// This method permits fine-grained control over the unmarshaler.
// Most users should use Unmarshal instead.
func (o UnmarshalOptions) UnmarshalState(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
+ if o.RecursionLimit == 0 {
+ o.RecursionLimit = protowire.DefaultRecursionLimit
+ }
return o.unmarshal(in.Buf, in.Message)
}
@@ -86,12 +96,17 @@
Message: m,
Buf: b,
Resolver: o.Resolver,
+ Depth: o.RecursionLimit,
}
if o.DiscardUnknown {
in.Flags |= protoiface.UnmarshalDiscardUnknown
}
out, err = methods.Unmarshal(in)
} else {
+ o.RecursionLimit--
+ if o.RecursionLimit < 0 {
+ return out, errors.New("exceeded max recursion depth")
+ }
err = o.unmarshalMessageSlow(b, m)
}
if err != nil {
diff --git a/reflect/protoreflect/methods.go b/reflect/protoreflect/methods.go
index 6be5d16..d5d5af6 100644
--- a/reflect/protoreflect/methods.go
+++ b/reflect/protoreflect/methods.go
@@ -53,6 +53,7 @@
FindExtensionByName(field FullName) (ExtensionType, error)
FindExtensionByNumber(message FullName, field FieldNumber) (ExtensionType, error)
}
+ Depth int
}
unmarshalOutput = struct {
pragma.NoUnkeyedLiterals
diff --git a/runtime/protoiface/methods.go b/runtime/protoiface/methods.go
index 32c04f6..44cf467 100644
--- a/runtime/protoiface/methods.go
+++ b/runtime/protoiface/methods.go
@@ -103,6 +103,7 @@
FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
}
+ Depth int
}
// UnmarshalOutput is output from the Unmarshal method.