hkdf: add Extract and Expand

RFC 5869, Section 3.3 suggests it might be sometimes appropriate to use
Expand without Extract, and it is reasonable to reuse (secret, salt)
with different info values, in which case the Extract can be performed
once as an optimization.

TLS 1.3 also needs direct access to both Extract and Expand.

pseudorandomKey is ugly to look at, but that's intentional, as it
signals that this should have non-obvious properties to the user. The
docs will make it clear it's not the thing you should use in most cases.

Fixes golang/go#28237

Change-Id: Ib43ae8cdde0663aa4752172c39aadfb0e1c35f10
Reviewed-on: https://go-review.googlesource.com/c/144398
Reviewed-by: Adam Langley <agl@golang.org>
diff --git a/hkdf/hkdf.go b/hkdf/hkdf.go
index 5bc2463..dda3f14 100644
--- a/hkdf/hkdf.go
+++ b/hkdf/hkdf.go
@@ -8,8 +8,6 @@
 // HKDF is a cryptographic key derivation function (KDF) with the goal of
 // expanding limited input keying material into one or more cryptographically
 // strong secret keys.
-//
-// RFC 5869: https://tools.ietf.org/html/rfc5869
 package hkdf // import "golang.org/x/crypto/hkdf"
 
 import (
@@ -19,6 +17,21 @@
 	"io"
 )
 
+// Extract generates a pseudorandom key for use with Expand from an input secret
+// and an optional independent salt.
+//
+// Only use this function if you need to reuse the extracted key with multiple
+// Expand invocations and different context values. Most common scenarios,
+// including the generation of multiple keys, should use New instead.
+func Extract(hash func() hash.Hash, secret, salt []byte) []byte {
+	if salt == nil {
+		salt = make([]byte, hash().Size())
+	}
+	extractor := hmac.New(hash, salt)
+	extractor.Write(secret)
+	return extractor.Sum(nil)
+}
+
 type hkdf struct {
 	expander hash.Hash
 	size     int
@@ -26,22 +39,22 @@
 	info    []byte
 	counter byte
 
-	prev  []byte
-	cache []byte
+	prev []byte
+	buf  []byte
 }
 
 func (f *hkdf) Read(p []byte) (int, error) {
 	// Check whether enough data can be generated
 	need := len(p)
-	remains := len(f.cache) + int(255-f.counter+1)*f.size
+	remains := len(f.buf) + int(255-f.counter+1)*f.size
 	if remains < need {
 		return 0, errors.New("hkdf: entropy limit reached")
 	}
-	// Read from the cache, if enough data is present
-	n := copy(p, f.cache)
+	// Read any leftover from the buffer
+	n := copy(p, f.buf)
 	p = p[n:]
 
-	// Fill the buffer
+	// Fill the rest of the buffer
 	for len(p) > 0 {
 		f.expander.Reset()
 		f.expander.Write(f.prev)
@@ -51,25 +64,30 @@
 		f.counter++
 
 		// Copy the new batch into p
-		f.cache = f.prev
-		n = copy(p, f.cache)
+		f.buf = f.prev
+		n = copy(p, f.buf)
 		p = p[n:]
 	}
 	// Save leftovers for next run
-	f.cache = f.cache[n:]
+	f.buf = f.buf[n:]
 
 	return need, nil
 }
 
