diff --git a/cng/hkdf.go b/cng/hkdf.go index 725fa5c..e9b997f 100644 --- a/cng/hkdf.go +++ b/cng/hkdf.go @@ -38,44 +38,73 @@ type hkdf struct { hashLen int buf []byte + n int } func (c *hkdf) finalize() { bcrypt.DestroyKey(c.hkey) } -func (c *hkdf) Read(p []byte) (int, error) { +func hkdfDerive(hkey bcrypt.KEY_HANDLE, info, out []byte) (int, error) { var params *bcrypt.BufferDesc - if len(c.info) > 0 { + if len(info) > 0 { params = &bcrypt.BufferDesc{ Count: 1, Buffers: &bcrypt.Buffer{ - Length: uint32(len(c.info)), + Length: uint32(len(info)), Type: bcrypt.KDF_HKDF_INFO, - Data: uintptr(unsafe.Pointer(&c.info[0])), + Data: uintptr(unsafe.Pointer(&info[0])), }, } + defer runtime.KeepAlive(params) } + var n uint32 + err := bcrypt.KeyDerivation(hkey, params, out, &n, 0) + return int(n), err +} + +func (c *hkdf) Read(p []byte) (int, error) { // KeyDerivation doesn't support incremental output, each call // derives the key from scratch and returns the requested bytes. // To implement io.Reader, we need to ask for len(c.buf) + len(p) // bytes and copy the last derived len(p) bytes to p. - // We use c.buf to know how many bytes we've already derived and - // to avoid allocating the whole output buffer on each call. - prevLen := len(c.buf) - needLen := len(p) - remains := 255*c.hashLen - prevLen - // Check whether enough data can be generated. - if remains < needLen { + maxDerived := 255 * c.hashLen + totalDerived := c.n + len(p) + // Check whether enough data can be derived. + if totalDerived > maxDerived { return 0, errors.New("hkdf: entropy limit reached") } - c.buf = append(c.buf, make([]byte, needLen)...) - var size uint32 - if err := bcrypt.KeyDerivation(c.hkey, params, c.buf, &size, 0); err != nil { - return 0, err + // Check whether c.buf already contains enough derived data, + // otherwise derive more data. + if bytesNeeded := totalDerived - len(c.buf); bytesNeeded > 0 { + // It is common to derive multiple equally sized keys from the same HKDF instance. + // Optimize this case by allocating a buffer large enough to hold + // at least 5 of such keys each time there is not enough data. + blocks := bytesNeeded / c.hashLen + if bytesNeeded%c.hashLen != 0 { + // Round up to the next multiple of hashLen. + blocks += 1 + } + const minBlocks = 5 + if blocks < minBlocks { + blocks = minBlocks + } + alloc := blocks * c.hashLen + if len(c.buf)+alloc > maxDerived { + // The buffer can't grow beyond maxDerived. + alloc = maxDerived - len(c.buf) + } + c.buf = append(c.buf, make([]byte, alloc)...) + n, err := hkdfDerive(c.hkey, c.info, c.buf) + if err != nil { + c.buf = c.buf[:c.n] + return 0, err + } + // Adjust totalDerived to the actual number of bytes derived. + totalDerived = c.n + n } - runtime.KeepAlive(params) - n := copy(p, c.buf[prevLen:size]) + n := copy(p, c.buf[c.n:totalDerived]) + c.n += n return n, nil } @@ -108,7 +137,7 @@ func newHKDF(h func() hash.Hash, secret, salt []byte, info []byte) (*hkdf, error bcrypt.DestroyKey(kh) return nil, err } - k := &hkdf{kh, info, ch.Size(), nil} + k := &hkdf{kh, info, ch.Size(), nil, 0} runtime.SetFinalizer(k, (*hkdf).finalize) return k, nil } diff --git a/cng/hkdf_test.go b/cng/hkdf_test.go index c89e721..19bc233 100644 --- a/cng/hkdf_test.go +++ b/cng/hkdf_test.go @@ -394,15 +394,15 @@ func TestHKDFLimit(t *testing.T) { } } -func Benchmark32ByteSHA256Single(b *testing.B) { +func BenchmarkHKDF32ByteSHA256Single(b *testing.B) { benchmarkHKDFSingle(cng.NewSHA256, 32, b) } -func Benchmark8ByteSHA256Stream(b *testing.B) { +func BenchmarkHKDF8ByteSHA256Stream(b *testing.B) { benchmarkHKDFStream(cng.NewSHA256, 8, b) } -func Benchmark32ByteSHA256Stream(b *testing.B) { +func BenchmarkHKDF32ByteSHA256Stream(b *testing.B) { benchmarkHKDFStream(cng.NewSHA256, 32, b) }