tiff: limit work when decoding malicious images

Fix two paths by which a malicious image could cause unreasonable
amounts of CPU consumption while decoding.

Avoid iterating over every horizontal pixel when decoding
a 0-height tiled image.

Limit the amount of data that will be decompressed per tile.

Thanks to Philippe Antoine (Catena cyber) for reporting this issue.

Fixes CVE-2023-29407
Fixes CVE-2023-29408
Fixes golang/go#61581
Fixes golang/go#61582

Change-Id: I8cbb26fa06843c6fe9fa99810cb1315431fa7d1d
Reviewed-on: https://go-review.googlesource.com/c/image/+/514897
Reviewed-by: Roland Shoemaker <roland@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Auto-Submit: Damien Neil <dneil@google.com>
Run-TryBot: Damien Neil <dneil@google.com>
diff --git a/tiff/reader.go b/tiff/reader.go
index 45cc056..f31569b 100644
--- a/tiff/reader.go
+++ b/tiff/reader.go
@@ -8,13 +8,13 @@
 package tiff // import "golang.org/x/image/tiff"
 
 import (
+	"bytes"
 	"compress/zlib"
 	"encoding/binary"
 	"fmt"
 	"image"
 	"image/color"
 	"io"
-	"io/ioutil"
 	"math"
 
 	"golang.org/x/image/ccitt"
@@ -579,6 +579,11 @@
 	default:
 		return nil, UnsupportedError("color model")
 	}
+	if d.firstVal(tPhotometricInterpretation) != pRGB {
+		if len(d.features[tBitsPerSample]) != 1 {
+			return nil, UnsupportedError("extra samples")
+		}
+	}
 
 	return d, nil
 }
@@ -629,6 +634,13 @@
 		blockWidth = int(d.firstVal(tTileWidth))
 		blockHeight = int(d.firstVal(tTileLength))
 
+		// The specification says that tile widths and lengths must be a multiple of 16.
+		// We currently permit invalid sizes, but reject anything too small to limit the
+		// amount of work a malicious input can force us to perform.
+		if blockWidth < 8 || blockHeight < 8 {
+			return nil, FormatError("tile size is too small")
+		}
+
 		if blockWidth != 0 {
 			blocksAcross = (d.config.Width + blockWidth - 1) / blockWidth
 		}
@@ -681,6 +693,11 @@
 		}
 	}
 
