Skip to content

Commit

Permalink
redefine ExpandHKDF as a one shot function
Browse files Browse the repository at this point in the history
  • Loading branch information
qmuntal committed Nov 20, 2024
1 parent 19f07bc commit 03e984e
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 177 deletions.
124 changes: 34 additions & 90 deletions cng/hkdf.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"encoding/binary"
"errors"
"hash"
"io"
"runtime"
"unsafe"

Expand All @@ -28,99 +27,23 @@ func loadHKDF() (bcrypt.ALG_HANDLE, error) {
})
}

type hkdf struct {
hkey bcrypt.KEY_HANDLE
info []byte

hashLen int
n int // count of bytes requested from Read
// buf contains the derived data.
// len(buf) can be larger than n, as Read may derive
// more data than requested and cache it in buf.
buf []byte
}

func (c *hkdf) finalize() {
bcrypt.DestroyKey(c.hkey)
}

func hkdfDerive(hkey bcrypt.KEY_HANDLE, info, out []byte) (int, error) {
var params *bcrypt.BufferDesc
if len(info) > 0 {
params = &bcrypt.BufferDesc{
Count: 1,
Buffers: &bcrypt.Buffer{
Length: uint32(len(info)),
Type: bcrypt.KDF_HKDF_INFO,
Data: uintptr(unsafe.Pointer(&info[0])),
},
}
defer runtime.KeepAlive(params)
}
var n uint32
err := bcrypt.KeyDerivation(hkey, params, out, &n, 0)
return int(n), err
}

func (c *hkdf) Read(p []byte) (int, error) {
// KeyDerivation doesn't support incremental output, each call
// derives the key from scratch and returns the requested bytes.
// To implement io.Reader, we need to ask for len(c.buf) + len(p)
// bytes and copy the last derived len(p) bytes to p.
maxDerived := 255 * c.hashLen
totalDerived := c.n + len(p)
// Check whether enough data can be derived.
if totalDerived > maxDerived {
return 0, errors.New("hkdf: entropy limit reached")
}
// Check whether c.buf already contains enough derived data,
// otherwise derive more data.
if bytesNeeded := totalDerived - len(c.buf); bytesNeeded > 0 {
// It is common to derive multiple equally sized keys from the same HKDF instance.
// Optimize this case by allocating a buffer large enough to hold
// at least 3 of such keys each time there is not enough data.
// Round up to the next multiple of hashLen.
blocks := (bytesNeeded-1)/c.hashLen + 1
const minBlocks = 3
if blocks < minBlocks {
blocks = minBlocks
}
alloc := blocks * c.hashLen
if len(c.buf)+alloc > maxDerived {
// The buffer can't grow beyond maxDerived.
alloc = maxDerived - len(c.buf)
}
c.buf = append(c.buf, make([]byte, alloc)...)
n, err := hkdfDerive(c.hkey, c.info, c.buf)
if err != nil {
c.buf = c.buf[:c.n]
return 0, err
}
// Adjust totalDerived to the actual number of bytes derived.
totalDerived = n
}
n := copy(p, c.buf[c.n:totalDerived])
c.n += n
return n, nil
}

