crypto/hmac: make Sum idempotent
Fixes #978.
R=rsc
CC=golang-dev
https://golang.org/cl/1967045
diff --git a/src/pkg/crypto/hmac/hmac.go b/src/pkg/crypto/hmac/hmac.go
index 38d1373..3b5aa13 100644
--- a/src/pkg/crypto/hmac/hmac.go
+++ b/src/pkg/crypto/hmac/hmac.go
@@ -34,10 +34,9 @@
)
type hmac struct {
- size int
- key []byte
- tmp []byte
- inner hash.Hash
+ size int
+ key, tmp []byte
+ outer, inner hash.Hash
}
func (h *hmac) tmpPad(xor byte) {
@@ -50,14 +49,14 @@
}
func (h *hmac) Sum() []byte {
- h.tmpPad(0x5c)
sum := h.inner.Sum()
+ h.tmpPad(0x5c)
for i, b := range sum {
h.tmp[padSize+i] = b
}
- h.inner.Reset()
- h.inner.Write(h.tmp)
- return h.inner.Sum()
+ h.outer.Reset()
+ h.outer.Write(h.tmp)
+ return h.outer.Sum()
}
func (h *hmac) Write(p []byte) (n int, err os.Error) {
@@ -72,27 +71,26 @@
h.inner.Write(h.tmp[0:padSize])
}
-// New returns a new HMAC hash using the given hash and key.
-func New(h hash.Hash, key []byte) hash.Hash {
+// New returns a new HMAC hash using the given hash generator and key.
+func New(h func() hash.Hash, key []byte) hash.Hash {
+ hm := new(hmac)
+ hm.outer = h()
+ hm.inner = h()
+ hm.size = hm.inner.Size()
+ hm.tmp = make([]byte, padSize+hm.size)
if len(key) > padSize {
// If key is too big, hash it.
- h.Write(key)
- key = h.Sum()
+ hm.outer.Write(key)
+ key = hm.outer.Sum()
}
- hm := new(hmac)
- hm.inner = h
- hm.size = h.Size()
hm.key = make([]byte, len(key))
- for i, k := range key {
- hm.key[i] = k
- }
- hm.tmp = make([]byte, padSize+hm.size)
+ copy(hm.key, key)
hm.Reset()
return hm
}
// NewMD5 returns a new HMAC-MD5 hash using the given key.
-func NewMD5(key []byte) hash.Hash { return New(md5.New(), key) }
+func NewMD5(key []byte) hash.Hash { return New(md5.New, key) }
// NewSHA1 returns a new HMAC-SHA1 hash using the given key.
-func NewSHA1(key []byte) hash.Hash { return New(sha1.New(), key) }
+func NewSHA1(key []byte) hash.Hash { return New(sha1.New, key) }
diff --git a/src/pkg/crypto/hmac/hmac_test.go b/src/pkg/crypto/hmac/hmac_test.go
index d867c83..6934df2 100644
--- a/src/pkg/crypto/hmac/hmac_test.go
+++ b/src/pkg/crypto/hmac/hmac_test.go
@@ -84,9 +84,13 @@
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
continue
}
- sum := fmt.Sprintf("%x", h.Sum())
- if sum != tt.out {
- t.Errorf("test %d.%d: have %s want %s\n", i, j, sum, tt.out)
+
+ // Repetive Sum() calls should return the same value
+ for k := 0; k < 2; k++ {
+ sum := fmt.Sprintf("%x", h.Sum())
+ if sum != tt.out {
+ t.Errorf("test %d.%d.%d: have %s want %s\n", i, j, k, sum, tt.out)
+ }
}
// Second iteration: make sure reset works.
diff --git a/src/pkg/crypto/tls/handshake_client.go b/src/pkg/crypto/tls/handshake_client.go
index 4c4626c..c629920 100644
--- a/src/pkg/crypto/tls/handshake_client.go
+++ b/src/pkg/crypto/tls/handshake_client.go
@@ -8,7 +8,6 @@
"crypto/hmac"
"crypto/rc4"
"crypto/rsa"
- "crypto/sha1"
"crypto/subtle"
"crypto/x509"
"io"
@@ -226,7 +225,7 @@
cipher, _ := rc4.NewCipher(clientKey)
- c.out.prepareCipherSpec(cipher, hmac.New(sha1.New(), clientMAC))
+ c.out.prepareCipherSpec(cipher, hmac.NewSHA1(clientMAC))
c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
finished := new(finishedMsg)
@@ -235,7 +234,7 @@
c.writeRecord(recordTypeHandshake, finished.marshal())
cipher2, _ := rc4.NewCipher(serverKey)
- c.in.prepareCipherSpec(cipher2, hmac.New(sha1.New(), serverMAC))
+ c.in.prepareCipherSpec(cipher2, hmac.NewSHA1(serverMAC))
c.readRecord(recordTypeChangeCipherSpec)
if c.err != nil {
return c.err
diff --git a/src/pkg/crypto/tls/handshake_server.go b/src/pkg/crypto/tls/handshake_server.go
index 734c0fe..118dd43 100644
--- a/src/pkg/crypto/tls/handshake_server.go
+++ b/src/pkg/crypto/tls/handshake_server.go
@@ -16,7 +16,6 @@
"crypto/hmac"
"crypto/rc4"
"crypto/rsa"
- "crypto/sha1"
"crypto/subtle"
"crypto/x509"
"io"
@@ -227,7 +226,7 @@
keysFromPreMasterSecret11(preMasterSecret, clientHello.random, hello.random, suite.hashLength, suite.cipherKeyLength)
cipher, _ := rc4.NewCipher(clientKey)
- c.in.prepareCipherSpec(cipher, hmac.New(sha1.New(), clientMAC))
+ c.in.prepareCipherSpec(cipher, hmac.NewSHA1(clientMAC))
c.readRecord(recordTypeChangeCipherSpec)
if err := c.error(); err != nil {
return err
@@ -264,7 +263,7 @@
finishedHash.Write(clientFinished.marshal())
cipher2, _ := rc4.NewCipher(serverKey)
- c.out.prepareCipherSpec(cipher2, hmac.New(sha1.New(), serverMAC))
+ c.out.prepareCipherSpec(cipher2, hmac.NewSHA1(serverMAC))
c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
finished := new(finishedMsg)
diff --git a/src/pkg/crypto/tls/prf.go b/src/pkg/crypto/tls/prf.go
index ee6cb78..b206d26 100644
--- a/src/pkg/crypto/tls/prf.go
+++ b/src/pkg/crypto/tls/prf.go
@@ -20,7 +20,7 @@
}
// pHash implements the P_hash function, as defined in RFC 4346, section 5.
-func pHash(result, secret, seed []byte, hash hash.Hash) {
+func pHash(result, secret, seed []byte, hash func() hash.Hash) {
h := hmac.New(hash, secret)
h.Write(seed)
a := h.Sum()
@@ -46,8 +46,8 @@
// pRF11 implements the TLS 1.1 pseudo-random function, as defined in RFC 4346, section 5.
func pRF11(result, secret, label, seed []byte) {
- hashSHA1 := sha1.New()
- hashMD5 := md5.New()
+ hashSHA1 := sha1.New
+ hashMD5 := md5.New
labelAndSeed := make([]byte, len(label)+len(seed))
copy(labelAndSeed, label)