-// New returns a new HKDF using the given hash, the secret keying material to expand
-// and optional salt and info fields.
-func New(hash func() hash.Hash, secret, salt, info []byte) io.Reader {
-	if salt == nil {
-		salt = make([]byte, hash().Size())
-	}
-	extractor := hmac.New(hash, salt)
-	extractor.Write(secret)
-	prk := extractor.Sum(nil)
+// Expand returns a Reader, from which keys can be read, using the given
+// pseudorandom key and optional context info, skipping the extraction step.
+//
+// The pseudorandomKey should have been generated by Extract, or be a uniformly
+// random or pseudorandom cryptographically strong key. See RFC 5869, Section
+// 3.3. Most common scenarios will want to use New instead.
+func Expand(hash func() hash.Hash, pseudorandomKey, info []byte) io.Reader {
+	expander := hmac.New(hash, pseudorandomKey)
+	return &hkdf{expander, expander.Size(), info, 1, nil, nil}
+}
 
-	return &hkdf{hmac.New(hash, prk), extractor.Size(), info, 1, nil, nil}
+// New returns a Reader, from which keys can be read, using the given hash,
+// secret, salt and context info. Salt and info can be nil.
+func New(hash func() hash.Hash, secret, salt, info []byte) io.Reader {
+	prk := Extract(hash, secret, salt)
+	return Expand(hash, prk, info)
 }
diff --git a/hkdf/hkdf_test.go b/hkdf/hkdf_test.go
index cee659b..ea57577 100644
--- a/hkdf/hkdf_test.go
+++ b/hkdf/hkdf_test.go
@@ -18,6 +18,7 @@
 	hash   func() hash.Hash
 	master []byte
 	salt   []byte
+	prk    []byte
 	info   []byte
 	out    []byte
 }
@@ -36,6 +37,12 @@
 			0x08, 0x09, 0x0a, 0x0b, 0x0c,
 		},
 		[]byte{
+			0x07, 0x77, 0x09, 0x36, 0x2c, 0x2e, 0x32, 0xdf,
+			0x0d, 0xdc, 0x3f, 0x0d, 0xc4, 0x7b, 0xba, 0x63,
+			0x90, 0xb6, 0xc7, 0x3b, 0xb5, 0x0f, 0x9c, 0x31,
+			0x22, 0xec, 0x84, 0x4a, 0xd7, 0xc2, 0xb3, 0xe5,
+		},
+		[]byte{
 			0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7,
 			0xf8, 0xf9,
 		},
@@ -75,6 +82,12 @@
 			0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf,
 		},
 		[]byte{
+			0x06, 0xa6, 0xb8, 0x8c, 0x58, 0x53, 0x36, 0x1a,
+			0x06, 0x10, 0x4c, 0x9c, 0xeb, 0x35, 0xb4, 0x5c,
+			0xef, 0x76, 0x00, 0x14, 0x90, 0x46, 0x71, 0x01,
+			0x4a, 0x19, 0x3f, 0x40, 0xc1, 0x5f, 0xc2, 0x44,
+		},
+		[]byte{
 			0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb5, 0xb6, 0xb7,
 			0xb8, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf,
 			0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7,
@@ -108,6 +121,12 @@
 			0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b,
 		},
 		[]byte{},
