diff --git a/cng/hkdf.go b/cng/hkdf.go index 5338fb5..20bcc79 100644 --- a/cng/hkdf.go +++ b/cng/hkdf.go @@ -10,7 +10,6 @@ import ( "encoding/binary" "errors" "hash" - "io" "runtime" "unsafe" @@ -28,99 +27,23 @@ func loadHKDF() (bcrypt.ALG_HANDLE, error) { }) } -type hkdf struct { - hkey bcrypt.KEY_HANDLE - info []byte - - hashLen int - n int // count of bytes requested from Read - // buf contains the derived data. - // len(buf) can be larger than n, as Read may derive - // more data than requested and cache it in buf. - buf []byte -} - -func (c *hkdf) finalize() { - bcrypt.DestroyKey(c.hkey) -} - -func hkdfDerive(hkey bcrypt.KEY_HANDLE, info, out []byte) (int, error) { - var params *bcrypt.BufferDesc - if len(info) > 0 { - params = &bcrypt.BufferDesc{ - Count: 1, - Buffers: &bcrypt.Buffer{ - Length: uint32(len(info)), - Type: bcrypt.KDF_HKDF_INFO, - Data: uintptr(unsafe.Pointer(&info[0])), - }, - } - defer runtime.KeepAlive(params) - } - var n uint32 - err := bcrypt.KeyDerivation(hkey, params, out, &n, 0) - return int(n), err -} - -func (c *hkdf) Read(p []byte) (int, error) { - // KeyDerivation doesn't support incremental output, each call - // derives the key from scratch and returns the requested bytes. - // To implement io.Reader, we need to ask for len(c.buf) + len(p) - // bytes and copy the last derived len(p) bytes to p. - maxDerived := 255 * c.hashLen - totalDerived := c.n + len(p) - // Check whether enough data can be derived. - if totalDerived > maxDerived { - return 0, errors.New("hkdf: entropy limit reached") - } - // Check whether c.buf already contains enough derived data, - // otherwise derive more data. - if bytesNeeded := totalDerived - len(c.buf); bytesNeeded > 0 { - // It is common to derive multiple equally sized keys from the same HKDF instance. - // Optimize this case by allocating a buffer large enough to hold - // at least 3 of such keys each time there is not enough data. - // Round up to the next multiple of hashLen. - blocks := (bytesNeeded-1)/c.hashLen + 1 - const minBlocks = 3 - if blocks < minBlocks { - blocks = minBlocks - } - alloc := blocks * c.hashLen - if len(c.buf)+alloc > maxDerived { - // The buffer can't grow beyond maxDerived. - alloc = maxDerived - len(c.buf) - } - c.buf = append(c.buf, make([]byte, alloc)...) - n, err := hkdfDerive(c.hkey, c.info, c.buf) - if err != nil { - c.buf = c.buf[:c.n] - return 0, err - } - // Adjust totalDerived to the actual number of bytes derived. - totalDerived = n - } - n := copy(p, c.buf[c.n:totalDerived]) - c.n += n - return n, nil -} - -func newHKDF(h func() hash.Hash, secret, salt []byte, info []byte) (*hkdf, error) { +func newHKDF(h func() hash.Hash, secret, salt []byte, info []byte) (bcrypt.KEY_HANDLE, error) { ch := h() hashID := hashToID(ch) if hashID == "" { - return nil, errors.New("cng: unsupported hash function") + return 0, errors.New("cng: unsupported hash function") } alg, err := loadHKDF() if err != nil { - return nil, err + return 0, err } var kh bcrypt.KEY_HANDLE if err := bcrypt.GenerateSymmetricKey(alg, &kh, nil, secret, 0); err != nil { - return nil, err + return 0, err } if err := setString(bcrypt.HANDLE(kh), bcrypt.HKDF_HASH_ALGORITHM, hashID); err != nil { bcrypt.DestroyKey(kh) - return nil, err + return 0, err } if salt != nil { // Used for Extract. @@ -131,11 +54,9 @@ func newHKDF(h func() hash.Hash, secret, salt []byte, info []byte) (*hkdf, error } if err != nil { bcrypt.DestroyKey(kh) - return nil, err + return 0, err } - k := &hkdf{kh, info, ch.Size(), 0, nil} - runtime.SetFinalizer(k, (*hkdf).finalize) - return k, nil + return kh, nil } func ExtractHKDF(h func() hash.Hash, secret, salt []byte) ([]byte, error) { @@ -147,11 +68,11 @@ func ExtractHKDF(h func() hash.Hash, secret, salt []byte) ([]byte, error) { if err != nil { return nil, err } - hdr, blob, err := exportKeyData(kh.hkey) + defer bcrypt.DestroyKey(kh) + hdr, blob, err := exportKeyData(kh) if err != nil { return nil, err } - runtime.KeepAlive(kh) if hdr.Version != bcrypt.KEY_DATA_BLOB_VERSION1 { return nil, errors.New("cng: unknown key data blob version") } @@ -171,10 +92,33 @@ func ExtractHKDF(h func() hash.Hash, secret, salt []byte) ([]byte, error) { return blob[cbHashName:], nil } -func ExpandHKDF(h func() hash.Hash, pseudorandomKey, info []byte) (io.Reader, error) { +// ExpandHKDF derives a key from the given hash, key, and optional context info. +func ExpandHKDF(h func() hash.Hash, pseudorandomKey, info []byte, keyLength int) ([]byte, error) { kh, err := newHKDF(h, pseudorandomKey, nil, info) if err != nil { return nil, err } - return kh, nil + defer bcrypt.DestroyKey(kh) + out := make([]byte, keyLength) + var params *bcrypt.BufferDesc + if len(info) > 0 { + params = &bcrypt.BufferDesc{ + Count: 1, + Buffers: &bcrypt.Buffer{ + Length: uint32(len(info)), + Type: bcrypt.KDF_HKDF_INFO, + Data: uintptr(unsafe.Pointer(&info[0])), + }, + } + defer runtime.KeepAlive(params) + } + var n uint32 + err = bcrypt.KeyDerivation(kh, params, out, &n, 0) + if err != nil { + return nil, err + } + if int(n) != keyLength { + return nil, errors.New("cng: key derivation returned unexpected length") + } + return out, err } diff --git a/cng/hkdf_test.go b/cng/hkdf_test.go index 19bc233..5aa0e5b 100644 --- a/cng/hkdf_test.go +++ b/cng/hkdf_test.go @@ -9,7 +9,6 @@ package cng_test import ( "bytes" "hash" - "io" "testing" "github.com/microsoft/go-crypto-winnative/cng" @@ -295,16 +294,12 @@ var hkdfTests = []hkdfTest{ }, } -func newHKDF(hash func() hash.Hash, secret, salt, info []byte) io.Reader { +func newHKDF(hash func() hash.Hash, secret, salt, info []byte, keyLength int) ([]byte, error) { prk, err := cng.ExtractHKDF(hash, secret, salt) if err != nil { - panic(err) + return nil, err } - r, err := cng.ExpandHKDF(hash, prk, info) - if err != nil { - panic(err) - } - return r + return cng.ExpandHKDF(hash, prk, info, keyLength) } func TestHKDF(t *testing.T) { @@ -320,56 +315,25 @@ func TestHKDF(t *testing.T) { t.Errorf("test %d: incorrect PRK: have %v, need %v.", i, prk, tt.prk) } - hkdf := newHKDF(tt.hash, tt.master, tt.salt, tt.info) - out := make([]byte, len(tt.out)) - - n, err := io.ReadFull(hkdf, out) - if n != len(tt.out) || err != nil { - t.Errorf("test %d: not enough output bytes: %d.", i, n) + out, err := newHKDF(tt.hash, tt.master, tt.salt, tt.info, len(tt.out)) + if err != nil { + t.Errorf("test %d: error generating HKDF: %v.", i, err) } - if !bytes.Equal(out, tt.out) { t.Errorf("test %d: incorrect output: have %v, need %v.", i, out, tt.out) } - hkdf, err = cng.ExpandHKDF(tt.hash, prk, tt.info) + out, err = cng.ExpandHKDF(tt.hash, prk, tt.info, len(tt.out)) if err != nil { t.Errorf("test %d: error expanding HKDF: %v.", i, err) } - - n, err = io.ReadFull(hkdf, out) - if n != len(tt.out) || err != nil { - t.Errorf("test %d: not enough output bytes from Expand: %d.", i, n) - } - if !bytes.Equal(out, tt.out) { t.Errorf("test %d: incorrect output from Expand: have %v, need %v.", i, out, tt.out) } } } -func TestHKDFMultiRead(t *testing.T) { - if !cng.SupportsHKDF() { - t.Skip("HKDF is not supported") - } - for i, tt := range hkdfTests { - hkdf := newHKDF(tt.hash, tt.master, tt.salt, tt.info) - out := make([]byte, len(tt.out)) - - for b := 0; b < len(tt.out); b++ { - n, err := io.ReadFull(hkdf, out[b:b+1]) - if n != 1 || err != nil { - t.Errorf("test %d.%d: not enough output bytes: have %d, need %d .", i, b, n, len(tt.out)) - } - } - - if !bytes.Equal(out, tt.out) { - t.Errorf("test %d: incorrect output: have %v, need %v.", i, out, tt.out) - } - } -} - -func TestHKDFLimit(t *testing.T) { +func TestExpandHKDFOneShotLimit(t *testing.T) { if !cng.SupportsHKDF() { t.Skip("HKDF is not supported") } @@ -377,65 +341,39 @@ func TestHKDFLimit(t *testing.T) { master := []byte{0x00, 0x01, 0x02, 0x03} info := []byte{} - hkdf := newHKDF(hash, master, nil, info) + prk, err := cng.ExtractHKDF(hash, master, nil) + if err != nil { + t.Fatalf("error extracting HKDF: %v.", err) + } limit := hash().Size() * 255 - out := make([]byte, limit) - - // The maximum output bytes should be extractable - n, err := io.ReadFull(hkdf, out) - if n != limit || err != nil { - t.Errorf("not enough output bytes: %d, %v.", n, err) + out, err := cng.ExpandHKDF(hash, prk, info, limit) + if err != nil { + t.Errorf("error expanding HKDF one-shot: %v.", err) } - - // Reading one more should fail - n, err = io.ReadFull(hkdf, make([]byte, 1)) - if n > 0 || err == nil { - t.Errorf("key expansion overflowed: n = %d, err = %v", n, err) + if len(out) != limit { + t.Errorf("incorrect output length: have %d, need %d.", len(out), limit) } -} -func BenchmarkHKDF32ByteSHA256Single(b *testing.B) { - benchmarkHKDFSingle(cng.NewSHA256, 32, b) -} - -func BenchmarkHKDF8ByteSHA256Stream(b *testing.B) { - benchmarkHKDFStream(cng.NewSHA256, 8, b) -} - -func BenchmarkHKDF32ByteSHA256Stream(b *testing.B) { - benchmarkHKDFStream(cng.NewSHA256, 32, b) -} - -func benchmarkHKDFSingle(hasher func() hash.Hash, block int, b *testing.B) { - master := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07} - salt := []byte{0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17} - info := []byte{0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27} - out := make([]byte, block) - - b.SetBytes(int64(block)) - b.ResetTimer() - - for i := 0; i < b.N; i++ { - hkdf := newHKDF(hasher, master, salt, info) - io.ReadFull(hkdf, out) + // Expanding one more byte should fail + _, err = cng.ExpandHKDF(hash, prk, info, limit+1) + if err == nil { + t.Errorf("expected error for key expansion overflow") } } -func benchmarkHKDFStream(hasher func() hash.Hash, block int, b *testing.B) { +func BenchmarkHKDF32ByteSHA256Single(b *testing.B) { master := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07} salt := []byte{0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17} info := []byte{0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27} - out := make([]byte, block) + const block = 32 b.SetBytes(int64(block)) b.ResetTimer() - hkdf := newHKDF(hasher, master, salt, info) for i := 0; i < b.N; i++ { - _, err := io.ReadFull(hkdf, out) + _, err := newHKDF(cng.NewSHA256, master, salt, info, block) if err != nil { - hkdf = newHKDF(hasher, master, salt, info) - i-- + b.Error(err) } } }