encoding/unicode: correctly handle single-byte UTF-16 inputs (and harden transform.String)

If a single byte is passed to a UTF-16 decoder
with atEOF set, it should not ask for more src
with ErrShortSrc but return an error. Also harden
transform.String not to enter an infinite loop if a
Transformer does return ErrShortSrc with atEOF true.

Fixes #39491
Fixes CVE-2020-14040

Change-Id: If8d2a9bca4eb9b4270c98a4967d356082043e17e
Reviewed-on: https://team-review.git.corp.google.com/c/golang/go-private/+/768667
Reviewed-by: Filippo Valsorda <valsorda@google.com>
Reviewed-on: https://go-review.googlesource.com/c/text/+/238238
Run-TryBot: Katie Hockman <katie@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
diff --git a/encoding/unicode/unicode.go b/encoding/unicode/unicode.go
index f2e576d..dd99ad1 100644
--- a/encoding/unicode/unicode.go
+++ b/encoding/unicode/unicode.go
@@ -368,16 +368,13 @@
 }
 
 func (u *utf16Decoder) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) {
+	if len(src) < 2 && atEOF && u.current.bomPolicy&requireBOM != 0 {
+		return 0, 0, ErrMissingBOM
+	}
 	if len(src) == 0 {
-		if atEOF && u.current.bomPolicy&requireBOM != 0 {
-			return 0, 0, ErrMissingBOM
-		}
 		return 0, 0, nil
 	}
-	if u.current.bomPolicy&acceptBOM != 0 {
-		if len(src) < 2 {
-			return 0, 0, transform.ErrShortSrc
-		}
+	if len(src) >= 2 && u.current.bomPolicy&acceptBOM != 0 {
 		switch {
 		case src[0] == 0xfe && src[1] == 0xff:
 			u.current.endianness = BigEndian
diff --git a/encoding/unicode/unicode_test.go b/encoding/unicode/unicode_test.go
index 02520c9..9611ba5 100644
--- a/encoding/unicode/unicode_test.go
+++ b/encoding/unicode/unicode_test.go
@@ -114,6 +114,19 @@
 		err:     ErrMissingBOM,
 		t:       utf16BEEB.NewDecoder(),
 	}, {
+		desc:    "utf-16 dec: Fail on single byte missing BOM when required",
+		src:     "\x00",
+		sizeDst: 4,
+		t:       utf16BEEB.NewDecoder(),
+		err:     ErrMissingBOM,
+	}, {
+		desc:    "utf-16 dec: Fail on short src missing BOM when required",
+		src:     "\x00",
+		notEOF:  true,
+		sizeDst: 4,
+		t:       utf16BEEB.NewDecoder(),
+		err:     transform.ErrShortSrc,
+	}, {
 		desc:    "utf-16 dec: SHOULD interpret text as big-endian when BOM not present (RFC 2781:4.3)",
 		src:     "\xD8\x08\xDF\x45\x00\x3D\x00\x52\x00\x61",
 		sizeDst: 100,
@@ -121,6 +134,20 @@
 		nSrc:    10,
 		t:       utf16BEUB.NewDecoder(),
 	}, {
+		desc:    "utf-16 dec: incorrect UTF-16: odd bytes",
+		src:     "\x00",
+		sizeDst: 100,
+		want:    "\uFFFD",
+		nSrc:    1,
+		t:       utf16BEUB.NewDecoder(),
+	}, {
+		desc:    "utf-16 dec: Fail on incorrect UTF-16: short source odd bytes",
+		src:     "\x00",
+		notEOF:  true,
+		sizeDst: 100,
+		t:       utf16BEUB.NewDecoder(),
+		err:     transform.ErrShortSrc,
+	}, {
 		// This is an error according to RFC 2781. But errors in RFC 2781 are
 		// open to interpretations, so I guess this is fine.
 		desc:    "utf-16le dec: incorrect BOM is an error (RFC 2781:4.1)",
@@ -273,16 +300,23 @@
 		t:       utf16LEUB.NewDecoder(),
 	}}
 	for i, tc := range testCases {
-		b := make([]byte, tc.sizeDst)
-		nDst, nSrc, err := tc.t.Transform(b, []byte(tc.src), !tc.notEOF)
-		if err != tc.err {
-			t.Errorf("%d:%s: error was %v; want %v", i, tc.desc, err, tc.err)
-		}
-		if got := string(b[:nDst]); got != tc.want {
-			t.Errorf("%d:%s: result was %q: want %q", i, tc.desc, got, tc.want)
-		}
-		if nSrc != tc.nSrc {
-			t.Errorf("%d:%s: nSrc was %d; want %d", i, tc.desc, nSrc, tc.nSrc)
+		for j := 0; j < 2; j++ {
+			b := make([]byte, tc.sizeDst)
+			nDst, nSrc, err := tc.t.Transform(b, []byte(tc.src), !tc.notEOF)
+			if err != tc.err {
+				t.Errorf("%d:%s: error was %v; want %v", i, tc.desc, err, tc.err)
+			}
+			if got := string(b[:nDst]); got != tc.want {
+				t.Errorf("%d:%s: result was %q: want %q", i, tc.desc, got, tc.want)
+			}
+			if nSrc != tc.nSrc {
+				t.Errorf("%d:%s: nSrc was %d; want %d", i, tc.desc, nSrc, tc.nSrc)
+			}
+			// Since Transform is stateful, run failures again
+			// to ensure that the same error occurs a second time.
+			if err == nil {
+				break
+			}
 		}
 	}
 }
diff --git a/transform/transform.go b/transform/transform.go
index 520b9ad..48ec64b 100644
--- a/transform/transform.go
+++ b/transform/transform.go
@@ -648,7 +648,8 @@
 	// Transform the remaining input, growing dst and src buffers as necessary.
 	for {
 		n := copy(src, s[pSrc:])
-		nDst, nSrc, err := t.Transform(dst[pDst:], src[:n], pSrc+n == len(s))
+		atEOF := pSrc+n == len(s)
+		nDst, nSrc, err := t.Transform(dst[pDst:], src[:n], atEOF)
 		pDst += nDst
 		pSrc += nSrc
 
@@ -659,6 +660,9 @@
 				dst = grow(dst, pDst)
 			}
 		} else if err == ErrShortSrc {
+			if atEOF {
+				return string(dst[:pDst]), pSrc, err
+			}
 			if nSrc == 0 {
 				src = grow(src, 0)
 			}
diff --git a/transform/transform_test.go b/transform/transform_test.go
index 771633d..273abfa 100644
--- a/transform/transform_test.go
+++ b/transform/transform_test.go
@@ -1315,3 +1315,26 @@
 	aaa = strings.Repeat("a", 4096)
 	AAA = strings.Repeat("A", 4096)
 )
+
+type badTransformer struct{}
+
+func (bt badTransformer) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) {
+	return 0, 0, ErrShortSrc
+}
+
+func (bt badTransformer) Reset() {}
+
+func TestBadTransformer(t *testing.T) {
+	bt := badTransformer{}
+	if _, _, err := String(bt, "aaa"); err != ErrShortSrc {
+		t.Errorf("String expected ErrShortSrc, got nil")
+	}
+	if _, _, err := Bytes(bt, []byte("aaa")); err != ErrShortSrc {
+		t.Errorf("Bytes expected ErrShortSrc, got nil")
+	}
+	r := NewReader(bytes.NewReader([]byte("aaa")), bt)
+	var bytes []byte
+	if _, err := r.Read(bytes); err != ErrShortSrc {
+		t.Errorf("NewReader Read expected ErrShortSrc, got nil")
+	}
+}