+	if blocksAcross == 0 || blocksDown == 0 {
+		return
+	}
+	// Maximum data per pixel is 8 bytes (RGBA64).
+	blockMaxDataSize := int64(blockWidth) * int64(blockHeight) * 8
 	for i := 0; i < blocksAcross; i++ {
 		blkW := blockWidth
 		if !blockPadding && i == blocksAcross-1 && d.config.Width%blockWidth != 0 {
@@ -708,15 +725,15 @@
 				inv := d.firstVal(tPhotometricInterpretation) == pWhiteIsZero
 				order := ccittFillOrder(d.firstVal(tFillOrder))
 				r := ccitt.NewReader(io.NewSectionReader(d.r, offset, n), order, ccitt.Group3, blkW, blkH, &ccitt.Options{Invert: inv, Align: false})
-				d.buf, err = ioutil.ReadAll(r)
+				d.buf, err = readBuf(r, d.buf, blockMaxDataSize)
 			case cG4:
 				inv := d.firstVal(tPhotometricInterpretation) == pWhiteIsZero
 				order := ccittFillOrder(d.firstVal(tFillOrder))
 				r := ccitt.NewReader(io.NewSectionReader(d.r, offset, n), order, ccitt.Group4, blkW, blkH, &ccitt.Options{Invert: inv, Align: false})
-				d.buf, err = ioutil.ReadAll(r)
+				d.buf, err = readBuf(r, d.buf, blockMaxDataSize)
 			case cLZW:
 				r := lzw.NewReader(io.NewSectionReader(d.r, offset, n), lzw.MSB, 8)
-				d.buf, err = ioutil.ReadAll(r)
+				d.buf, err = readBuf(r, d.buf, blockMaxDataSize)
 				r.Close()
 			case cDeflate, cDeflateOld:
 				var r io.ReadCloser
@@ -724,7 +741,7 @@
 				if err != nil {
 					return nil, err
 				}
-				d.buf, err = ioutil.ReadAll(r)
+				d.buf, err = readBuf(r, d.buf, blockMaxDataSize)
 				r.Close()
 			case cPackBits:
 				d.buf, err = unpackBits(io.NewSectionReader(d.r, offset, n))
@@ -748,6 +765,12 @@
 	return
 }
 
+func readBuf(r io.Reader, buf []byte, lim int64) ([]byte, error) {
+	b := bytes.NewBuffer(buf[:0])
+	_, err := b.ReadFrom(io.LimitReader(r, lim))
+	return b.Bytes(), err
+}
+
 func init() {
 	image.RegisterFormat("tiff", leHeader, Decode, DecodeConfig)
 	image.RegisterFormat("tiff", beHeader, Decode, DecodeConfig)
diff --git a/tiff/reader_test.go b/tiff/reader_test.go
index f91fd94..4777fd2 100644
--- a/tiff/reader_test.go
+++ b/tiff/reader_test.go
@@ -6,13 +6,16 @@
 
 import (
 	"bytes"
+	"compress/zlib"
 	"encoding/binary"
 	"encoding/hex"
 	"errors"
+	"fmt"
 	"image"
 	"io"
 	"io/ioutil"
 	"os"
+	"sort"
 	"strings"
 	"testing"
 
@@ -414,13 +417,17 @@
 // benchmarkDecode benchmarks the decoding of an image.
 func benchmarkDecode(b *testing.B, filename string) {
 	b.Helper()
-	b.StopTimer()
 	contents, err := ioutil.ReadFile(testdataDir + filename)
 	if err != nil {
 		b.Fatal(err)
 	}
-	r := &buffer{buf: contents}
-	b.StartTimer()
+	benchmarkDecodeData(b, contents)
+}
+
+func benchmarkDecodeData(b *testing.B, data []byte) {
+	b.Helper()
+	r := &buffer{buf: data}
+	b.ResetTimer()
 	for i := 0; i < b.N; i++ {
 		_, err := Decode(r)
 		if err != nil {
@@ -431,3 +438,148 @@
 
 func BenchmarkDecodeCompressed(b *testing.B)   { benchmarkDecode(b, "video-001.tiff") }
 func BenchmarkDecodeUncompressed(b *testing.B) { benchmarkDecode(b, "video-001-uncompressed.tiff") }
+
+func BenchmarkZeroHeightTile(b *testing.B) {
+	enc := binary.BigEndian
+	data := newTIFF(enc)
+	data = appendIFD(data, enc, map[uint16]interface{}{
+		tImageWidth:  uint32(4294967295),
+		tImageLength: uint32(0),
+		tTileWidth:   uint32(1),
+		tTileLength:  uint32(0),
+	})
+	benchmarkDecodeData(b, data)
+}
+
+func BenchmarkRepeatedOversizedTileData(b *testing.B) {
+	const (
+		imageWidth  = 256
+		imageHeight = 256
+		tileWidth   = 8
+		tileLength  = 8
+		numTiles    = (imageWidth * imageHeight) / (tileWidth * tileLength)
+	)
+
+	// Create a chunk of tile data that decompresses to a large size.
+	zdata := func() []byte {
+		var zbuf bytes.Buffer
+		zw := zlib.NewWriter(&zbuf)
+		zeros := make([]byte, 1024)
+		for i := 0; i < 1<<16; i++ {
+			zw.Write(zeros)
+		}
+		zw.Close()
+		return zbuf.Bytes()
+	}()
+
+	enc := binary.BigEndian
+	data := newTIFF(enc)
+
+	zoff := len(data)
+	data = append(data, zdata...)
+
+	// Each tile refers to the same compressed data chunk.
+	var tileoffs []uint32
+	var tilesizes []uint32
+	for i := 0; i < numTiles; i++ {
+		tileoffs = append(tileoffs, uint32(zoff))
+		tilesizes = append(tilesizes, uint32(len(zdata)))
+	}
+
+	data = appendIFD(data, enc, map[uint16]interface{}{
+		tImageWidth:                uint32(imageWidth),
+		tImageLength:               uint32(imageHeight),
+		tTileWidth:                 uint32(tileWidth),
+		tTileLength:                uint32(tileLength),
+		tTileOffsets:               tileoffs,
+		tTileByteCounts:            tilesizes,
+		tCompression:               uint16(cDeflate),
+		tBitsPerSample:             []uint16{16, 16, 16},
+		tPhotometricInterpretation: uint16(pRGB),
+	})
+	benchmarkDecodeData(b, data)
+}
+
+type byteOrder interface {
+	binary.ByteOrder
+	binary.AppendByteOrder
+}
+
+// newTIFF returns the TIFF header.
+func newTIFF(enc byteOrder) []byte {
+	b := []byte{0, 0, 0, 42, 0, 0, 0, 0}
+	switch enc.Uint16([]byte{1, 0}) {
+	case 0x1:
+		b[0], b[1] = 'I', 'I'
+	case 0x100:
+		b[0], b[1] = 'M', 'M'
+	default:
+		panic("odd byte order")
+	}
+	return b
+}
+
+// appendIFD appends an IFD to the TIFF in b,
+// updating the IFD location in the header.
+func appendIFD(b []byte, enc byteOrder, entries map[uint16]interface{}) []byte {
+	var tags []uint16
+	for tag := range entries {
+		tags = append(tags, tag)
+	}
+	sort.Slice(tags, func(i, j int) bool {
+		return tags[i] < tags[j]
+	})
+
+	var ifd []byte
+	for _, tag := range tags {
+		ifd = enc.AppendUint16(ifd, tag)
+		switch v := entries[tag].(type) {
+		case uint16:
+			ifd = enc.AppendUint16(ifd, dtShort)
+			ifd = enc.AppendUint32(ifd, 1)
+			ifd = enc.AppendUint16(ifd, v)
+			ifd = enc.AppendUint16(ifd, v)
+		case uint32:
+			ifd = enc.AppendUint16(ifd, dtLong)
+			ifd = enc.AppendUint32(ifd, 1)
+			ifd = enc.AppendUint32(ifd, v)
+		case []uint16:
+			ifd = enc.AppendUint16(ifd, dtShort)
+			ifd = enc.AppendUint32(ifd, uint32(len(v)))
+			switch len(v) {
+			case 0:
+				ifd = enc.AppendUint32(ifd, 0)
+			case 1:
+				ifd = enc.AppendUint16(ifd, v[0])
+				ifd = enc.AppendUint16(ifd, v[1])
+			default:
+				ifd = enc.AppendUint32(ifd, uint32(len(b)))
+				for _, e := range v {
+					b = enc.AppendUint16(b, e)
+				}
+			}
+		case []uint32:
+			ifd = enc.AppendUint16(ifd, dtLong)
+			ifd = enc.AppendUint32(ifd, uint32(len(v)))
+			switch len(v) {
+			case 0:
+				ifd = enc.AppendUint32(ifd, 0)
+			case 1:
+				ifd = enc.AppendUint32(ifd, v[0])
+			default:
+				ifd = enc.AppendUint32(ifd, uint32(len(b)))
+				for _, e := range v {
+					b = enc.AppendUint32(b, e)
+				}
+			}
+		default:
+			panic(fmt.Errorf("unhandled type %T", v))
+		}
+	}
+
+	enc.PutUint32(b[4:8], uint32(len(b)))
+	b = enc.AppendUint16(b, uint16(len(entries)))
+	b = append(b, ifd...)
+	b = enc.AppendUint32(b, 0)
+	return b
+}