encoding/pem: fix stack overflow in Decode

Previously, Decode called decodeError, a recursive function that was
prone to stack overflows when given a large PEM file containing errors.

Credit to Juho Nurminen of Mattermost who reported the error.

Fixes CVE-2022-24675
Fixes #51853

Change-Id: Iffe768be53c8ddc0036fea0671d290f8f797692c
Reviewed-on: https://team-review.git.corp.google.com/c/golang/go-private/+/1391157
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Filippo Valsorda <valsorda@google.com>
(cherry picked from commit 794ea5e828010e8b68493b2fc6d2963263195a02)
Reviewed-on: https://go-review.googlesource.com/c/go/+/399820
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Run-TryBot: Dmitri Shuralyov <dmitshur@google.com>
Auto-Submit: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: Cherry Mui <cherryyz@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/src/encoding/pem/pem.go b/src/encoding/pem/pem.go
index 146f9d0..d26e4c8 100644
--- a/src/encoding/pem/pem.go
+++ b/src/encoding/pem/pem.go
@@ -90,122 +90,96 @@
 	// pemStart begins with a newline. However, at the very beginning of
 	// the byte array, we'll accept the start string without it.
 	rest = data
-	if bytes.HasPrefix(data, pemStart[1:]) {
-		rest = rest[len(pemStart)-1 : len(data)]
-	} else if _, after, ok := bytes.Cut(data, pemStart); ok {
-		rest = after
-	} else {
-		return nil, data
-	}
-
-	typeLine, rest := getLine(rest)
-	if !bytes.HasSuffix(typeLine, pemEndOfLine) {
-		return decodeError(data, rest)
-	}
-	typeLine = typeLine[0 : len(typeLine)-len(pemEndOfLine)]
-
-	p = &Block{
-		Headers: make(map[string]string),
-		Type:    string(typeLine),
-	}
-
 	for {
-		// This loop terminates because getLine's second result is
-		// always smaller than its argument.
-		if len(rest) == 0 {
+		if bytes.HasPrefix(rest, pemStart[1:]) {
+			rest = rest[len(pemStart)-1:]
+		} else if _, after, ok := bytes.Cut(rest, pemStart); ok {
+			rest = after
+		} else {
 			return nil, data
 		}
-		line, next := getLine(rest)
 
-		key, val, ok := bytes.Cut(line, colon)
-		if !ok {
-			break
+		var typeLine []byte
+		typeLine, rest = getLine(rest)
+		if !bytes.HasSuffix(typeLine, pemEndOfLine) {
+			continue
+		}
+		typeLine = typeLine[0 : len(typeLine)-len(pemEndOfLine)]
+
+		p = &Block{
+			Headers: make(map[string]string),
+			Type:    string(typeLine),
 		}
 
-		// TODO(agl): need to cope with values that spread across lines.
-		key = bytes.TrimSpace(key)
-		val = bytes.TrimSpace(val)
-		p.Headers[string(key)] = string(val)
-		rest = next
+		for {
+			// This loop terminates because getLine's second result is
+			// always smaller than its argument.
+			if len(rest) == 0 {
+				return nil, data
+			}
+			line, next := getLine(rest)
+
+			key, val, ok := bytes.Cut(line, colon)
+			if !ok {
+				break
+			}
+
+			// TODO(agl): need to cope with values that spread across lines.
+			key = bytes.TrimSpace(key)
+			val = bytes.TrimSpace(val)
+			p.Headers[string(key)] = string(val)
+			rest = next
+		}
+
+		var endIndex, endTrailerIndex int
+
+		// If there were no headers, the END line might occur
+		// immediately, without a leading newline.
+		if len(p.Headers) == 0 && bytes.HasPrefix(rest, pemEnd[1:]) {
+			endIndex = 0
+			endTrailerIndex = len(pemEnd) - 1
+		} else {
+			endIndex = bytes.Index(rest, pemEnd)
+			endTrailerIndex = endIndex + len(pemEnd)
+		}
+
+		if endIndex < 0 {
+			continue
+		}
+
+		// After the "-----" of the ending line, there should be the same type
+		// and then a final five dashes.
+		endTrailer := rest[endTrailerIndex:]
+		endTrailerLen := len(typeLine) + len(pemEndOfLine)
+		if len(endTrailer) < endTrailerLen {
+			continue
+		}
+
+		restOfEndLine := endTrailer[endTrailerLen:]
+		endTrailer = endTrailer[:endTrailerLen]
+		if !bytes.HasPrefix(endTrailer, typeLine) ||
+			!bytes.HasSuffix(endTrailer, pemEndOfLine) {
+			continue
+		}
+
+		// The line must end with only whitespace.
+		if s, _ := getLine(restOfEndLine); len(s) != 0 {
+			continue
+		}
+
+		base64Data := removeSpacesAndTabs(rest[:endIndex])
+		p.Bytes = make([]byte, base64.StdEncoding.DecodedLen(len(base64Data)))
+		n, err := base64.StdEncoding.Decode(p.Bytes, base64Data)
+		if err != nil {
+			continue
+		}
+		p.Bytes = p.Bytes[:n]
+
+		// the -1 is because we might have only matched pemEnd without the
+		// leading newline if the PEM block was empty.
+		_, rest = getLine(rest[endIndex+len(pemEnd)-1:])
+		return p, rest
 	}
-
-	var endIndex, endTrailerIndex int
-
-	// If there were no headers, the END line might occur
-	// immediately, without a leading newline.
-	if len(p.Headers) == 0 && bytes.HasPrefix(rest, pemEnd[1:]) {
-		endIndex = 0
-		endTrailerIndex = len(pemEnd) - 1
-	} else {
-		endIndex = bytes.Index(rest, pemEnd)
-		endTrailerIndex = endIndex + len(pemEnd)
-	}
-
-	if endIndex < 0 {
-		return decodeError(data, rest)
-	}
-
-	// After the "-----" of the ending line, there should be the same type
-	// and then a final five dashes.
-	endTrailer := rest[endTrailerIndex:]
-	endTrailerLen := len(typeLine) + len(pemEndOfLine)
-	if len(endTrailer) < endTrailerLen {
-		return decodeError(data, rest)
-	}
-
-	restOfEndLine := endTrailer[endTrailerLen:]
-	endTrailer = endTrailer[:endTrailerLen]
-	if !bytes.HasPrefix(endTrailer, typeLine) ||
-		!bytes.HasSuffix(endTrailer, pemEndOfLine) {
-		return decodeError(data, rest)
-	}
-
-	// The line must end with only whitespace.
-	if s, _ := getLine(restOfEndLine); len(s) != 0 {
-		return decodeError(data, rest)
-	}
-
-	base64Data := removeSpacesAndTabs(rest[:endIndex])
-	p.Bytes = make([]byte, base64.StdEncoding.DecodedLen(len(base64Data)))
-	n, err := base64.StdEncoding.Decode(p.Bytes, base64Data)
-	if err != nil {
-		return decodeError(data, rest)
-	}
-	p.Bytes = p.Bytes[:n]
-
-	// the -1 is because we might have only matched pemEnd without the
-	// leading newline if the PEM block was empty.
-	_, rest = getLine(rest[endIndex+len(pemEnd)-1:])
-
-	return
-}
-
-func decodeError(data, rest []byte) (*Block, []byte) {
-	// If we get here then we have rejected a likely looking, but
-	// ultimately invalid PEM block. We need to start over from a new
-	// position. We have consumed the preamble line and will have consumed
-	// any lines which could be header lines. However, a valid preamble
-	// line is not a valid header line, therefore we cannot have consumed
-	// the preamble line for the any subsequent block. Thus, we will always
-	// find any valid block, no matter what bytes precede it.
-	//
-	// For example, if the input is
-	//
-	//    -----BEGIN MALFORMED BLOCK-----
-	//    junk that may look like header lines
-	//   or data lines, but no END line
-	//
-	//    -----BEGIN ACTUAL BLOCK-----
-	//    realdata
-	//    -----END ACTUAL BLOCK-----
-	//
-	// we've failed to parse using the first BEGIN line
-	// and now will try again, using the second BEGIN line.
-	p, rest := Decode(rest)
-	if p == nil {
-		rest = data
-	}
-	return p, rest
 }
 
 const pemLineLength = 64
diff --git a/src/encoding/pem/pem_test.go b/src/encoding/pem/pem_test.go
index b2b6b15..c94b5ca 100644
--- a/src/encoding/pem/pem_test.go
+++ b/src/encoding/pem/pem_test.go
@@ -107,6 +107,12 @@
 dGVzdA==
 -----ENDBAR-----`
 
+const pemMissingEndLine = `
+-----BEGIN FOO-----
+Header: 1`
+
+var pemRepeatingBegin = strings.Repeat("-----BEGIN \n", 10)
+
 var badPEMTests = []struct {
 	name  string
 	input string
@@ -131,14 +137,34 @@
 		"missing ending space",
 		pemMissingEndingSpace,
 	},
+	{
+		"repeating begin",
+		pemRepeatingBegin,
+	},
+	{
+		"missing end line",
+		pemMissingEndLine,
+	},
 }
 
 func TestBadDecode(t *testing.T) {
 	for _, test := range badPEMTests {
-		result, _ := Decode([]byte(test.input))
+		result, rest := Decode([]byte(test.input))
 		if result != nil {
 			t.Errorf("unexpected success while parsing %q", test.name)
 		}
+		if string(rest) != test.input {
+			t.Errorf("unexpected rest: %q; want = %q", rest, test.input)
+		}
+	}
+}
+
+func TestCVE202224675(t *testing.T) {
+	// Prior to CVE-2022-24675, this input would cause a stack overflow.
+	input := []byte(strings.Repeat("-----BEGIN \n", 10000000))
+	result, rest := Decode(input)
+	if result != nil || !reflect.DeepEqual(rest, input) {
+		t.Errorf("Encode of %#v decoded as %#v", input, rest)
 	}
 }