From 1bb252f57c966a63ae1514cc6644d92f1a587307 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Thu, 26 Sep 2024 11:26:20 +0200 Subject: [PATCH 1/3] update loadOrStoreAlg to accept type arguments --- cng/cipher.go | 10 +++------ cng/cng.go | 17 ++++++++------ cng/dsa.go | 8 ++----- cng/ecdh.go | 61 +++++++++++++++++++++++--------------------------- cng/ecdsa.go | 58 +++++++++++++++++++++-------------------------- cng/hash.go | 6 +---- cng/hkdf.go | 6 +---- cng/pbkdf2.go | 6 +---- cng/rsa.go | 8 ++----- cng/tls1prf.go | 6 +---- 10 files changed, 75 insertions(+), 111 deletions(-) diff --git a/cng/cipher.go b/cng/cipher.go index 61f5dc8..c1365f8 100644 --- a/cng/cipher.go +++ b/cng/cipher.go @@ -18,25 +18,21 @@ type cipherAlgorithm struct { } func loadCipher(id, mode string) (cipherAlgorithm, error) { - v, err := loadOrStoreAlg(id, bcrypt.ALG_NONE_FLAG, mode, func(h bcrypt.ALG_HANDLE) (interface{}, error) { + return loadOrStoreAlg(id, bcrypt.ALG_NONE_FLAG, mode, func(h bcrypt.ALG_HANDLE) (cipherAlgorithm, error) { if mode != "" { // Windows 8 added support to set the CipherMode value on a key, // but Windows 7 requires that it be set on the algorithm before key creation. err := setString(bcrypt.HANDLE(h), bcrypt.CHAINING_MODE, mode) if err != nil { - return nil, err + return cipherAlgorithm{}, err } } lengths, err := getKeyLengths(bcrypt.HANDLE(h)) if err != nil { - return nil, err + return cipherAlgorithm{}, err } return cipherAlgorithm{h, lengths}, nil }) - if err != nil { - return cipherAlgorithm{}, nil - } - return v.(cipherAlgorithm), nil } func newCipherHandle(id, mode string, key []byte) (bcrypt.KEY_HANDLE, error) { diff --git a/cng/cng.go b/cng/cng.go index 844c087..29255e7 100644 --- a/cng/cng.go +++ b/cng/cng.go @@ -36,9 +36,10 @@ func len32(s []byte) int { var algCache sync.Map -type newAlgEntryFn func(h bcrypt.ALG_HANDLE) (interface{}, error) - -func loadOrStoreAlg(id string, flags bcrypt.AlgorithmProviderFlags, mode string, fn newAlgEntryFn) (interface{}, error) { +// loadOrStoreAlg loads an algorithm with the given id, flags, and mode from the cache. +// If the algorithm is not in the cache, a new one is created and then initialized using fn. +// The returned algorithm handle should not be closed by the caller. +func loadOrStoreAlg[T any](id string, flags bcrypt.AlgorithmProviderFlags, mode string, fn func(h bcrypt.ALG_HANDLE) (T, error)) (T, error) { var entryKey = struct { id string flags bcrypt.AlgorithmProviderFlags @@ -46,22 +47,24 @@ func loadOrStoreAlg(id string, flags bcrypt.AlgorithmProviderFlags, mode string, }{id, flags, mode} if v, ok := algCache.Load(entryKey); ok { - return v, nil + return v.(T), nil } var h bcrypt.ALG_HANDLE err := bcrypt.OpenAlgorithmProvider(&h, utf16PtrFromString(id), nil, flags) if err != nil { - return nil, err + var t T + return t, err } v, err := fn(h) if err != nil { bcrypt.CloseAlgorithmProvider(h, 0) - return nil, err + var t T + return t, err } if existing, loaded := algCache.LoadOrStore(entryKey, v); loaded { // We can safely use a provider that has already been cached in another concurrent goroutine. bcrypt.CloseAlgorithmProvider(h, 0) - v = existing + v = existing.(T) } return v, nil } diff --git a/cng/dsa.go b/cng/dsa.go index bd3e19d..5d4d397 100644 --- a/cng/dsa.go +++ b/cng/dsa.go @@ -39,17 +39,13 @@ type dsaAlgorithm struct { } func loadDSA() (h dsaAlgorithm, err error) { - v, err := loadOrStoreAlg(bcrypt.DSA_ALGORITHM, bcrypt.ALG_NONE_FLAG, "", func(h bcrypt.ALG_HANDLE) (interface{}, error) { + return loadOrStoreAlg(bcrypt.DSA_ALGORITHM, bcrypt.ALG_NONE_FLAG, "", func(h bcrypt.ALG_HANDLE) (dsaAlgorithm, error) { lengths, err := getKeyLengths(bcrypt.HANDLE(h)) if err != nil { - return nil, err + return dsaAlgorithm{}, err } return dsaAlgorithm{h, lengths}, nil }) - if err != nil { - return dsaAlgorithm{}, err - } - return v.(dsaAlgorithm), nil } // DSAParameters contains the DSA parameters. diff --git a/cng/ecdh.go b/cng/ecdh.go index cd6e9a9..2738728 100644 --- a/cng/ecdh.go +++ b/cng/ecdh.go @@ -20,36 +20,31 @@ var errInvalidPrivateKey = errors.New("cng: invalid private key") type ecdhAlgorithm struct { handle bcrypt.ALG_HANDLE + bits uint32 } -func loadECDH(curve string) (h ecdhAlgorithm, bits uint32, err error) { - var id string - switch curve { - case "P-256": - id, bits = bcrypt.ECC_CURVE_NISTP256, 256 - case "P-384": - id, bits = bcrypt.ECC_CURVE_NISTP384, 384 - case "P-521": - id, bits = bcrypt.ECC_CURVE_NISTP521, 521 - case "X25519": - id, bits = bcrypt.ECC_CURVE_25519, 255 - default: - err = errUnknownCurve - } - if err != nil { - return - } - v, err := loadOrStoreAlg(bcrypt.ECDH_ALGORITHM, bcrypt.ALG_NONE_FLAG, id, func(h bcrypt.ALG_HANDLE) (interface{}, error) { - err := setString(bcrypt.HANDLE(h), bcrypt.ECC_CURVE_NAME, id) +func loadECDH(curve string) (ecdhAlgorithm, error) { + return loadOrStoreAlg(bcrypt.ECDH_ALGORITHM, bcrypt.ALG_NONE_FLAG, curve, func(h bcrypt.ALG_HANDLE) (ecdhAlgorithm, error) { + var name string + var bits uint32 + switch curve { + case "P-256": + name, bits = bcrypt.ECC_CURVE_NISTP256, 256 + case "P-384": + name, bits = bcrypt.ECC_CURVE_NISTP384, 384 + case "P-521": + name, bits = bcrypt.ECC_CURVE_NISTP521, 521 + case "X25519": + name, bits = bcrypt.ECC_CURVE_25519, 255 + default: + return ecdhAlgorithm{}, errUnknownCurve + } + err := setString(bcrypt.HANDLE(h), bcrypt.ECC_CURVE_NAME, name) if err != nil { - return nil, err + return ecdhAlgorithm{}, err } - return ecdhAlgorithm{h}, nil + return ecdhAlgorithm{h, bits}, nil }) - if err != nil { - return ecdhAlgorithm{}, 0, err - } - return v.(ecdhAlgorithm), bits, nil } type PublicKeyECDH struct { @@ -115,12 +110,12 @@ func ECDH(priv *PrivateKeyECDH, pub *PublicKeyECDH) ([]byte, error) { } func GenerateKeyECDH(curve string) (*PrivateKeyECDH, []byte, error) { - h, bits, err := loadECDH(curve) + h, err := loadECDH(curve) if err != nil { return nil, nil, err } var hkey bcrypt.KEY_HANDLE - err = bcrypt.GenerateKeyPair(h.handle, &hkey, bits, 0) + err = bcrypt.GenerateKeyPair(h.handle, &hkey, h.bits, 0) if err != nil { return nil, nil, err } @@ -154,7 +149,7 @@ func NewPublicKeyECDH(curve string, bytes []byte) (*PublicKeyECDH, error) { if len(bytes) == 0 || (nist && bytes[0] != ecdhUncompressedPrefix) { return nil, errInvalidPublicKey } - h, bits, err := loadECDH(curve) + h, err := loadECDH(curve) if err != nil { return nil, err } @@ -169,11 +164,11 @@ func NewPublicKeyECDH(curve string, bytes []byte) (*PublicKeyECDH, error) { ncomponents = 1 keyWithoutEncoding = bytes } - keySize := int(bits+7) / 8 + keySize := int(h.bits+7) / 8 if len(keyWithoutEncoding) != keySize*ncomponents { return nil, errInvalidPublicKey } - hkey, err := importECCKey(h.handle, bcrypt.ECDH_ALGORITHM, bits, keyWithoutEncoding[:keySize], keyWithoutEncoding[keySize:], nil) + hkey, err := importECCKey(h.handle, bcrypt.ECDH_ALGORITHM, h.bits, keyWithoutEncoding[:keySize], keyWithoutEncoding[keySize:], nil) if err != nil { return nil, err } @@ -185,11 +180,11 @@ func NewPublicKeyECDH(curve string, bytes []byte) (*PublicKeyECDH, error) { func (k *PublicKeyECDH) Bytes() []byte { return k.bytes } func NewPrivateKeyECDH(curve string, key []byte) (*PrivateKeyECDH, error) { - h, bits, err := loadECDH(curve) + h, err := loadECDH(curve) if err != nil { return nil, err } - keySize := int(bits+7) / 8 + keySize := int(h.bits+7) / 8 if len(key) != keySize { return nil, errInvalidPrivateKey } @@ -202,7 +197,7 @@ func NewPrivateKeyECDH(curve string, key []byte) (*PrivateKeyECDH, error) { // To trigger this behavior we pass a zeroed X/Y with keySize length. // zero is big enough to fit P-521 curves, the largest we handle, in the stack. var zero [(521 + 7) / 8]byte - hkey, err := importECCKey(h.handle, bcrypt.ECDH_ALGORITHM, bits, zero[:keySize], zero[:keySize], key) + hkey, err := importECCKey(h.handle, bcrypt.ECDH_ALGORITHM, h.bits, zero[:keySize], zero[:keySize], key) if err != nil { return nil, err } diff --git a/cng/ecdsa.go b/cng/ecdsa.go index a77ff97..586e9ae 100644 --- a/cng/ecdsa.go +++ b/cng/ecdsa.go @@ -17,47 +17,41 @@ var errUnknownCurve = errors.New("cng: unknown elliptic curve") type ecdsaAlgorithm struct { handle bcrypt.ALG_HANDLE + bits uint32 } -func loadECDSA(curve string) (h ecdsaAlgorithm, bits uint32, err error) { - var id string - switch curve { - case "P-224": - id, bits = bcrypt.ECC_CURVE_NISTP224, 224 - case "P-256": - id, bits = bcrypt.ECC_CURVE_NISTP256, 256 - case "P-384": - id, bits = bcrypt.ECC_CURVE_NISTP384, 384 - case "P-521": - id, bits = bcrypt.ECC_CURVE_NISTP521, 521 - default: - err = errUnknownCurve - } - if err != nil { - return - } - v, err := loadOrStoreAlg(bcrypt.ECDSA_ALGORITHM, bcrypt.ALG_NONE_FLAG, id, func(h bcrypt.ALG_HANDLE) (interface{}, error) { - err := setString(bcrypt.HANDLE(h), bcrypt.ECC_CURVE_NAME, id) +func loadECDSA(curve string) (ecdsaAlgorithm, error) { + return loadOrStoreAlg(bcrypt.ECDSA_ALGORITHM, bcrypt.ALG_NONE_FLAG, curve, func(h bcrypt.ALG_HANDLE) (ecdsaAlgorithm, error) { + var name string + var bits uint32 + switch curve { + case "P-224": + name, bits = bcrypt.ECC_CURVE_NISTP224, 224 + case "P-256": + name, bits = bcrypt.ECC_CURVE_NISTP256, 256 + case "P-384": + name, bits = bcrypt.ECC_CURVE_NISTP384, 384 + case "P-521": + name, bits = bcrypt.ECC_CURVE_NISTP521, 521 + default: + return ecdsaAlgorithm{}, errUnknownCurve + } + err := setString(bcrypt.HANDLE(h), bcrypt.ECC_CURVE_NAME, name) if err != nil { - return nil, err + return ecdsaAlgorithm{}, err } - return ecdsaAlgorithm{h}, nil + return ecdsaAlgorithm{h, bits}, nil }) - if err != nil { - return ecdsaAlgorithm{}, 0, err - } - return v.(ecdsaAlgorithm), bits, nil } func GenerateKeyECDSA(curve string) (X, Y, D BigInt, err error) { var h ecdsaAlgorithm - var bits uint32 - h, bits, err = loadECDSA(curve) + h, err = loadECDSA(curve) if err != nil { return } var hkey bcrypt.KEY_HANDLE - err = bcrypt.GenerateKeyPair(h.handle, &hkey, bits, 0) + err = bcrypt.GenerateKeyPair(h.handle, &hkey, h.bits, 0) if err != nil { return } @@ -87,11 +81,11 @@ type PublicKeyECDSA struct { } func NewPublicKeyECDSA(curve string, X, Y BigInt) (*PublicKeyECDSA, error) { - h, bits, err := loadECDSA(curve) + h, err := loadECDSA(curve) if err != nil { return nil, err } - hkey, err := importECCKey(h.handle, bcrypt.ECDSA_ALGORITHM, bits, X, Y, nil) + hkey, err := importECCKey(h.handle, bcrypt.ECDSA_ALGORITHM, h.bits, X, Y, nil) if err != nil { return nil, err } @@ -109,11 +103,11 @@ type PrivateKeyECDSA struct { } func NewPrivateKeyECDSA(curve string, X, Y, D BigInt) (*PrivateKeyECDSA, error) { - h, bits, err := loadECDSA(curve) + h, err := loadECDSA(curve) if err != nil { return nil, err } - hkey, err := importECCKey(h.handle, bcrypt.ECDSA_ALGORITHM, bits, X, Y, D) + hkey, err := importECCKey(h.handle, bcrypt.ECDSA_ALGORITHM, h.bits, X, Y, D) if err != nil { return nil, err } diff --git a/cng/hash.go b/cng/hash.go index bebbc99..c4f01e1 100644 --- a/cng/hash.go +++ b/cng/hash.go @@ -157,7 +157,7 @@ type hashAlgorithm struct { } func loadHash(id string, flags bcrypt.AlgorithmProviderFlags) (*hashAlgorithm, error) { - v, err := loadOrStoreAlg(id, flags, "", func(h bcrypt.ALG_HANDLE) (interface{}, error) { + return loadOrStoreAlg(id, flags, "", func(h bcrypt.ALG_HANDLE) (*hashAlgorithm, error) { size, err := getUint32(bcrypt.HANDLE(h), bcrypt.HASH_LENGTH) if err != nil { return nil, err @@ -168,10 +168,6 @@ func loadHash(id string, flags bcrypt.AlgorithmProviderFlags) (*hashAlgorithm, e } return &hashAlgorithm{h, id, size, blockSize}, nil }) - if err != nil { - return nil, err - } - return v.(*hashAlgorithm), nil } // hashToID converts a hash.Hash implementation from this package diff --git a/cng/hkdf.go b/cng/hkdf.go index 6f164ce..914b669 100644 --- a/cng/hkdf.go +++ b/cng/hkdf.go @@ -23,13 +23,9 @@ func SupportsHKDF() bool { } func loadHKDF() (bcrypt.ALG_HANDLE, error) { - h, err := loadOrStoreAlg(bcrypt.HKDF_ALGORITHM, 0, "", func(h bcrypt.ALG_HANDLE) (interface{}, error) { + return loadOrStoreAlg(bcrypt.HKDF_ALGORITHM, 0, "", func(h bcrypt.ALG_HANDLE) (bcrypt.ALG_HANDLE, error) { return h, nil }) - if err != nil { - return 0, err - } - return h.(bcrypt.ALG_HANDLE), nil } type hkdf struct { diff --git a/cng/pbkdf2.go b/cng/pbkdf2.go index 42614c9..49f1ce2 100644 --- a/cng/pbkdf2.go +++ b/cng/pbkdf2.go @@ -15,13 +15,9 @@ import ( ) func loadPBKDF2() (bcrypt.ALG_HANDLE, error) { - h, err := loadOrStoreAlg(bcrypt.PBKDF2_ALGORITHM, 0, "", func(h bcrypt.ALG_HANDLE) (interface{}, error) { + return loadOrStoreAlg(bcrypt.PBKDF2_ALGORITHM, 0, "", func(h bcrypt.ALG_HANDLE) (bcrypt.ALG_HANDLE, error) { return h, nil }) - if err != nil { - return 0, err - } - return h.(bcrypt.ALG_HANDLE), nil } func PBKDF2(password, salt []byte, iter, keyLen int, h func() hash.Hash) ([]byte, error) { diff --git a/cng/rsa.go b/cng/rsa.go index 7e3f7ab..e9e2a09 100644 --- a/cng/rsa.go +++ b/cng/rsa.go @@ -22,17 +22,13 @@ type rsaAlgorithm struct { } func loadRsa() (rsaAlgorithm, error) { - v, err := loadOrStoreAlg(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG, "", func(h bcrypt.ALG_HANDLE) (interface{}, error) { + return loadOrStoreAlg(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG, "", func(h bcrypt.ALG_HANDLE) (rsaAlgorithm, error) { lengths, err := getKeyLengths(bcrypt.HANDLE(h)) if err != nil { - return nil, err + return rsaAlgorithm{}, err } return rsaAlgorithm{h, lengths}, nil }) - if err != nil { - return rsaAlgorithm{}, err - } - return v.(rsaAlgorithm), nil } func GenerateKeyRSA(bits int) (N, E, D, P, Q, Dp, Dq, Qinv BigInt, err error) { diff --git a/cng/tls1prf.go b/cng/tls1prf.go index 30ef224..c4e08b5 100644 --- a/cng/tls1prf.go +++ b/cng/tls1prf.go @@ -15,13 +15,9 @@ import ( ) func loadTLS1PRF(id string) (bcrypt.ALG_HANDLE, error) { - h, err := loadOrStoreAlg(id, 0, "", func(h bcrypt.ALG_HANDLE) (interface{}, error) { + return loadOrStoreAlg(id, 0, "", func(h bcrypt.ALG_HANDLE) (bcrypt.ALG_HANDLE, error) { return h, nil }) - if err != nil { - return 0, err - } - return h.(bcrypt.ALG_HANDLE), nil } // TLS1PRF implements the TLS 1.0/1.1 pseudo-random function if h is nil, From d3b9b85c2baa6b6d168204d9d7b4a563b2b61086 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Fri, 27 Sep 2024 09:28:59 +0200 Subject: [PATCH 2/3] use oneliner to instantiate generic zero variables --- cng/cng.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/cng/cng.go b/cng/cng.go index 29255e7..d1916f9 100644 --- a/cng/cng.go +++ b/cng/cng.go @@ -52,14 +52,12 @@ func loadOrStoreAlg[T any](id string, flags bcrypt.AlgorithmProviderFlags, mode var h bcrypt.ALG_HANDLE err := bcrypt.OpenAlgorithmProvider(&h, utf16PtrFromString(id), nil, flags) if err != nil { - var t T - return t, err + return *new(T), err } v, err := fn(h) if err != nil { bcrypt.CloseAlgorithmProvider(h, 0) - var t T - return t, err + return *new(T), err } if existing, loaded := algCache.LoadOrStore(entryKey, v); loaded { // We can safely use a provider that has already been cached in another concurrent goroutine. From 519bf9d027446d703d2d821bf2f39de37589ab6c Mon Sep 17 00:00:00 2001 From: qmuntal Date: Fri, 27 Sep 2024 09:57:12 +0200 Subject: [PATCH 3/3] always use a named constant for none flags --- cng/hkdf.go | 2 +- cng/pbkdf2.go | 2 +- cng/tls1prf.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cng/hkdf.go b/cng/hkdf.go index 914b669..655926e 100644 --- a/cng/hkdf.go +++ b/cng/hkdf.go @@ -23,7 +23,7 @@ func SupportsHKDF() bool { } func loadHKDF() (bcrypt.ALG_HANDLE, error) { - return loadOrStoreAlg(bcrypt.HKDF_ALGORITHM, 0, "", func(h bcrypt.ALG_HANDLE) (bcrypt.ALG_HANDLE, error) { + return loadOrStoreAlg(bcrypt.HKDF_ALGORITHM, bcrypt.ALG_NONE_FLAG, "", func(h bcrypt.ALG_HANDLE) (bcrypt.ALG_HANDLE, error) { return h, nil }) } diff --git a/cng/pbkdf2.go b/cng/pbkdf2.go index 49f1ce2..5466b18 100644 --- a/cng/pbkdf2.go +++ b/cng/pbkdf2.go @@ -15,7 +15,7 @@ import ( ) func loadPBKDF2() (bcrypt.ALG_HANDLE, error) { - return loadOrStoreAlg(bcrypt.PBKDF2_ALGORITHM, 0, "", func(h bcrypt.ALG_HANDLE) (bcrypt.ALG_HANDLE, error) { + return loadOrStoreAlg(bcrypt.PBKDF2_ALGORITHM, bcrypt.ALG_NONE_FLAG, "", func(h bcrypt.ALG_HANDLE) (bcrypt.ALG_HANDLE, error) { return h, nil }) } diff --git a/cng/tls1prf.go b/cng/tls1prf.go index c4e08b5..5a3fb01 100644 --- a/cng/tls1prf.go +++ b/cng/tls1prf.go @@ -15,7 +15,7 @@ import ( ) func loadTLS1PRF(id string) (bcrypt.ALG_HANDLE, error) { - return loadOrStoreAlg(id, 0, "", func(h bcrypt.ALG_HANDLE) (bcrypt.ALG_HANDLE, error) { + return loadOrStoreAlg(id, bcrypt.ALG_NONE_FLAG, "", func(h bcrypt.ALG_HANDLE) (bcrypt.ALG_HANDLE, error) { return h, nil }) }