+		[]byte{
+			0x19, 0xef, 0x24, 0xa3, 0x2c, 0x71, 0x7b, 0x16,
+			0x7f, 0x33, 0xa9, 0x1d, 0x6f, 0x64, 0x8b, 0xdf,
+			0x96, 0x59, 0x67, 0x76, 0xaf, 0xdb, 0x63, 0x77,
+			0xac, 0x43, 0x4c, 0x1c, 0x29, 0x3c, 0xcb, 0x04,
+		},
 		[]byte{},
 		[]byte{
 			0x8d, 0xa4, 0xe7, 0x75, 0xa5, 0x63, 0xc1, 0x8f,
@@ -119,6 +138,30 @@
 		},
 	},
 	{
+		sha256.New,
+		[]byte{
+			0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b,
+			0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b,
+			0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b,
+		},
+		nil,
+		[]byte{
+			0x19, 0xef, 0x24, 0xa3, 0x2c, 0x71, 0x7b, 0x16,
+			0x7f, 0x33, 0xa9, 0x1d, 0x6f, 0x64, 0x8b, 0xdf,
+			0x96, 0x59, 0x67, 0x76, 0xaf, 0xdb, 0x63, 0x77,
+			0xac, 0x43, 0x4c, 0x1c, 0x29, 0x3c, 0xcb, 0x04,
+		},
+		nil,
+		[]byte{
+			0x8d, 0xa4, 0xe7, 0x75, 0xa5, 0x63, 0xc1, 0x8f,
+			0x71, 0x5f, 0x80, 0x2a, 0x06, 0x3c, 0x5a, 0x31,
+			0xb8, 0xa1, 0x1f, 0x5c, 0x5e, 0xe1, 0x87, 0x9e,
+			0xc3, 0x45, 0x4e, 0x5f, 0x3c, 0x73, 0x8d, 0x2d,
+			0x9d, 0x20, 0x13, 0x95, 0xfa, 0xa4, 0xb6, 0x1a,
+			0x96, 0xc8,
+		},
+	},
+	{
 		sha1.New,
 		[]byte{
 			0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b,
@@ -129,6 +172,11 @@
 			0x08, 0x09, 0x0a, 0x0b, 0x0c,
 		},
 		[]byte{
+			0x9b, 0x6c, 0x18, 0xc4, 0x32, 0xa7, 0xbf, 0x8f,
+			0x0e, 0x71, 0xc8, 0xeb, 0x88, 0xf4, 0xb3, 0x0b,
+			0xaa, 0x2b, 0xa2, 0x43,
+		},
+		[]byte{
 			0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7,
 			0xf8, 0xf9,
 		},
@@ -168,6 +216,11 @@
 			0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf,
 		},
 		[]byte{
+			0x8a, 0xda, 0xe0, 0x9a, 0x2a, 0x30, 0x70, 0x59,
+			0x47, 0x8d, 0x30, 0x9b, 0x26, 0xc4, 0x11, 0x5a,
+			0x22, 0x4c, 0xfa, 0xf6,
+		},
+		[]byte{
 			0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb5, 0xb6, 0xb7,
 			0xb8, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf,
 			0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7,
@@ -201,6 +254,11 @@
 			0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b,
 		},
 		[]byte{},
+		[]byte{
+			0xda, 0x8c, 0x8a, 0x73, 0xc7, 0xfa, 0x77, 0x28,
+			0x8e, 0xc6, 0xf5, 0xe7, 0xc2, 0x97, 0x78, 0x6a,
+			0xa0, 0xd3, 0x2d, 0x01,
+		},
 		[]byte{},
 		[]byte{
 			0x0a, 0xc1, 0xaf, 0x70, 0x02, 0xb3, 0xd7, 0x61,
@@ -219,7 +277,12 @@
 			0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c,
 		},
 		nil,
-		[]byte{},
+		[]byte{
+			0x2a, 0xdc, 0xca, 0xda, 0x18, 0x77, 0x9e, 0x7c,
+			0x20, 0x77, 0xad, 0x2e, 0xb1, 0x9d, 0x3f, 0x3e,
+			0x73, 0x13, 0x85, 0xdd,
+		},
+		nil,
 		[]byte{
 			0x2c, 0x91, 0x11, 0x72, 0x04, 0xd7, 0x45, 0xf3,
 			0x50, 0x0d, 0x63, 0x6a, 0x62, 0xf6, 0x4f, 0x0a,
@@ -233,6 +296,11 @@
 
 func TestHKDF(t *testing.T) {
 	for i, tt := range hkdfTests {
+		prk := Extract(tt.hash, tt.master, tt.salt)
+		if !bytes.Equal(prk, tt.prk) {
+			t.Errorf("test %d: incorrect PRK: have %v, need %v.", i, prk, tt.prk)
+		}
+
 		hkdf := New(tt.hash, tt.master, tt.salt, tt.info)
 		out := make([]byte, len(tt.out))
 
@@ -244,6 +312,17 @@
 		if !bytes.Equal(out, tt.out) {
 			t.Errorf("test %d: incorrect output: have %v, need %v.", i, out, tt.out)
 		}
+
+		hkdf = Expand(tt.hash, prk, tt.info)
+
+		n, err = io.ReadFull(hkdf, out)
+		if n != len(tt.out) || err != nil {
+			t.Errorf("test %d: not enough output bytes from Expand: %d.", i, n)
+		}
+
+		if !bytes.Equal(out, tt.out) {
+			t.Errorf("test %d: incorrect output from Expand: have %v, need %v.", i, out, tt.out)
+		}
 	}
 }