[release-branch.go1.15-security] encoding/xml: prevent infinite loop while decoding

This change properly handles a TokenReader which
returns an EOF in the middle of an open XML
element.

Thanks to Sam Whited for reporting this.

Fixes CVE-2021-27918

Change-Id: Id02a3f3def4a1b415fa2d9a8e3b373eb6cb0f433
Reviewed-on: https://team-review.git.corp.google.com/c/golang/go-private/+/1004594
Reviewed-by: Russ Cox <rsc@google.com>
Reviewed-by: Roland Shoemaker <bracewell@google.com>
Reviewed-by: Filippo Valsorda <valsorda@google.com>
(cherry picked from commit e7ce1f6746223ec7b4caa3b1ece25d9be3864710)
Reviewed-on: https://team-review.git.corp.google.com/c/golang/go-private/+/1014236
diff --git a/src/encoding/xml/xml.go b/src/encoding/xml/xml.go
index adaf4da..6f9594d 100644
--- a/src/encoding/xml/xml.go
+++ b/src/encoding/xml/xml.go
@@ -271,7 +271,7 @@
 // it will return an error.
 //
 // Token implements XML name spaces as described by
-// https://www.w3.org/TR/REC-xml-names/.  Each of the
+// https://www.w3.org/TR/REC-xml-names/. Each of the
 // Name structures contained in the Token has the Space
 // set to the URL identifying its name space when known.
 // If Token encounters an unrecognized name space prefix,
@@ -285,16 +285,17 @@
 	if d.nextToken != nil {
 		t = d.nextToken
 		d.nextToken = nil
-	} else if t, err = d.rawToken(); err != nil {
-		switch {
-		case err == io.EOF && d.t != nil:
-			err = nil
-		case err == io.EOF && d.stk != nil && d.stk.kind != stkEOF:
-			err = d.syntaxError("unexpected EOF")
+	} else {
+		if t, err = d.rawToken(); t == nil && err != nil {
+			if err == io.EOF && d.stk != nil && d.stk.kind != stkEOF {
+				err = d.syntaxError("unexpected EOF")
+			}
+			return nil, err
 		}
-		return t, err
+		// We still have a token to process, so clear any
+		// errors (e.g. EOF) and proceed.
+		err = nil
 	}
-
 	if !d.Strict {
 		if t1, ok := d.autoClose(t); ok {
 			d.nextToken = t
diff --git a/src/encoding/xml/xml_test.go b/src/encoding/xml/xml_test.go
index efddca4..5672ebb 100644
--- a/src/encoding/xml/xml_test.go
+++ b/src/encoding/xml/xml_test.go
@@ -33,30 +33,90 @@
 
 func TestDecodeEOF(t *testing.T) {
 	start := StartElement{Name: Name{Local: "test"}}
-	t.Run("EarlyEOF", func(t *testing.T) {
-		d := NewTokenDecoder(&toks{earlyEOF: true, t: []Token{
-			start,
-			start.End(),
-		}})
-		err := d.Decode(&struct {
-			XMLName Name `xml:"test"`
-		}{})
-		if err != nil {
-			t.Error(err)
+	tests := []struct {
+		name   string
+		tokens []Token
+		ok     bool
+	}{
+		{
+			name: "OK",
+			tokens: []Token{
+				start,
+				start.End(),
+			},
+			ok: true,
+		},
+		{
+			name: "Malformed",
+			tokens: []Token{
+				start,
+				StartElement{Name: Name{Local: "bad"}},
+				start.End(),
+			},
+			ok: false,
+		},
+	}
+	for _, tc := range tests {
+		for _, eof := range []bool{true, false} {
+			name := fmt.Sprintf("%s/earlyEOF=%v", tc.name, eof)
+			t.Run(name, func(t *testing.T) {
+				d := NewTokenDecoder(&toks{
+					earlyEOF: eof,
+					t:        tc.tokens,
+				})
+				err := d.Decode(&struct {
+					XMLName Name `xml:"test"`
+				}{})
+				if tc.ok && err != nil {
+					t.Fatalf("d.Decode: expected nil error, got %v", err)
+				}
+				if _, ok := err.(*SyntaxError); !tc.ok && !ok {
+					t.Errorf("d.Decode: expected syntax error, got %v", err)
+				}
+			})
 		}
-	})
-	t.Run("LateEOF", func(t *testing.T) {
-		d := NewTokenDecoder(&toks{t: []Token{
-			start,
-			start.End(),
-		}})
-		err := d.Decode(&struct {
-			XMLName Name `xml:"test"`
-		}{})
-		if err != nil {
-			t.Error(err)
+	}
+}
+
+type toksNil struct {
+	returnEOF bool
+	t         []Token
+}
+
+func (t *toksNil) Token() (Token, error) {
+	if len(t.t) == 0 {
+		if !t.returnEOF {
+			// Return nil, nil before returning an EOF. It's legal, but
+			// discouraged.
+			t.returnEOF = true
+			return nil, nil
 		}
-	})
+		return nil, io.EOF
+	}
+	var tok Token
+	tok, t.t = t.t[0], t.t[1:]
+	return tok, nil
+}
+
+func TestDecodeNilToken(t *testing.T) {
+	for _, strict := range []bool{true, false} {
+		name := fmt.Sprintf("Strict=%v", strict)
+		t.Run(name, func(t *testing.T) {
+			start := StartElement{Name: Name{Local: "test"}}
+			bad := StartElement{Name: Name{Local: "bad"}}
+			d := NewTokenDecoder(&toksNil{
+				// Malformed
+				t: []Token{start, bad, start.End()},
+			})
+			d.Strict = strict
+			err := d.Decode(&struct {
+				XMLName Name `xml:"test"`
+			}{})
+			if _, ok := err.(*SyntaxError); !ok {
+				t.Errorf("d.Decode: expected syntax error, got %v", err)
+			}
+		})
+	}
 }
 
 const testInput = `