all: implement depth limit for unmarshaling

+ This change introduce a default and configurable depth limit for
  proto.Unmarshal. If a message is nested deeper than the limit,
  unmarshaling will fail. There are two ways to nest messages. Either by
  having fields which are message types itself or by using groups.
+ The default limit is 10,000 for now. This might change in the future
  to align it with other language implementation (C++ and Java use 100
  as limit).
+ If pure groups (groups that don't contain message fields) are nested
  deeper than the default limit the unmarshaling fails with:
  proto: cannot parse invalid wire-format data
+ Note: the configured limit does not apply to pure groups.
+ This change is introduced to improve security and robustness. Because
  unmarshaling is implemented using recursion it can lead to stack overflows
  for certain inputs. The introduced limit protects against this.
+ A secondary motivation for this limit is the alignment with other
  languages. Protocol buffers are a language interoperability mechanism
  and thus either all implementations should accept the input or all
  implementation should reject the input.

Change-Id: I14bdb44d06e4bd1aa90d6336c2cf6446003b2037
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/385854
Trust: Dmitri Shuralyov <dmitshur@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
Trust: Damien Neil <dneil@google.com>
Reviewed-by: Nicolas Hillegeer <aktau@google.com>
Reviewed-by: Chressie Himpel <chressie@google.com>
diff --git a/encoding/protowire/wire.go b/encoding/protowire/wire.go
index a427f8b..9c61112 100644
--- a/encoding/protowire/wire.go
+++ b/encoding/protowire/wire.go
@@ -21,10 +21,11 @@
 type Number int32
 
 const (
-	MinValidNumber      Number = 1
-	FirstReservedNumber Number = 19000
-	LastReservedNumber  Number = 19999
-	MaxValidNumber      Number = 1<<29 - 1
+	MinValidNumber        Number = 1
+	FirstReservedNumber   Number = 19000
+	LastReservedNumber    Number = 19999
+	MaxValidNumber        Number = 1<<29 - 1
+	DefaultRecursionLimit        = 10000
 )
 
 // IsValid reports whether the field number is semantically valid.
@@ -55,6 +56,7 @@
 	errCodeOverflow
 	errCodeReserved
 	errCodeEndGroup
+	errCodeRecursionDepth
 )
 
 var (
@@ -112,6 +114,10 @@
 // When parsing a group, the length includes the end group marker and
 // the end group is verified to match the starting field number.
 func ConsumeFieldValue(num Number, typ Type, b []byte) (n int) {
+	return consumeFieldValueD(num, typ, b, DefaultRecursionLimit)
+}
+
+func consumeFieldValueD(num Number, typ Type, b []byte, depth int) (n int) {
 	switch typ {
 	case VarintType:
 		_, n = ConsumeVarint(b)
@@ -126,6 +132,9 @@
 		_, n = ConsumeBytes(b)
 		return n
 	case StartGroupType:
+		if depth < 0 {
+			return errCodeRecursionDepth
+		}
 		n0 := len(b)
 		for {
 			num2, typ2, n := ConsumeTag(b)
@@ -140,7 +149,7 @@
 				return n0 - len(b)
 			}
 
-			n = ConsumeFieldValue(num2, typ2, b)
+			n = consumeFieldValueD(num2, typ2, b, depth-1)
 			if n < 0 {
 				return n // forward error code
 			}
diff --git a/internal/fuzz/wirefuzz/fuzz.go b/internal/fuzz/wirefuzz/fuzz.go
index f7a9b74..fd27cca 100644
--- a/internal/fuzz/wirefuzz/fuzz.go
+++ b/internal/fuzz/wirefuzz/fuzz.go
@@ -41,7 +41,7 @@
 	// Unmarshal, Validate, and CheckInitialized should agree about initialization.
 	checkInit := proto.CheckInitialized(m1) == nil
 	methods := m1.ProtoReflect().ProtoMethods()
-	in := piface.UnmarshalInput{Message: mt.New(), Resolver: protoregistry.GlobalTypes}
+	in := piface.UnmarshalInput{Message: mt.New(), Resolver: protoregistry.GlobalTypes, Depth: 10000}
 	if checkInit {
 		// If the message initialized, the both Unmarshal and Validate should
 		// report it as such. False negatives are tolerated, but have a
diff --git a/internal/impl/decode.go b/internal/impl/decode.go
index 949dc49..c65b032 100644
--- a/internal/impl/decode.go
+++ b/internal/impl/decode.go
@@ -18,6 +18,7 @@
 )
 
 var errDecode = errors.New("cannot parse invalid wire-format data")
+var errRecursionDepth = errors.New("exceeded maximum recursion depth")
 
 type unmarshalOptions struct {
 	flags    protoiface.UnmarshalInputFlags
@@ -25,6 +26,7 @@
 		FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
 		FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
 	}
+	depth int
 }
 
 func (o unmarshalOptions) Options() proto.UnmarshalOptions {
@@ -44,6 +46,7 @@
 
 var lazyUnmarshalOptions = unmarshalOptions{
 	resolver: preg.GlobalTypes,
+	depth:    protowire.DefaultRecursionLimit,
 }
 
 type unmarshalOutput struct {
@@ -62,6 +65,7 @@
 	out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{
 		flags:    in.Flags,
 		resolver: in.Resolver,
+		depth:    in.Depth,
 	})
 	var flags piface.UnmarshalOutputFlags
 	if out.initialized {
@@ -82,6 +86,10 @@
 
 func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
 	mi.init()
+	opts.depth--
+	if opts.depth < 0 {
+		return out, errRecursionDepth
+	}
 	if flags.ProtoLegacy && mi.isMessageSet {
 		return unmarshalMessageSet(mi, b, p, opts)
 	}
diff --git a/proto/decode.go b/proto/decode.go
index 49f9b8c..11bf717 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -42,18 +42,25 @@
 		FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
 		FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
 	}
+
+	// RecursionLimit limits how deeply messages may be nested.
+	// If zero, a default limit is applied.
+	RecursionLimit int
 }
 
 // Unmarshal parses the wire-format message in b and places the result in m.
 // The provided message must be mutable (e.g., a non-nil pointer to a message).
 func Unmarshal(b []byte, m Message) error {
-	_, err := UnmarshalOptions{}.unmarshal(b, m.ProtoReflect())
+	_, err := UnmarshalOptions{RecursionLimit: protowire.DefaultRecursionLimit}.unmarshal(b, m.ProtoReflect())
 	return err
 }
 
 // Unmarshal parses the wire-format message in b and places the result in m.
 // The provided message must be mutable (e.g., a non-nil pointer to a message).
 func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
+	if o.RecursionLimit == 0 {
+		o.RecursionLimit = protowire.DefaultRecursionLimit
+	}
 	_, err := o.unmarshal(b, m.ProtoReflect())
 	return err
 }
@@ -63,6 +70,9 @@
 // This method permits fine-grained control over the unmarshaler.
 // Most users should use Unmarshal instead.
 func (o UnmarshalOptions) UnmarshalState(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
+	if o.RecursionLimit == 0 {
+		o.RecursionLimit = protowire.DefaultRecursionLimit
+	}
 	return o.unmarshal(in.Buf, in.Message)
 }
 
@@ -86,12 +96,17 @@
 			Message:  m,
 			Buf:      b,
 			Resolver: o.Resolver,
+			Depth:    o.RecursionLimit,
 		}
 		if o.DiscardUnknown {
 			in.Flags |= protoiface.UnmarshalDiscardUnknown
 		}
 		out, err = methods.Unmarshal(in)
 	} else {
+		o.RecursionLimit--
+		if o.RecursionLimit < 0 {
+			return out, errors.New("exceeded max recursion depth")
+		}
 		err = o.unmarshalMessageSlow(b, m)
 	}
 	if err != nil {
diff --git a/reflect/protoreflect/methods.go b/reflect/protoreflect/methods.go
index 6be5d16..d5d5af6 100644
--- a/reflect/protoreflect/methods.go
+++ b/reflect/protoreflect/methods.go
@@ -53,6 +53,7 @@
 			FindExtensionByName(field FullName) (ExtensionType, error)
 			FindExtensionByNumber(message FullName, field FieldNumber) (ExtensionType, error)
 		}
+		Depth int
 	}
 	unmarshalOutput = struct {
 		pragma.NoUnkeyedLiterals
diff --git a/runtime/protoiface/methods.go b/runtime/protoiface/methods.go
index 32c04f6..44cf467 100644
--- a/runtime/protoiface/methods.go
+++ b/runtime/protoiface/methods.go
@@ -103,6 +103,7 @@
 		FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
 		FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
 	}
+	Depth int
 }
 
 // UnmarshalOutput is output from the Unmarshal method.