blake2b,blake2s: implement BinaryMarshaler, BinaryUnmarshaler
The marshal method allows the hash's internal state to be serialized and
unmarshaled at a later time, without having the re-write the entire stream
of data that was already written to the hash.
Fixes golang/go#24548
Change-Id: I82358c34181fc815f85d5d1509fb2fe0e62e40bd
Reviewed-on: https://go-review.googlesource.com/103241
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Run-TryBot: Filippo Valsorda <filippo@golang.org>
diff --git a/blake2b/blake2b.go b/blake2b/blake2b.go
index 6dedb89..58ea875 100644
--- a/blake2b/blake2b.go
+++ b/blake2b/blake2b.go
@@ -92,6 +92,8 @@
// values equal or greater than:
// - 32 if BLAKE2b is used as a hash function (The key is zero bytes long).
// - 16 if BLAKE2b is used as a MAC function (The key is at least 16 bytes long).
+// When the key is nil, the returned hash.Hash implements BinaryMarshaler
+// and BinaryUnmarshaler for state (de)serialization as documented by hash.Hash.
func New(size int, key []byte) (hash.Hash, error) { return newDigest(size, key) }
func newDigest(hashSize int, key []byte) (*digest, error) {
@@ -150,6 +152,50 @@
keyLen int
}
+const (
+ magic = "b2b"
+ marshaledSize = len(magic) + 8*8 + 2*8 + 1 + BlockSize + 1
+)
+
+func (d *digest) MarshalBinary() ([]byte, error) {
+ if d.keyLen != 0 {
+ return nil, errors.New("crypto/blake2b: cannot marshal MACs")
+ }
+ b := make([]byte, 0, marshaledSize)
+ b = append(b, magic...)
+ for i := 0; i < 8; i++ {
+ b = appendUint64(b, d.h[i])
+ }
+ b = appendUint64(b, d.c[0])
+ b = appendUint64(b, d.c[1])
+ // Maximum value for size is 64
+ b = append(b, byte(d.size))
+ b = append(b, d.block[:]...)
+ b = append(b, byte(d.offset))
+ return b, nil
+}
+
+func (d *digest) UnmarshalBinary(b []byte) error {
+ if len(b) < len(magic) || string(b[:len(magic)]) != magic {
+ return errors.New("crypto/blake2b: invalid hash state identifier")
+ }
+ if len(b) != marshaledSize {
+ return errors.New("crypto/blake2b: invalid hash state size")
+ }
+ b = b[len(magic):]
+ for i := 0; i < 8; i++ {
+ b, d.h[i] = consumeUint64(b)
+ }
+ b, d.c[0] = consumeUint64(b)
+ b, d.c[1] = consumeUint64(b)
+ d.size = int(b[0])
+ b = b[1:]
+ copy(d.block[:], b[:BlockSize])
+ b = b[BlockSize:]
+ d.offset = int(b[0])
+ return nil
+}
+
func (d *digest) BlockSize() int { return BlockSize }
func (d *digest) Size() int { return d.size }
@@ -219,3 +265,25 @@
binary.LittleEndian.PutUint64(hash[8*i:], v)
}
}
+
+func appendUint64(b []byte, x uint64) []byte {
+ var a [8]byte
+ binary.BigEndian.PutUint64(a[:], x)
+ return append(b, a[:]...)
+}
+
+func appendUint32(b []byte, x uint32) []byte {
+ var a [4]byte
+ binary.BigEndian.PutUint32(a[:], x)
+ return append(b, a[:]...)
+}
+
+func consumeUint64(b []byte) ([]byte, uint64) {
+ x := binary.BigEndian.Uint64(b)
+ return b[8:], x
+}
+
+func consumeUint32(b []byte) ([]byte, uint32) {
+ x := binary.BigEndian.Uint32(b)
+ return b[4:], x
+}
diff --git a/blake2b/blake2b_test.go b/blake2b/blake2b_test.go
index 5d68bbf..723327a 100644
--- a/blake2b/blake2b_test.go
+++ b/blake2b/blake2b_test.go
@@ -6,6 +6,7 @@
import (
"bytes"
+ "encoding"
"encoding/hex"
"fmt"
"hash"
@@ -69,6 +70,54 @@
testHashes2X(t)
}
+func TestMarshal(t *testing.T) {
+ input := make([]byte, 255)
+ for i := range input {
+ input[i] = byte(i)
+ }
+ for _, size := range []int{Size, Size256, Size384, 12, 25, 63} {
+ for i := 0; i < 256; i++ {
+ h, err := New(size, nil)
+ if err != nil {
+ t.Fatalf("size=%d, len(input)=%d: error from New(%v, nil): %v", size, i, size, err)
+ }
+ h2, err := New(size, nil)
+ if err != nil {
+ t.Fatalf("size=%d, len(input)=%d: error from New(%v, nil): %v", size, i, size, err)
+ }
+
+ h.Write(input[:i/2])
+ halfstate, err := h.(encoding.BinaryMarshaler).MarshalBinary()
+ if err != nil {
+ t.Fatalf("size=%d, len(input)=%d: could not marshal: %v", size, i, err)
+ }
+ err = h2.(encoding.BinaryUnmarshaler).UnmarshalBinary(halfstate)
+ if err != nil {
+ t.Fatalf("size=%d, len(input)=%d: could not unmarshal: %v", size, i, err)
+ }
+
+ h.Write(input[i/2 : i])
+ sum := h.Sum(nil)
+ h2.Write(input[i/2 : i])
+ sum2 := h2.Sum(nil)
+
+ if !bytes.Equal(sum, sum2) {
+ t.Fatalf("size=%d, len(input)=%d: results do not match; sum = %v, sum2 = %v", size, i, sum, sum2)
+ }
+
+ h3, err := New(size, nil)
+ if err != nil {
+ t.Fatalf("size=%d, len(input)=%d: error from New(%v, nil): %v", size, i, size, err)
+ }
+ h3.Write(input[:i])
+ sum3 := h3.Sum(nil)
+ if !bytes.Equal(sum, sum3) {
+ t.Fatalf("size=%d, len(input)=%d: sum = %v, want %v", size, i, sum, sum3)
+ }
+ }
+ }
+}
+
func testHashes(t *testing.T) {
key, _ := hex.DecodeString("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f")
diff --git a/blake2s/blake2s.go b/blake2s/blake2s.go
index ae0dc92..5fb4a9e 100644
--- a/blake2s/blake2s.go
+++ b/blake2s/blake2s.go
@@ -49,6 +49,8 @@
// New256 returns a new hash.Hash computing the BLAKE2s-256 checksum. A non-nil
// key turns the hash into a MAC. The key must between zero and 32 bytes long.
+// When the key is nil, the returned hash.Hash implements BinaryMarshaler
+// and BinaryUnmarshaler for state (de)serialization as documented by hash.Hash.
func New256(key []byte) (hash.Hash, error) { return newDigest(Size, key) }
// New128 returns a new hash.Hash computing the BLAKE2s-128 checksum given a
@@ -120,6 +122,50 @@
keyLen int
}
+const (
+ magic = "b2s"
+ marshaledSize = len(magic) + 8*4 + 2*4 + 1 + BlockSize + 1
+)
+
+func (d *digest) MarshalBinary() ([]byte, error) {
+ if d.keyLen != 0 {
+ return nil, errors.New("crypto/blake2s: cannot marshal MACs")
+ }
+ b := make([]byte, 0, marshaledSize)
+ b = append(b, magic...)
+ for i := 0; i < 8; i++ {
+ b = appendUint32(b, d.h[i])
+ }
+ b = appendUint32(b, d.c[0])
+ b = appendUint32(b, d.c[1])
+ // Maximum value for size is 32
+ b = append(b, byte(d.size))
+ b = append(b, d.block[:]...)
+ b = append(b, byte(d.offset))
+ return b, nil
+}
+
+func (d *digest) UnmarshalBinary(b []byte) error {
+ if len(b) < len(magic) || string(b[:len(magic)]) != magic {
+ return errors.New("crypto/blake2s: invalid hash state identifier")
+ }
+ if len(b) != marshaledSize {
+ return errors.New("crypto/blake2s: invalid hash state size")
+ }
+ b = b[len(magic):]
+ for i := 0; i < 8; i++ {
+ b, d.h[i] = consumeUint32(b)
+ }
+ b, d.c[0] = consumeUint32(b)
+ b, d.c[1] = consumeUint32(b)
+ d.size = int(b[0])
+ b = b[1:]
+ copy(d.block[:], b[:BlockSize])
+ b = b[BlockSize:]
+ d.offset = int(b[0])
+ return nil
+}
+
func (d *digest) BlockSize() int { return BlockSize }
func (d *digest) Size() int { return d.size }
@@ -185,3 +231,14 @@
binary.LittleEndian.PutUint32(hash[4*i:], v)
}
}
+
+func appendUint32(b []byte, x uint32) []byte {
+ var a [4]byte
+ binary.BigEndian.PutUint32(a[:], x)
+ return append(b, a[:]...)
+}
+
+func consumeUint32(b []byte) ([]byte, uint32) {
+ x := binary.BigEndian.Uint32(b)
+ return b[4:], x
+}
diff --git a/blake2s/blake2s_test.go b/blake2s/blake2s_test.go
index cfeb18b..cde79fb 100644
--- a/blake2s/blake2s_test.go
+++ b/blake2s/blake2s_test.go
@@ -5,6 +5,8 @@
package blake2s
import (
+ "bytes"
+ "encoding"
"encoding/hex"
"fmt"
"testing"
@@ -64,6 +66,52 @@
testHashes2X(t)
}
+func TestMarshal(t *testing.T) {
+ input := make([]byte, 255)
+ for i := range input {
+ input[i] = byte(i)
+ }
+ for i := 0; i < 256; i++ {
+ h, err := New256(nil)
+ if err != nil {
+ t.Fatalf("len(input)=%d: error from New256(nil): %v", i, err)
+ }
+ h2, err := New256(nil)
+ if err != nil {
+ t.Fatalf("len(input)=%d: error from New256(nil): %v", i, err)
+ }
+
+ h.Write(input[:i/2])
+ halfstate, err := h.(encoding.BinaryMarshaler).MarshalBinary()
+ if err != nil {
+ t.Fatalf("len(input)=%d: could not marshal: %v", i, err)
+ }
+ err = h2.(encoding.BinaryUnmarshaler).UnmarshalBinary(halfstate)
+ if err != nil {
+ t.Fatalf("len(input)=%d: could not unmarshal: %v", i, err)
+ }
+
+ h.Write(input[i/2 : i])
+ sum := h.Sum(nil)
+ h2.Write(input[i/2 : i])
+ sum2 := h2.Sum(nil)
+
+ if !bytes.Equal(sum, sum2) {
+ t.Fatalf("len(input)=%d: results do not match; sum = %v, sum2 = %v", i, sum, sum2)
+ }
+
+ h3, err := New256(nil)
+ if err != nil {
+ t.Fatalf("len(input)=%d: error from New256(nil): %v", i, err)
+ }
+ h3.Write(input[:i])
+ sum3 := h3.Sum(nil)
+ if !bytes.Equal(sum, sum3) {
+ t.Fatalf("len(input)=%d: sum = %v, want %v", i, sum, sum3)
+ }
+ }
+}
+
func testHashes(t *testing.T) {
key, _ := hex.DecodeString("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f")