blake2b,blake2s: implement BinaryMarshaler, BinaryUnmarshaler

The marshal method allows the hash's internal state to be serialized and
unmarshaled at a later time, without having the re-write the entire stream
of data that was already written to the hash.

Fixes golang/go#24548

Change-Id: I82358c34181fc815f85d5d1509fb2fe0e62e40bd
Reviewed-on: https://go-review.googlesource.com/103241
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Run-TryBot: Filippo Valsorda <filippo@golang.org>
diff --git a/blake2b/blake2b.go b/blake2b/blake2b.go
index 6dedb89..58ea875 100644
--- a/blake2b/blake2b.go
+++ b/blake2b/blake2b.go
@@ -92,6 +92,8 @@
 // values equal or greater than:
 // - 32 if BLAKE2b is used as a hash function (The key is zero bytes long).
 // - 16 if BLAKE2b is used as a MAC function (The key is at least 16 bytes long).
+// When the key is nil, the returned hash.Hash implements BinaryMarshaler
+// and BinaryUnmarshaler for state (de)serialization as documented by hash.Hash.
 func New(size int, key []byte) (hash.Hash, error) { return newDigest(size, key) }
 
 func newDigest(hashSize int, key []byte) (*digest, error) {
@@ -150,6 +152,50 @@
 	keyLen int
 }
 
+const (
+	magic         = "b2b"
+	marshaledSize = len(magic) + 8*8 + 2*8 + 1 + BlockSize + 1
+)
+
+func (d *digest) MarshalBinary() ([]byte, error) {
+	if d.keyLen != 0 {
+		return nil, errors.New("crypto/blake2b: cannot marshal MACs")
+	}
+	b := make([]byte, 0, marshaledSize)
+	b = append(b, magic...)
+	for i := 0; i < 8; i++ {
+		b = appendUint64(b, d.h[i])
+	}
+	b = appendUint64(b, d.c[0])
+	b = appendUint64(b, d.c[1])
+	// Maximum value for size is 64
+	b = append(b, byte(d.size))
+	b = append(b, d.block[:]...)
+	b = append(b, byte(d.offset))
+	return b, nil
+}
+
+func (d *digest) UnmarshalBinary(b []byte) error {
+	if len(b) < len(magic) || string(b[:len(magic)]) != magic {
+		return errors.New("crypto/blake2b: invalid hash state identifier")
+	}
+	if len(b) != marshaledSize {
+		return errors.New("crypto/blake2b: invalid hash state size")
+	}
+	b = b[len(magic):]
+	for i := 0; i < 8; i++ {
+		b, d.h[i] = consumeUint64(b)
+	}
+	b, d.c[0] = consumeUint64(b)
+	b, d.c[1] = consumeUint64(b)
+	d.size = int(b[0])
+	b = b[1:]
+	copy(d.block[:], b[:BlockSize])
+	b = b[BlockSize:]
+	d.offset = int(b[0])
+	return nil
+}
+
 func (d *digest) BlockSize() int { return BlockSize }
 
 func (d *digest) Size() int { return d.size }
@@ -219,3 +265,25 @@
 		binary.LittleEndian.PutUint64(hash[8*i:], v)
 	}
 }
+
+func appendUint64(b []byte, x uint64) []byte {
+	var a [8]byte
+	binary.BigEndian.PutUint64(a[:], x)
+	return append(b, a[:]...)
+}
+
+func appendUint32(b []byte, x uint32) []byte {
+	var a [4]byte
+	binary.BigEndian.PutUint32(a[:], x)
+	return append(b, a[:]...)
+}
+
+func consumeUint64(b []byte) ([]byte, uint64) {
+	x := binary.BigEndian.Uint64(b)
+	return b[8:], x
+}
+
+func consumeUint32(b []byte) ([]byte, uint32) {
+	x := binary.BigEndian.Uint32(b)
+	return b[4:], x
+}
diff --git a/blake2b/blake2b_test.go b/blake2b/blake2b_test.go
index 5d68bbf..723327a 100644
--- a/blake2b/blake2b_test.go
+++ b/blake2b/blake2b_test.go
@@ -6,6 +6,7 @@
 
 import (
 	"bytes"
+	"encoding"
 	"encoding/hex"
 	"fmt"
 	"hash"
@@ -69,6 +70,54 @@
 	testHashes2X(t)
 }
 
