diff --git a/cng/hash.go b/cng/hash.go index d894c07..bebbc99 100644 --- a/cng/hash.go +++ b/cng/hash.go @@ -151,11 +151,12 @@ func NewSHA3_512() hash.Hash { type hashAlgorithm struct { handle bcrypt.ALG_HANDLE + id string size uint32 blockSize uint32 } -func loadHash(id string, flags bcrypt.AlgorithmProviderFlags) (hashAlgorithm, error) { +func loadHash(id string, flags bcrypt.AlgorithmProviderFlags) (*hashAlgorithm, error) { v, err := loadOrStoreAlg(id, flags, "", func(h bcrypt.ALG_HANDLE) (interface{}, error) { size, err := getUint32(bcrypt.HANDLE(h), bcrypt.HASH_LENGTH) if err != nil { @@ -165,89 +166,108 @@ func loadHash(id string, flags bcrypt.AlgorithmProviderFlags) (hashAlgorithm, er if err != nil { return nil, err } - return hashAlgorithm{h, size, blockSize}, nil + return &hashAlgorithm{h, id, size, blockSize}, nil }) if err != nil { - return hashAlgorithm{}, err + return nil, err + } + return v.(*hashAlgorithm), nil +} + +// hashToID converts a hash.Hash implementation from this package +// to a CNG hash ID +func hashToID(h hash.Hash) string { + hx, ok := h.(*hashX) + if !ok { + return "" } - return v.(hashAlgorithm), nil + return hx.alg.id } type hashX struct { - h bcrypt.ALG_HANDLE - ctx bcrypt.HASH_HANDLE - size int - blockSize int - buf []byte - key []byte + alg *hashAlgorithm + _ctx bcrypt.HASH_HANDLE // access it using withCtx + + buf []byte + key []byte } +// newHashX returns a new hash.Hash using the specified algorithm. func newHashX(id string, flag bcrypt.AlgorithmProviderFlags, key []byte) *hashX { - h, err := loadHash(id, flag) + alg, err := loadHash(id, flag) if err != nil { panic(err) } - hx := new(hashX) - hx.h = h.handle - hx.size = int(h.size) - hx.blockSize = int(h.blockSize) - hx.buf = make([]byte, hx.size) + h := new(hashX) + h.alg = alg if len(key) > 0 { - hx.key = make([]byte, len(key)) - copy(hx.key, key) + h.key = make([]byte, len(key)) + copy(h.key, key) } - hx.Reset() - runtime.SetFinalizer(hx, (*hashX).finalize) - return hx + // Don't allocate hx.buf nor call bcrypt.CreateHash yet, + // which would be wasteful if the caller only wants to know + // the hash type. This is a common pattern in this package, + // as some functions accept a `func() hash.Hash` parameter + // and call it just to know the hash type. + runtime.SetFinalizer(h, (*hashX).finalize) + return h } func (h *hashX) finalize() { - if h.ctx != 0 { - bcrypt.DestroyHash(h.ctx) + if h._ctx != 0 { + bcrypt.DestroyHash(h._ctx) + } +} + +func (h *hashX) withCtx(fn func(ctx bcrypt.HASH_HANDLE) error) error { + defer runtime.KeepAlive(h) + if h._ctx == 0 { + err := bcrypt.CreateHash(h.alg.handle, &h._ctx, nil, h.key, 0) + if err != nil { + panic(err) + } } + return fn(h._ctx) } func (h *hashX) Clone() (hash.Hash, error) { h2 := &hashX{ - h: h.h, - size: h.size, - blockSize: h.blockSize, - buf: make([]byte, len(h.buf)), - key: make([]byte, len(h.key)), + alg: h.alg, + } + if h.key != nil { + h2.key = make([]byte, len(h.key)) + copy(h2.key, h.key) } - copy(h2.key, h.key) - err := bcrypt.DuplicateHash(h.ctx, &h2.ctx, nil, 0) + err := h.withCtx(func(ctx bcrypt.HASH_HANDLE) error { + return bcrypt.DuplicateHash(ctx, &h2._ctx, nil, 0) + }) if err != nil { return nil, err } runtime.SetFinalizer(h2, (*hashX).finalize) - runtime.KeepAlive(h) return h2, nil } func (h *hashX) Reset() { - if h.ctx != 0 { - bcrypt.DestroyHash(h.ctx) - h.ctx = 0 + if h._ctx != 0 { + bcrypt.DestroyHash(h._ctx) + h._ctx = 0 } - err := bcrypt.CreateHash(h.h, &h.ctx, nil, h.key, 0) - if err != nil { - panic(err) - } - runtime.KeepAlive(h) } func (h *hashX) Write(p []byte) (n int, err error) { - for n < len(p) && err == nil { - nn := len32(p[n:]) - err = bcrypt.HashData(h.ctx, p[n:n+nn], 0) - n += nn - } + err = h.withCtx(func(ctx bcrypt.HASH_HANDLE) error { + for n < len(p) && err == nil { + nn := len32(p[n:]) + err = bcrypt.HashData(h._ctx, p[n:n+nn], 0) + n += nn + } + return err + }) if err != nil { // hash.Hash interface mandates Write should never return an error. panic(err) } - runtime.KeepAlive(h) return len(p), nil } @@ -262,37 +282,39 @@ func (h *hashX) WriteString(s string) (int, error) { } func (h *hashX) WriteByte(c byte) error { - if err := bcrypt.HashDataRaw(h.ctx, &c, 1, 0); err != nil { + err := h.withCtx(func(ctx bcrypt.HASH_HANDLE) error { + return bcrypt.HashDataRaw(h._ctx, &c, 1, 0) + }) + if err != nil { // hash.Hash interface mandates Write should never return an error. panic(err) } - runtime.KeepAlive(h) return nil } func (h *hashX) Size() int { - return h.size + return int(h.alg.size) } func (h *hashX) BlockSize() int { - return h.blockSize + return int(h.alg.blockSize) } func (h *hashX) Sum(in []byte) []byte { - h.sum(h.buf) - return append(in, h.buf...) -} - -func (h *hashX) sum(out []byte) { var ctx2 bcrypt.HASH_HANDLE - err := bcrypt.DuplicateHash(h.ctx, &ctx2, nil, 0) + err := h.withCtx(func(ctx bcrypt.HASH_HANDLE) error { + return bcrypt.DuplicateHash(ctx, &ctx2, nil, 0) + }) if err != nil { panic(err) } defer bcrypt.DestroyHash(ctx2) - err = bcrypt.FinishHash(ctx2, out, 0) + if h.buf == nil { + h.buf = make([]byte, h.alg.size) + } + err = bcrypt.FinishHash(ctx2, h.buf, 0) if err != nil { panic(err) } - runtime.KeepAlive(h) + return append(in, h.buf...) } diff --git a/cng/hash_test.go b/cng/hash_test.go index 182a90f..ab66d4f 100644 --- a/cng/hash_test.go +++ b/cng/hash_test.go @@ -171,7 +171,7 @@ func TestHash_OneShot(t *testing.T) { } } -func BenchmarkHash8Bytes(b *testing.B) { +func BenchmarkSHA256_8Bytes(b *testing.B) { b.StopTimer() h := cng.NewSHA256() sum := make([]byte, h.Size()) @@ -188,7 +188,7 @@ func BenchmarkHash8Bytes(b *testing.B) { } } -func BenchmarkSHA256(b *testing.B) { +func BenchmarkSHA256_OneShot(b *testing.B) { b.StopTimer() size := 8 buf := make([]byte, size) diff --git a/cng/hmac.go b/cng/hmac.go index 76a80ce..2d9fd36 100644 --- a/cng/hmac.go +++ b/cng/hmac.go @@ -12,26 +12,6 @@ import ( "github.com/microsoft/go-crypto-winnative/internal/bcrypt" ) -// hashToID converts a hash.Hash implementation from this package -// to a CNG hash ID -func hashToID(h hash.Hash) string { - if _, ok := h.(*hashX); !ok { - return "" - } - var id string - switch h.Size() { - case 20: - id = bcrypt.SHA1_ALGORITHM - case 256 / 8: - id = bcrypt.SHA256_ALGORITHM - case 384 / 8: - id = bcrypt.SHA384_ALGORITHM - case 512 / 8: - id = bcrypt.SHA512_ALGORITHM - } - return id -} - // NewHMAC returns a new HMAC using BCrypt. // The function h must return a hash implemented by // CNG (for example, h could be cng.NewSHA256). diff --git a/cng/pbkdf2_test.go b/cng/pbkdf2_test.go index 2e2659f..5730ffa 100644 --- a/cng/pbkdf2_test.go +++ b/cng/pbkdf2_test.go @@ -8,8 +8,6 @@ package cng_test import ( "bytes" - "crypto/sha1" - "crypto/sha256" "hash" "testing" @@ -181,7 +179,7 @@ func TestPBKDF2NoSalt(t *testing.T) { var sink uint8 -func benchmark(b *testing.B, h func() hash.Hash) { +func benchmarkPBKDF2(b *testing.B, h func() hash.Hash) { password := make([]byte, h().Size()) salt := make([]byte, 8) var err error @@ -194,10 +192,10 @@ func benchmark(b *testing.B, h func() hash.Hash) { sink += password[0] } -func BenchmarkHMACSHA1(b *testing.B) { - benchmark(b, sha1.New) +func BenchmarkPBKDF2HMACSHA1(b *testing.B) { + benchmarkPBKDF2(b, cng.NewSHA1) } -func BenchmarkHMACSHA256(b *testing.B) { - benchmark(b, sha256.New) +func BenchmarkPBKDF2HMACSHA256(b *testing.B) { + benchmarkPBKDF2(b, cng.NewSHA256) }