Skip to content

Commit

Permalink
Merge pull request #48 from microsoft/opthkdf
Browse files Browse the repository at this point in the history
Optimize multiple hkdf reads case
  • Loading branch information
qmuntal authored Oct 11, 2023
2 parents 968dd30 + 576c2f6 commit 739160e
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 22 deletions.
67 changes: 48 additions & 19 deletions cng/hkdf.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,45 +37,74 @@ type hkdf struct {
info []byte

hashLen int
buf []byte
n int // count of bytes requested from Read
// buf contains the derived data.
// len(buf) can be larger than n, as Read may derive
// more data than requested and cache it in buf.
buf []byte
}

func (c *hkdf) finalize() {
bcrypt.DestroyKey(c.hkey)
}

func (c *hkdf) Read(p []byte) (int, error) {
func hkdfDerive(hkey bcrypt.KEY_HANDLE, info, out []byte) (int, error) {
var params *bcrypt.BufferDesc
if len(c.info) > 0 {
if len(info) > 0 {
params = &bcrypt.BufferDesc{
Count: 1,
Buffers: &bcrypt.Buffer{
Length: uint32(len(c.info)),
Length: uint32(len(info)),
Type: bcrypt.KDF_HKDF_INFO,
Data: uintptr(unsafe.Pointer(&c.info[0])),
Data: uintptr(unsafe.Pointer(&info[0])),
},
}
defer runtime.KeepAlive(params)
}
var n uint32
err := bcrypt.KeyDerivation(hkey, params, out, &n, 0)
return int(n), err
}

func (c *hkdf) Read(p []byte) (int, error) {
// KeyDerivation doesn't support incremental output, each call
// derives the key from scratch and returns the requested bytes.
// To implement io.Reader, we need to ask for len(c.buf) + len(p)
// bytes and copy the last derived len(p) bytes to p.
// We use c.buf to know how many bytes we've already derived and
// to avoid allocating the whole output buffer on each call.
prevLen := len(c.buf)
needLen := len(p)
remains := 255*c.hashLen - prevLen
// Check whether enough data can be generated.
if remains < needLen {
maxDerived := 255 * c.hashLen
totalDerived := c.n + len(p)
// Check whether enough data can be derived.
if totalDerived > maxDerived {
return 0, errors.New("hkdf: entropy limit reached")
}
c.buf = append(c.buf, make([]byte, needLen)...)
var size uint32
if err := bcrypt.KeyDerivation(c.hkey, params, c.buf, &size, 0); err != nil {
return 0, err
// Check whether c.buf already contains enough derived data,
// otherwise derive more data.
if bytesNeeded := totalDerived - len(c.buf); bytesNeeded > 0 {
// It is common to derive multiple equally sized keys from the same HKDF instance.
// Optimize this case by allocating a buffer large enough to hold
// at least 3 of such keys each time there is not enough data.
// Round up to the next multiple of hashLen.
blocks := (bytesNeeded-1)/c.hashLen + 1
const minBlocks = 3
if blocks < minBlocks {
blocks = minBlocks
}
alloc := blocks * c.hashLen
if len(c.buf)+alloc > maxDerived {
// The buffer can't grow beyond maxDerived.
alloc = maxDerived - len(c.buf)
}
c.buf = append(c.buf, make([]byte, alloc)...)
n, err := hkdfDerive(c.hkey, c.info, c.buf)
if err != nil {
c.buf = c.buf[:c.n]
return 0, err
}
// Adjust totalDerived to the actual number of bytes derived.
totalDerived = n
}
runtime.KeepAlive(params)
n := copy(p, c.buf[prevLen:size])
n := copy(p, c.buf[c.n:totalDerived])
c.n += n
return n, nil
}

Expand Down Expand Up @@ -108,7 +137,7 @@ func newHKDF(h func() hash.Hash, secret, salt []byte, info []byte) (*hkdf, error
bcrypt.DestroyKey(kh)
return nil, err
}
k := &hkdf{kh, info, ch.Size(), nil}
k := &hkdf{kh, info, ch.Size(), 0, nil}
runtime.SetFinalizer(k, (*hkdf).finalize)
return k, nil
}
Expand Down
6 changes: 3 additions & 3 deletions cng/hkdf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,15 +394,15 @@ func TestHKDFLimit(t *testing.T) {
}
}

func Benchmark32ByteSHA256Single(b *testing.B) {
func BenchmarkHKDF32ByteSHA256Single(b *testing.B) {
benchmarkHKDFSingle(cng.NewSHA256, 32, b)
}

func Benchmark8ByteSHA256Stream(b *testing.B) {
func BenchmarkHKDF8ByteSHA256Stream(b *testing.B) {
benchmarkHKDFStream(cng.NewSHA256, 8, b)
}

func Benchmark32ByteSHA256Stream(b *testing.B) {
func BenchmarkHKDF32ByteSHA256Stream(b *testing.B) {
benchmarkHKDFStream(cng.NewSHA256, 32, b)
}

Expand Down

0 comments on commit 739160e

Please sign in to comment.