+func TestMarshal(t *testing.T) {
+	input := make([]byte, 255)
+	for i := range input {
+		input[i] = byte(i)
+	}
+	for _, size := range []int{Size, Size256, Size384, 12, 25, 63} {
+		for i := 0; i < 256; i++ {
+			h, err := New(size, nil)
+			if err != nil {
+				t.Fatalf("size=%d, len(input)=%d: error from New(%v, nil): %v", size, i, size, err)
+			}
+			h2, err := New(size, nil)
+			if err != nil {
+				t.Fatalf("size=%d, len(input)=%d: error from New(%v, nil): %v", size, i, size, err)
+			}
+
+			h.Write(input[:i/2])
+			halfstate, err := h.(encoding.BinaryMarshaler).MarshalBinary()
+			if err != nil {
+				t.Fatalf("size=%d, len(input)=%d: could not marshal: %v", size, i, err)
+			}
+			err = h2.(encoding.BinaryUnmarshaler).UnmarshalBinary(halfstate)
+			if err != nil {
+				t.Fatalf("size=%d, len(input)=%d: could not unmarshal: %v", size, i, err)
+			}
+
+			h.Write(input[i/2 : i])
+			sum := h.Sum(nil)
+			h2.Write(input[i/2 : i])
+			sum2 := h2.Sum(nil)
+
+			if !bytes.Equal(sum, sum2) {
+				t.Fatalf("size=%d, len(input)=%d: results do not match; sum = %v, sum2 = %v", size, i, sum, sum2)
+			}
+
+			h3, err := New(size, nil)
+			if err != nil {
+				t.Fatalf("size=%d, len(input)=%d: error from New(%v, nil): %v", size, i, size, err)
+			}
+			h3.Write(input[:i])
+			sum3 := h3.Sum(nil)
+			if !bytes.Equal(sum, sum3) {
+				t.Fatalf("size=%d, len(input)=%d: sum = %v, want %v", size, i, sum, sum3)
+			}
+		}
+	}
+}
+
 func testHashes(t *testing.T) {
 	key, _ := hex.DecodeString("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f")
 
diff --git a/blake2s/blake2s.go b/blake2s/blake2s.go
index ae0dc92..5fb4a9e 100644
--- a/blake2s/blake2s.go
+++ b/blake2s/blake2s.go
@@ -49,6 +49,8 @@
 
 // New256 returns a new hash.Hash computing the BLAKE2s-256 checksum. A non-nil
 // key turns the hash into a MAC. The key must between zero and 32 bytes long.
+// When the key is nil, the returned hash.Hash implements BinaryMarshaler
+// and BinaryUnmarshaler for state (de)serialization as documented by hash.Hash.
 func New256(key []byte) (hash.Hash, error) { return newDigest(Size, key) }
 
 // New128 returns a new hash.Hash computing the BLAKE2s-128 checksum given a
@@ -120,6 +122,50 @@
 	keyLen int
 }
 
