From 74f51ccd1c54283bfed2e1b46e3d7392d560a75f Mon Sep 17 00:00:00 2001 From: qmuntal Date: Fri, 20 Dec 2024 10:54:17 +0100 Subject: [PATCH] deduplicate code --- cng/hash.go | 142 ++++++++++++++++++++++++++++++---------------------- cng/sha3.go | 66 +++++------------------- 2 files changed, 95 insertions(+), 113 deletions(-) diff --git a/cng/hash.go b/cng/hash.go index 35a9467..de06c90 100644 --- a/cng/hash.go +++ b/cng/hash.go @@ -16,6 +16,9 @@ import ( "github.com/microsoft/go-crypto-winnative/internal/bcrypt" ) +// maxHashSize is the size of SHA52 and SHA3_512, the largest hashes we support. +const maxHashSize = 64 + // SupportsHash returns true if a hash.Hash implementation is supported for h. func SupportsHash(h crypto.Hash) bool { switch h { @@ -145,11 +148,11 @@ func hashToID(h hash.Hash) string { return hx.alg.id } +// hashX implements [hash.Hash]. type hashX struct { - alg *hashAlgorithm - _ctx bcrypt.HASH_HANDLE // access it using withCtx + alg *hashAlgorithm + ctx bcrypt.HASH_HANDLE - buf []byte key []byte } @@ -160,37 +163,34 @@ func newHashX(id string, flag bcrypt.AlgorithmProviderFlags, key []byte) *hashX panic(err) } h := &hashX{alg: alg, key: bytes.Clone(key)} - // Don't allocate hx.buf nor call bcrypt.CreateHash yet, - // which would be wasteful if the caller only wants to know - // the hash type. This is a common pattern in this package, - // as some functions accept a `func() hash.Hash` parameter - // and call it just to know the hash type. - runtime.SetFinalizer(h, (*hashX).finalize) + // Don't call bcrypt.CreateHash yet, it would be wasteful + // if the caller only wants to know the hash type. This + // is a common pattern in this package, as some functions + // accept a `func() hash.Hash` parameter and call it just + // to know the hash type. return h } func (h *hashX) finalize() { - if h._ctx != 0 { - bcrypt.DestroyHash(h._ctx) - } + bcrypt.DestroyHash(h.ctx) } -func (h *hashX) withCtx(fn func(ctx bcrypt.HASH_HANDLE) error) error { +func (h *hashX) init() { defer runtime.KeepAlive(h) - if h._ctx == 0 { - err := bcrypt.CreateHash(h.alg.handle, &h._ctx, nil, h.key, 0) - if err != nil { - panic(err) - } + if h.ctx != 0 { + return + } + err := bcrypt.CreateHash(h.alg.handle, &h.ctx, nil, h.key, bcrypt.HASH_REUSABLE_FLAG) + if err != nil { + panic(err) } - return fn(h._ctx) + runtime.SetFinalizer(h, (*hashX).finalize) } func (h *hashX) Clone() (hash.Hash, error) { + defer runtime.KeepAlive(h) h2 := &hashX{alg: h.alg, key: bytes.Clone(h.key)} - err := h.withCtx(func(ctx bcrypt.HASH_HANDLE) error { - return bcrypt.DuplicateHash(ctx, &h2._ctx, nil, 0) - }) + err := bcrypt.DuplicateHash(h.ctx, &h2.ctx, nil, 0) if err != nil { return nil, err } @@ -199,49 +199,37 @@ func (h *hashX) Clone() (hash.Hash, error) { } func (h *hashX) Reset() { - if h._ctx != 0 { - bcrypt.DestroyHash(h._ctx) - h._ctx = 0 + defer runtime.KeepAlive(h) + if h.ctx != 0 { + hashReset(h.ctx, h.Size()) } } func (h *hashX) Write(p []byte) (n int, err error) { - err = h.withCtx(func(ctx bcrypt.HASH_HANDLE) error { - for n < len(p) && err == nil { - nn := len32(p[n:]) - err = bcrypt.HashData(h._ctx, p[n:n+nn], 0) - n += nn - } - return err - }) - if err != nil { - // hash.Hash interface mandates Write should never return an error. - panic(err) - } + defer runtime.KeepAlive(h) + h.init() + hashData(h.ctx, p) return len(p), nil } func (h *hashX) WriteString(s string) (int, error) { - // TODO: use unsafe.StringData once we drop support - // for go1.19 and earlier. - hdr := (*struct { - Data *byte - Len int - })(unsafe.Pointer(&s)) - return h.Write(unsafe.Slice(hdr.Data, len(s))) + defer runtime.KeepAlive(h) + return h.Write(unsafe.Slice(unsafe.StringData(s), len(s))) } func (h *hashX) WriteByte(c byte) error { - err := h.withCtx(func(ctx bcrypt.HASH_HANDLE) error { - return bcrypt.HashDataRaw(h._ctx, &c, 1, 0) - }) - if err != nil { - // hash.Hash interface mandates Write should never return an error. - panic(err) - } + defer runtime.KeepAlive(h) + h.init() + hashByte(h.ctx, c) return nil } +func (h *hashX) Sum(in []byte) []byte { + defer runtime.KeepAlive(h) + h.init() + return hashSum(h.ctx, h.Size(), in) +} + func (h *hashX) Size() int { return int(h.alg.size) } @@ -250,21 +238,55 @@ func (h *hashX) BlockSize() int { return int(h.alg.blockSize) } -func (h *hashX) Sum(in []byte) []byte { +// hashData writes p to ctx. It panics on error. +func hashData(ctx bcrypt.HASH_HANDLE, p []byte) { + var n int + var err error + for n < len(p) && err == nil { + nn := len32(p[n:]) + err = bcrypt.HashData(ctx, p[n:n+nn], 0) + n += nn + } + if err != nil { + panic(err) + } +} + +// hashByte writes c to ctx. It panics on error. +func hashByte(ctx bcrypt.HASH_HANDLE, c byte) { + err := bcrypt.HashDataRaw(ctx, &c, 1, 0) + if err != nil { + panic(err) + } +} + +// hashSum writes the hash of ctx to in and returns the result. +// size is the size of the hash output. +// It panics on error. +func hashSum(ctx bcrypt.HASH_HANDLE, size int, in []byte) []byte { var ctx2 bcrypt.HASH_HANDLE - err := h.withCtx(func(ctx bcrypt.HASH_HANDLE) error { - return bcrypt.DuplicateHash(ctx, &ctx2, nil, 0) - }) + err := bcrypt.DuplicateHash(ctx, &ctx2, nil, 0) if err != nil { panic(err) } defer bcrypt.DestroyHash(ctx2) - if h.buf == nil { - h.buf = make([]byte, h.alg.size) - } - err = bcrypt.FinishHash(ctx2, h.buf, 0) + buf := make([]byte, size, maxHashSize) // explicit cap to allow stack allocation + err = bcrypt.FinishHash(ctx2, buf, 0) if err != nil { panic(err) } - return append(in, h.buf...) + return append(in, buf...) +} + +// hashReset resets the hash state of ctx. +// size is the size of the hash output. +// It panics on error. +func hashReset(ctx bcrypt.HASH_HANDLE, size int) { + // bcrypt.FinishHash expects the output buffer to match the hash size. + // We don't care about the output, so we just pass a stack-allocated buffer + // that is large enough to hold the largest hash size we support. + var discard [maxHashSize]byte + if err := bcrypt.FinishHash(ctx, discard[:size], 0); err != nil { + panic(err) + } } diff --git a/cng/sha3.go b/cng/sha3.go index 6136b2a..f3d4986 100644 --- a/cng/sha3.go +++ b/cng/sha3.go @@ -14,9 +14,6 @@ import ( "github.com/microsoft/go-crypto-winnative/internal/bcrypt" ) -// maxSHA3Size is the size of SHA3_512, the largest SHA3 hash we support. -const maxSHA3Size = 64 - // SumSHA3_256 returns the SHA3-256 checksum of the data. func SumSHA3_256(p []byte) (sum [32]byte) { if err := hashOneShot(bcrypt.SHA3_256_ALGORITHM, p, sum[:]); err != nil { @@ -123,28 +120,14 @@ func (h *DigestSHA3) Clone() (hash.Hash, error) { func (h *DigestSHA3) Reset() { defer runtime.KeepAlive(h) if h.ctx != 0 { - // bcrypt.FinishHash expects the output buffer to match the hash size. - // We don't care about the output, so we just pass a stack-allocated buffer - // that is large enough to hold the largest hash size we support. - var discard [maxSHA3Size]byte - if err := bcrypt.FinishHash(h.ctx, discard[:h.Size()], 0); err != nil { - panic(err) - } + hashReset(h.ctx, h.Size()) } } func (h *DigestSHA3) Write(p []byte) (n int, err error) { defer runtime.KeepAlive(h) h.init() - for n < len(p) && err == nil { - nn := len32(p[n:]) - err = bcrypt.HashData(h.ctx, p[n:n+nn], 0) - n += nn - } - if err != nil { - // hash.Hash interface mandates Write should never return an error. - panic(err) - } + hashData(h.ctx, p) return len(p), nil } @@ -156,14 +139,16 @@ func (h *DigestSHA3) WriteString(s string) (int, error) { func (h *DigestSHA3) WriteByte(c byte) error { defer runtime.KeepAlive(h) h.init() - err := bcrypt.HashDataRaw(h.ctx, &c, 1, 0) - if err != nil { - // hash.Hash interface mandates Write should never return an error. - panic(err) - } + hashByte(h.ctx, c) return nil } +func (h *DigestSHA3) Sum(in []byte) []byte { + defer runtime.KeepAlive(h) + h.init() + return hashSum(h.ctx, h.Size(), in) +} + func (h *DigestSHA3) Size() int { return int(h.alg.size) } @@ -172,23 +157,6 @@ func (h *DigestSHA3) BlockSize() int { return int(h.alg.blockSize) } -func (h *DigestSHA3) Sum(in []byte) []byte { - defer runtime.KeepAlive(h) - h.init() - var ctx2 bcrypt.HASH_HANDLE - err := bcrypt.DuplicateHash(h.ctx, &ctx2, nil, 0) - if err != nil { - panic(err) - } - defer bcrypt.DestroyHash(ctx2) - buf := make([]byte, h.alg.size, maxSHA3Size) // explicit cap to allow stack allocation - err = bcrypt.FinishHash(ctx2, buf, 0) - if err != nil { - panic(err) - } - return append(in, buf...) -} - // NewSHA3_256 returns a new SHA256 hash. func NewSHA3_256() *DigestSHA3 { return newDigestSHA3(bcrypt.SHA3_256_ALGORITHM) @@ -281,14 +249,7 @@ func (s *SHAKE) Write(p []byte) (n int, err error) { return 0, nil } defer runtime.KeepAlive(s) - for n < len(p) && err == nil { - nn := len32(p[n:]) - err = bcrypt.HashData(s.ctx, p[n:n+nn], 0) - n += nn - } - if err != nil { - panic(err) - } + hashData(s.ctx, p) return len(p), nil } @@ -314,10 +275,9 @@ func (s *SHAKE) Read(p []byte) (n int, err error) { // Reset resets the XOF to its initial state. func (s *SHAKE) Reset() { defer runtime.KeepAlive(s) - var discard [1]byte - if err := bcrypt.FinishHash(s.ctx, discard[:], 0); err != nil { - panic(err) - } + // SHAKE has a variable size, CNG doesn't change the size of the hash + // when resetting, so we can pass a small value here. + hashReset(s.ctx, 1) } // BlockSize returns the rate of the XOF.