encoding/xml: limit depth of nesting in unmarshal

Prevent exhausting the stack limit when unmarshalling extremely deeply
nested structures into nested types.

Fixes #53611
Fixes CVE-2022-30633

Change-Id: Ic6c5d41674c93cfc9a316135a408db9156d39c59
Reviewed-on: https://team-review.git.corp.google.com/c/golang/go-private/+/1421319
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Julie Qiu <julieqiu@google.com>
Reviewed-on: https://go-review.googlesource.com/c/go/+/417061
Run-TryBot: Michael Knyszek <mknyszek@google.com>
Reviewed-by: Heschi Kreinick <heschi@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/src/encoding/xml/read.go b/src/encoding/xml/read.go
index 2575912..0161306 100644
--- a/src/encoding/xml/read.go
+++ b/src/encoding/xml/read.go
@@ -152,7 +152,7 @@
 	if val.IsNil() {
 		return errors.New("nil pointer passed to Unmarshal")
 	}
-	return d.unmarshal(val.Elem(), start)
+	return d.unmarshal(val.Elem(), start, 0)
 }
 
 // An UnmarshalError represents an error in the unmarshaling process.
@@ -308,8 +308,15 @@
 	textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
 )
 
+const maxUnmarshalDepth = 10000
+
+var errExeceededMaxUnmarshalDepth = errors.New("exceeded max depth")
+
 // Unmarshal a single XML element into val.
-func (d *Decoder) unmarshal(val reflect.Value, start *StartElement) error {
+func (d *Decoder) unmarshal(val reflect.Value, start *StartElement, depth int) error {
+	if depth >= maxUnmarshalDepth {
+		return errExeceededMaxUnmarshalDepth
+	}
 	// Find start element if we need it.
 	if start == nil {
 		for {
@@ -402,7 +409,7 @@
 		v.Set(reflect.Append(val, reflect.Zero(v.Type().Elem())))
 
 		// Recur to read element into slice.
-		if err := d.unmarshal(v.Index(n), start); err != nil {
+		if err := d.unmarshal(v.Index(n), start, depth+1); err != nil {
 			v.SetLen(n)
 			return err
 		}
@@ -525,13 +532,15 @@
 		case StartElement:
 			consumed := false
 			if sv.IsValid() {
-				consumed, err = d.unmarshalPath(tinfo, sv, nil, &t)
+				// unmarshalPath can call unmarshal, so we need to pass the depth through so that
+				// we can continue to enforce the maximum recusion limit.
+				consumed, err = d.unmarshalPath(tinfo, sv, nil, &t, depth)
 				if err != nil {
 					return err
 				}
 				if !consumed && saveAny.IsValid() {
 					consumed = true
-					if err := d.unmarshal(saveAny, &t); err != nil {
+					if err := d.unmarshal(saveAny, &t, depth+1); err != nil {
 						return err
 					}
 				}
@@ -676,7 +685,7 @@
 // The consumed result tells whether XML elements have been consumed
 // from the Decoder until start's matching end element, or if it's
 // still untouched because start is uninteresting for sv's fields.
-func (d *Decoder) unmarshalPath(tinfo *typeInfo, sv reflect.Value, parents []string, start *StartElement) (consumed bool, err error) {
+func (d *Decoder) unmarshalPath(tinfo *typeInfo, sv reflect.Value, parents []string, start *StartElement, depth int) (consumed bool, err error) {
 	recurse := false
 Loop:
 	for i := range tinfo.fields {
@@ -691,7 +700,7 @@
 		}
 		if len(finfo.parents) == len(parents) && finfo.name == start.Name.Local {
 			// It's a perfect match, unmarshal the field.
-			return true, d.unmarshal(finfo.value(sv, initNilPointers), start)
+			return true, d.unmarshal(finfo.value(sv, initNilPointers), start, depth+1)
 		}
 		if len(finfo.parents) > len(parents) && finfo.parents[len(parents)] == start.Name.Local {
 			// It's a prefix for the field. Break and recurse
@@ -720,7 +729,9 @@
 		}
 		switch t := tok.(type) {
 		case StartElement:
-			consumed2, err := d.unmarshalPath(tinfo, sv, parents, &t)
+			// the recursion depth of unmarshalPath is limited to the path length specified
+			// by the struct field tag, so we don't increment the depth here.
+			consumed2, err := d.unmarshalPath(tinfo, sv, parents, &t, depth)
 			if err != nil {
 				return true, err
 			}
diff --git a/src/encoding/xml/read_test.go b/src/encoding/xml/read_test.go
index 6ef55de..42059db 100644
--- a/src/encoding/xml/read_test.go
+++ b/src/encoding/xml/read_test.go
@@ -5,6 +5,8 @@
 package xml
 
 import (
+	"bytes"
+	"errors"
 	"io"
 	"reflect"
 	"strings"
@@ -1094,3 +1096,16 @@
 	}
 
 }
+
+func TestCVE202228131(t *testing.T) {
+	type nested struct {
+		Parent *nested `xml:",any"`
+	}
+	var n nested
+	err := Unmarshal(bytes.Repeat([]byte("<a>"), maxUnmarshalDepth+1), &n)
+	if err == nil {
+		t.Fatal("Unmarshal did not fail")
+	} else if !errors.Is(err, errExeceededMaxUnmarshalDepth) {
+		t.Fatalf("Unmarshal unexpected error: got %q, want %q", err, errExeceededMaxUnmarshalDepth)
+	}
+}