func newHKDF(h func() hash.Hash, secret, salt []byte, info []byte) (*hkdf, error) {
func newHKDF(h func() hash.Hash, secret, salt []byte, info []byte) (bcrypt.KEY_HANDLE, error) {
ch := h()
hashID := hashToID(ch)
if hashID == "" {
return nil, errors.New("cng: unsupported hash function")
return 0, errors.New("cng: unsupported hash function")
}
alg, err := loadHKDF()
if err != nil {
return nil, err
return 0, err
}
var kh bcrypt.KEY_HANDLE
if err := bcrypt.GenerateSymmetricKey(alg, &kh, nil, secret, 0); err != nil {
return nil, err
return 0, err
}
if err := setString(bcrypt.HANDLE(kh), bcrypt.HKDF_HASH_ALGORITHM, hashID); err != nil {
bcrypt.DestroyKey(kh)
return nil, err
return 0, err
}
if salt != nil {
// Used for Extract.
Expand All @@ -131,11 +54,9 @@ func newHKDF(h func() hash.Hash, secret, salt []byte, info []byte) (*hkdf, error
}
if err != nil {
bcrypt.DestroyKey(kh)
return nil, err
return 0, err
}
k := &hkdf{kh, info, ch.Size(), 0, nil}
runtime.SetFinalizer(k, (*hkdf).finalize)
return k, nil
return kh, nil
}

func ExtractHKDF(h func() hash.Hash, secret, salt []byte) ([]byte, error) {
Expand All @@ -147,11 +68,11 @@ func ExtractHKDF(h func() hash.Hash, secret, salt []byte) ([]byte, error) {
if err != nil {
return nil, err
}
hdr, blob, err := exportKeyData(kh.hkey)
defer bcrypt.DestroyKey(kh)
hdr, blob, err := exportKeyData(kh)
if err != nil {
return nil, err
}
runtime.KeepAlive(kh)
if hdr.Version != bcrypt.KEY_DATA_BLOB_VERSION1 {
return nil, errors.New("cng: unknown key data blob version")
}
Expand All @@ -171,10 +92,33 @@ func ExtractHKDF(h func() hash.Hash, secret, salt []byte) ([]byte, error) {
return blob[cbHashName:], nil
}

func ExpandHKDF(h func() hash.Hash, pseudorandomKey, info []byte) (io.Reader, error) {
// ExpandHKDF derives a key from the given hash, key, and optional context info.
func ExpandHKDF(h func() hash.Hash, pseudorandomKey, info []byte, keyLength int) ([]byte, error) {
kh, err := newHKDF(h, pseudorandomKey, nil, info)
if err != nil {
return nil, err
}
return kh, nil
defer bcrypt.DestroyKey(kh)
out := make([]byte, keyLength)
var params *bcrypt.BufferDesc
if len(info) > 0 {
params = &bcrypt.BufferDesc{
Count: 1,
Buffers: &bcrypt.Buffer{
Length: uint32(len(info)),
Type: bcrypt.KDF_HKDF_INFO,
Data: uintptr(unsafe.Pointer(&info[0])),
},
}
defer runtime.KeepAlive(params)
}
var n uint32
err = bcrypt.KeyDerivation(kh, params, out, &n, 0)
if err != nil {
return nil, err
}
if int(n) != keyLength {
return nil, errors.New("cng: key derivation returned unexpected length")
}
return out, err
}
112 changes: 25 additions & 87 deletions cng/hkdf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ package cng_test
import (
"bytes"
"hash"
"io"
"testing"

"github.com/microsoft/go-crypto-winnative/cng"
Expand Down Expand Up @@ -295,16 +294,12 @@ var hkdfTests = []hkdfTest{
},
}

func newHKDF(hash func() hash.Hash, secret, salt, info []byte) io.Reader {
func newHKDF(hash func() hash.Hash, secret, salt, info []byte, keyLength int) ([]byte, error) {
prk, err := cng.ExtractHKDF(hash, secret, salt)
if err != nil {
panic(err)
return nil, err
}
r, err := cng.ExpandHKDF(hash, prk, info)
if err != nil {
panic(err)
}
return r
return cng.ExpandHKDF(hash, prk, info, keyLength)
}

func TestHKDF(t *testing.T) {
Expand All @@ -320,122 +315,65 @@ func TestHKDF(t *testing.T) {
t.Errorf("test %d: incorrect PRK: have %v, need %v.", i, prk, tt.prk)
}

hkdf := newHKDF(tt.hash, tt.master, tt.salt, tt.info)
out := make([]byte, len(tt.out))

n, err := io.ReadFull(hkdf, out)
if n != len(tt.out) || err != nil {
t.Errorf("test %d: not enough output bytes: %d.", i, n)
out, err := newHKDF(tt.hash, tt.master, tt.salt, tt.info, len(tt.out))
if err != nil {
t.Errorf("test %d: error generating HKDF: %v.", i, err)
}

if !bytes.Equal(out, tt.out) {
t.Errorf("test %d: incorrect output: have %v, need %v.", i, out, tt.out)
}

hkdf, err = cng.ExpandHKDF(tt.hash, prk, tt.info)
out, err = cng.ExpandHKDF(tt.hash, prk, tt.info, len(tt.out))
if err != nil {
t.Errorf("test %d: error expanding HKDF: %v.", i, err)
}

n, err = io.ReadFull(hkdf, out)
if n != len(tt.out) || err != nil {
t.Errorf("test %d: not enough output bytes from Expand: %d.", i, n)
}

if !bytes.Equal(out, tt.out) {
t.Errorf("test %d: incorrect output from Expand: have %v, need %v.", i, out, tt.out)
}
}
}

func TestHKDFMultiRead(t *testing.T) {
if !cng.SupportsHKDF() {
t.Skip("HKDF is not supported")
}
for i, tt := range hkdfTests {
hkdf := newHKDF(tt.hash, tt.master, tt.salt, tt.info)
out := make([]byte, len(tt.out))

for b := 0; b < len(tt.out); b++ {
n, err := io.ReadFull(hkdf, out[b:b+1])
if n != 1 || err != nil {
t.Errorf("test %d.%d: not enough output bytes: have %d, need %d .", i, b, n, len(tt.out))
}
}

if !bytes.Equal(out, tt.out) {
t.Errorf("test %d: incorrect output: have %v, need %v.", i, out, tt.out)
}
}
}

func TestHKDFLimit(t *testing.T) {
func TestExpandHKDFOneShotLimit(t *testing.T) {
if !cng.SupportsHKDF() {
t.Skip("HKDF is not supported")
}
hash := cng.NewSHA1
master := []byte{0x00, 0x01, 0x02, 0x03}
info := []byte{}

hkdf := newHKDF(hash, master, nil, info)
prk, err := cng.ExtractHKDF(hash, master, nil)
if err != nil {
t.Fatalf("error extracting HKDF: %v.", err)
}
limit := hash().Size() * 255
out := make([]byte, limit)

// The maximum output bytes should be extractable
n, err := io.ReadFull(hkdf, out)
if n != limit || err != nil {
t.Errorf("not enough output bytes: %d, %v.", n, err)
out, err := cng.ExpandHKDF(hash, prk, info, limit)
if err != nil {
t.Errorf("error expanding HKDF one-shot: %v.", err)
}

// Reading one more should fail
n, err = io.ReadFull(hkdf, make([]byte, 1))
if n > 0 || err == nil {
t.Errorf("key expansion overflowed: n = %d, err = %v", n, err)
if len(out) != limit {
t.Errorf("incorrect output length: have %d, need %d.", len(out), limit)
}
}

func BenchmarkHKDF32ByteSHA256Single(b *testing.B) {
benchmarkHKDFSingle(cng.NewSHA256, 32, b)
}

func BenchmarkHKDF8ByteSHA256Stream(b *testing.B) {
benchmarkHKDFStream(cng.NewSHA256, 8, b)
}

func BenchmarkHKDF32ByteSHA256Stream(b *testing.B) {
benchmarkHKDFStream(cng.NewSHA256, 32, b)
}

func benchmarkHKDFSingle(hasher func() hash.Hash, block int, b *testing.B) {
master := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}
salt := []byte{0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17}
info := []byte{0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27}
out := make([]byte, block)

b.SetBytes(int64(block))
b.ResetTimer()

for i := 0; i < b.N; i++ {
hkdf := newHKDF(hasher, master, salt, info)
io.ReadFull(hkdf, out)
// Expanding one more byte should fail
_, err = cng.ExpandHKDF(hash, prk, info, limit+1)
if err == nil {
t.Errorf("expected error for key expansion overflow")
}
}

func benchmarkHKDFStream(hasher func() hash.Hash, block int, b *testing.B) {
func BenchmarkHKDF32ByteSHA256Single(b *testing.B) {
master := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}
salt := []byte{0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17}
info := []byte{0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27}
out := make([]byte, block)
const block = 32

b.SetBytes(int64(block))
b.ResetTimer()

hkdf := newHKDF(hasher, master, salt, info)
for i := 0; i < b.N; i++ {
_, err := io.ReadFull(hkdf, out)
_, err := newHKDF(cng.NewSHA256, master, salt, info, block)
if err != nil {
hkdf = newHKDF(hasher, master, salt, info)
i--
b.Error(err)
}
}
}

0 comments on commit 03e984e

Please sign in to comment.