bmp: Add support for writing bitmaps with alpha channels

This also fixes the writer test to actually compare the decoded images,
so as to make sure it is decoded properly.

Implements golang/go#25945.

Change-Id: I606887baa11b7664018313cf7d5800b2dc7622cf
Reviewed-on: https://go-review.googlesource.com/120095
Reviewed-by: Nigel Tao <nigeltao@golang.org>
Run-TryBot: Nigel Tao <nigeltao@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/bmp/writer.go b/bmp/writer.go
index 6947968..f07b39d 100644
--- a/bmp/writer.go
+++ b/bmp/writer.go
@@ -49,20 +49,91 @@
 	return nil
 }
 
-func encodeRGBA(w io.Writer, pix []uint8, dx, dy, stride, step int) error {
+func encodeRGBA(w io.Writer, pix []uint8, dx, dy, stride, step int, opaque bool) error {
 	buf := make([]byte, step)
-	for y := dy - 1; y >= 0; y-- {
-		min := y*stride + 0
-		max := y*stride + dx*4
-		off := 0
-		for i := min; i < max; i += 4 {
-			buf[off+2] = pix[i+0]
-			buf[off+1] = pix[i+1]
-			buf[off+0] = pix[i+2]
-			off += 3
+	if opaque {
+		for y := dy - 1; y >= 0; y-- {
+			min := y*stride + 0
+			max := y*stride + dx*4
+			off := 0
+			for i := min; i < max; i += 4 {
+				buf[off+2] = pix[i+0]
+				buf[off+1] = pix[i+1]
+				buf[off+0] = pix[i+2]
+				off += 3
+			}
+			if _, err := w.Write(buf); err != nil {
+				return err
+			}
 		}
-		if _, err := w.Write(buf); err != nil {
-			return err
+	} else {
+		for y := dy - 1; y >= 0; y-- {
+			min := y*stride + 0
+			max := y*stride + dx*4
+			off := 0
+			for i := min; i < max; i += 4 {
+				a := uint32(pix[i+3])
+				if a == 0 {
+					buf[off+2] = 0
+					buf[off+1] = 0
+					buf[off+0] = 0
+					buf[off+3] = 0
+					off += 4
+					continue
+				} else if a == 0xff {
+					buf[off+2] = pix[i+0]
+					buf[off+1] = pix[i+1]
+					buf[off+0] = pix[i+2]
+					buf[off+3] = 0xff
+					off += 4
+					continue
+				}
+				buf[off+2] = uint8(((uint32(pix[i+0]) * 0xffff) / a) >> 8)
+				buf[off+1] = uint8(((uint32(pix[i+1]) * 0xffff) / a) >> 8)
+				buf[off+0] = uint8(((uint32(pix[i+2]) * 0xffff) / a) >> 8)
+				buf[off+3] = uint8(a)
+				off += 4
+			}
+			if _, err := w.Write(buf); err != nil {
+				return err
+			}
+		}
+	}
+	return nil
+}
+
+func encodeNRGBA(w io.Writer, pix []uint8, dx, dy, stride, step int, opaque bool) error {
+	buf := make([]byte, step)
+	if opaque {
+		for y := dy - 1; y >= 0; y-- {
+			min := y*stride + 0
+			max := y*stride + dx*4
+			off := 0
+			for i := min; i < max; i += 4 {
+				buf[off+2] = pix[i+0]
+				buf[off+1] = pix[i+1]
+				buf[off+0] = pix[i+2]
+				off += 3
+			}
+			if _, err := w.Write(buf); err != nil {
+				return err
+			}
+		}
+	} else {
+		for y := dy - 1; y >= 0; y-- {
+			min := y*stride + 0
+			max := y*stride + dx*4
+			off := 0
+			for i := min; i < max; i += 4 {
+				buf[off+2] = pix[i+0]
+				buf[off+1] = pix[i+1]
+				buf[off+0] = pix[i+2]
+				buf[off+3] = pix[i+3]
+				off += 4
+			}
+			if _, err := w.Write(buf); err != nil {
+				return err
+			}
 		}
 	}
 	return nil
@@ -105,6 +176,7 @@
 
 	var step int
 	var palette []byte
+	var opaque bool
 	switch m := m.(type) {
 	case *image.Gray:
 		step = (d.X + 3) &^ 3
@@ -134,6 +206,28 @@
 		h.fileSize += uint32(len(palette)) + h.imageSize
 		h.pixOffset += uint32(len(palette))
 		h.bpp = 8
+	case *image.RGBA:
+		opaque = m.Opaque()
+		if opaque {
+			step = (3*d.X + 3) &^ 3
+			h.bpp = 24
+		} else {
+			step = 4 * d.X
+			h.bpp = 32
+		}
+		h.imageSize = uint32(d.Y * step)
+		h.fileSize += h.imageSize
+	case *image.NRGBA:
+		opaque = m.Opaque()
+		if opaque {
+			step = (3*d.X + 3) &^ 3
+			h.bpp = 24
+		} else {
+			step = 4 * d.X
+			h.bpp = 32
+		}
+		h.imageSize = uint32(d.Y * step)
+		h.fileSize += h.imageSize
 	default:
 		step = (3*d.X + 3) &^ 3
 		h.imageSize = uint32(d.Y * step)
@@ -160,7 +254,9 @@
 	case *image.Paletted:
 		return encodePaletted(w, m.Pix, d.X, d.Y, m.Stride, step)
 	case *image.RGBA:
-		return encodeRGBA(w, m.Pix, d.X, d.Y, m.Stride, step)
+		return encodeRGBA(w, m.Pix, d.X, d.Y, m.Stride, step, opaque)
+	case *image.NRGBA:
+		return encodeNRGBA(w, m.Pix, d.X, d.Y, m.Stride, step, opaque)
 	}
 	return encode(w, m, step)
 }
diff --git a/bmp/writer_test.go b/bmp/writer_test.go
index 9e5a327..1643e4d 100644
--- a/bmp/writer_test.go
+++ b/bmp/writer_test.go
@@ -8,6 +8,7 @@
 	"bytes"
 	"fmt"
 	"image"
+	"image/draw"
 	"io/ioutil"
 	"os"
 	"testing"
@@ -23,24 +24,75 @@
 	return Decode(f)
 }
 
+func convertToRGBA(in image.Image) image.Image {
+	b := in.Bounds()
+	out := image.NewRGBA(b)
+	draw.Draw(out, b, in, b.Min, draw.Src)
+	return out
+}
+
+func convertToNRGBA(in image.Image) image.Image {
+	b := in.Bounds()
+	out := image.NewNRGBA(b)
+	draw.Draw(out, b, in, b.Min, draw.Src)
+	return out
+}
+
 func TestEncode(t *testing.T) {
-	img0, err := openImage("video-001.bmp")
-	if err != nil {
-		t.Fatal(err)
+	testCases := []string{
+		"video-001.bmp",
+		"yellow_rose-small.bmp",
 	}
 
-	buf := new(bytes.Buffer)
-	err = Encode(buf, img0)
-	if err != nil {
-		t.Fatal(err)
-	}
+	for _, tc := range testCases {
+		img0, err := openImage(tc)
+		if err != nil {
+			t.Errorf("%s: Open BMP: %v", tc, err)
+			continue
+		}
 
-	img1, err := Decode(buf)
-	if err != nil {
-		t.Fatal(err)
-	}
+		buf := new(bytes.Buffer)
+		err = Encode(buf, img0)
+		if err != nil {
+			t.Errorf("%s: Encode BMP: %v", tc, err)
+			continue
+		}
 
-	compare(t, img0, img1)
+		img1, err := Decode(buf)
+		if err != nil {
+			t.Errorf("%s: Decode BMP: %v", tc, err)
+			continue
+		}
+
+		err = compare(t, img0, img1)
+		if err != nil {
+			t.Errorf("%s: Compare BMP: %v", tc, err)
+			continue
+		}
+
+		buf2 := new(bytes.Buffer)
+		rgba := convertToRGBA(img0)
+		err = Encode(buf2, rgba)
+		if err != nil {
+			t.Errorf("%s: Encode pre-multiplied BMP: %v", tc, err)
+			continue
+		}
+
+		img2, err := Decode(buf2)
+		if err != nil {
+			t.Errorf("%s: Decode pre-multiplied BMP: %v", tc, err)
+			continue
+		}
+
+		// We need to do another round trip to NRGBA to compare to, since
+		// the conversion process is lossy.
+		img3 := convertToNRGBA(rgba)
+
+		err = compare(t, img3, img2)
+		if err != nil {
+			t.Errorf("%s: Compare pre-multiplied BMP: %v", tc, err)
+		}
+	}
 }
 
 // TestZeroWidthVeryLargeHeight tests that encoding and decoding a degenerate