riff: fix some short chunk data bugs.

Fixes golang/go#16236

Change-Id: I0e524054d0702a6487ff47d86aed6bf58f4ba3f2
Reviewed-on: https://go-review.googlesource.com/24638
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/riff/riff.go b/riff/riff.go
index 9b9f71d..38dc0e5 100644
--- a/riff/riff.go
+++ b/riff/riff.go
@@ -23,6 +23,7 @@
 var (
 	errMissingPaddingByte     = errors.New("riff: missing padding byte")
 	errMissingRIFFChunkHeader = errors.New("riff: missing RIFF chunk header")
+	errListSubchunkTooLong    = errors.New("riff: list subchunk too long")
 	errShortChunkData         = errors.New("riff: short chunk data")
 	errShortChunkHeader       = errors.New("riff: short chunk header")
 	errStaleReader            = errors.New("riff: stale reader")
@@ -100,13 +101,23 @@
 
 	// Drain the rest of the previous chunk.
 	if z.chunkLen != 0 {
-		_, z.err = io.Copy(ioutil.Discard, z.chunkReader)
+		want := z.chunkLen
+		var got int64
+		got, z.err = io.Copy(ioutil.Discard, z.chunkReader)
+		if z.err == nil && uint32(got) != want {
+			z.err = errShortChunkData
+		}
 		if z.err != nil {
 			return FourCC{}, 0, nil, z.err
 		}
 	}
 	z.chunkReader = nil
 	if z.padded {
+		if z.totalLen == 0 {
+			z.err = errListSubchunkTooLong
+			return FourCC{}, 0, nil, z.err
+		}
+		z.totalLen--
 		_, z.err = io.ReadFull(z.r, z.buf[:1])
 		if z.err != nil {
 			if z.err == io.EOF {
@@ -114,7 +125,6 @@
 			}
 			return FourCC{}, 0, nil, z.err
 		}
-		z.totalLen--
 	}
 
 	// We are done if we have no more data.
@@ -129,7 +139,7 @@
 		return FourCC{}, 0, nil, z.err
 	}
 	z.totalLen -= chunkHeaderSize
-	if _, err = io.ReadFull(z.r, z.buf[:chunkHeaderSize]); err != nil {
+	if _, z.err = io.ReadFull(z.r, z.buf[:chunkHeaderSize]); z.err != nil {
 		if z.err == io.EOF || z.err == io.ErrUnexpectedEOF {
 			z.err = errShortChunkHeader
 		}
@@ -137,6 +147,10 @@
 	}
 	chunkID = FourCC{z.buf[0], z.buf[1], z.buf[2], z.buf[3]}
 	z.chunkLen = u32(z.buf[4:])
+	if z.chunkLen > z.totalLen {
+		z.err = errListSubchunkTooLong
+		return FourCC{}, 0, nil, z.err
+	}
 	z.padded = z.chunkLen&1 == 1
 	z.chunkReader = &chunkReader{z}
 	return chunkID, z.chunkLen, z.chunkReader, nil
diff --git a/riff/riff_test.go b/riff/riff_test.go
new file mode 100644
index 0000000..567e938
--- /dev/null
+++ b/riff/riff_test.go
@@ -0,0 +1,69 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package riff
+
+import (
+	"bytes"
+	"testing"
+)
+
+func encodeU32(u uint32) []byte {
+	return []byte{
+		byte(u >> 0),
+		byte(u >> 8),
+		byte(u >> 16),
+		byte(u >> 24),
+	}
+}
+
+func TestShortChunks(t *testing.T) {
+	// s is a RIFF(ABCD) with allegedly 256 bytes of data (excluding the
+	// leading 8-byte "RIFF\x00\x01\x00\x00"). The first chunk of that ABCD
+	// list is an abcd chunk of length m followed by n zeroes.
+	for _, m := range []uint32{0, 8, 15, 200, 300} {
+		for _, n := range []int{0, 1, 2, 7} {
+			s := []byte("RIFF\x00\x01\x00\x00ABCDabcd")
+			s = append(s, encodeU32(m)...)
+			s = append(s, make([]byte, n)...)
+			_, r, err := NewReader(bytes.NewReader(s))
+			if err != nil {
+				t.Errorf("m=%d, n=%d: NewReader: %v", m, n, err)
+				continue
+			}
+
+			_, _, _, err0 := r.Next()
+			// The total "ABCD" list length is 256 bytes, of which the first 12
+			// bytes are "ABCDabcd" plus the 4-byte encoding of m. If the
+			// "abcd" subchunk length (m) plus those 12 bytes is greater than
+			// the total list length, we have an invalid RIFF, and we expect an
+			// errListSubchunkTooLong error.
+			if m+12 > 256 {
+				if err0 != errListSubchunkTooLong {
+					t.Errorf("m=%d, n=%d: Next #0: got %v, want %v", m, n, err0, errListSubchunkTooLong)
+				}
+				continue
+			}
+			// Otherwise, we expect a nil error.
+			if err0 != nil {
+				t.Errorf("m=%d, n=%d: Next #0: %v", m, n, err0)
+				continue
+			}
+
+			_, _, _, err1 := r.Next()
+			// If m > 0, then m > n, so that "abcd" subchunk doesn't have m
+			// bytes of data. If m == 0, then that "abcd" subchunk is OK in
+			// that it has 0 extra bytes of data, but the next subchunk (8 byte
+			// header plus body) is missing, as we only have n < 8 more bytes.
+			want := errShortChunkData
+			if m == 0 {
+				want = errShortChunkHeader
+			}
+			if err1 != want {
+				t.Errorf("m=%d, n=%d: Next #1: got %v, want %v", m, n, err1, want)
+				continue
+			}
+		}
+	}
+}