+const (
+	magic         = "b2s"
+	marshaledSize = len(magic) + 8*4 + 2*4 + 1 + BlockSize + 1
+)
+
+func (d *digest) MarshalBinary() ([]byte, error) {
+	if d.keyLen != 0 {
+		return nil, errors.New("crypto/blake2s: cannot marshal MACs")
+	}
+	b := make([]byte, 0, marshaledSize)
+	b = append(b, magic...)
+	for i := 0; i < 8; i++ {
+		b = appendUint32(b, d.h[i])
+	}
+	b = appendUint32(b, d.c[0])
+	b = appendUint32(b, d.c[1])
+	// Maximum value for size is 32
+	b = append(b, byte(d.size))
+	b = append(b, d.block[:]...)
+	b = append(b, byte(d.offset))
+	return b, nil
+}
+
+func (d *digest) UnmarshalBinary(b []byte) error {
+	if len(b) < len(magic) || string(b[:len(magic)]) != magic {
+		return errors.New("crypto/blake2s: invalid hash state identifier")
+	}
+	if len(b) != marshaledSize {
+		return errors.New("crypto/blake2s: invalid hash state size")
+	}
+	b = b[len(magic):]
+	for i := 0; i < 8; i++ {
+		b, d.h[i] = consumeUint32(b)
+	}
+	b, d.c[0] = consumeUint32(b)
+	b, d.c[1] = consumeUint32(b)
+	d.size = int(b[0])
+	b = b[1:]
+	copy(d.block[:], b[:BlockSize])
+	b = b[BlockSize:]
+	d.offset = int(b[0])
+	return nil
+}
+
 func (d *digest) BlockSize() int { return BlockSize }
 
 func (d *digest) Size() int { return d.size }
@@ -185,3 +231,14 @@
 		binary.LittleEndian.PutUint32(hash[4*i:], v)
 	}
 }
+
+func appendUint32(b []byte, x uint32) []byte {
+	var a [4]byte
+	binary.BigEndian.PutUint32(a[:], x)
+	return append(b, a[:]...)
+}
+
+func consumeUint32(b []byte) ([]byte, uint32) {
+	x := binary.BigEndian.Uint32(b)
+	return b[4:], x
+}
diff --git a/blake2s/blake2s_test.go b/blake2s/blake2s_test.go
index cfeb18b..cde79fb 100644
--- a/blake2s/blake2s_test.go
+++ b/blake2s/blake2s_test.go
@@ -5,6 +5,8 @@
 package blake2s
 
 import (
+	"bytes"
+	"encoding"
 	"encoding/hex"
 	"fmt"
 	"testing"
@@ -64,6 +66,52 @@
 	testHashes2X(t)
 }
 
+func TestMarshal(t *testing.T) {
+	input := make([]byte, 255)
+	for i := range input {
+		input[i] = byte(i)
+	}
+	for i := 0; i < 256; i++ {
+		h, err := New256(nil)
+		if err != nil {
+			t.Fatalf("len(input)=%d: error from New256(nil): %v", i, err)
+		}
+		h2, err := New256(nil)
+		if err != nil {
+			t.Fatalf("len(input)=%d: error from New256(nil): %v", i, err)
+		}
+
+		h.Write(input[:i/2])
+		halfstate, err := h.(encoding.BinaryMarshaler).MarshalBinary()
+		if err != nil {
+			t.Fatalf("len(input)=%d: could not marshal: %v", i, err)
+		}
+		err = h2.(encoding.BinaryUnmarshaler).UnmarshalBinary(halfstate)
+		if err != nil {
+			t.Fatalf("len(input)=%d: could not unmarshal: %v", i, err)
+		}
+
+		h.Write(input[i/2 : i])
+		sum := h.Sum(nil)
+		h2.Write(input[i/2 : i])
+		sum2 := h2.Sum(nil)
+
+		if !bytes.Equal(sum, sum2) {
+			t.Fatalf("len(input)=%d: results do not match; sum = %v, sum2 = %v", i, sum, sum2)
+		}
+
+		h3, err := New256(nil)
+		if err != nil {
+			t.Fatalf("len(input)=%d: error from New256(nil): %v", i, err)
+		}
+		h3.Write(input[:i])
+		sum3 := h3.Sum(nil)
+		if !bytes.Equal(sum, sum3) {
+			t.Fatalf("len(input)=%d: sum = %v, want %v", i, sum, sum3)
+		}
+	}
+}
+
 func testHashes(t *testing.T) {
 	key, _ := hex.DecodeString("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f")