Skip to content

Commit

Permalink
update loadOrStoreAlg to accept type arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
qmuntal committed Sep 26, 2024
1 parent c5abe8b commit 1bb252f
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 111 deletions.
10 changes: 3 additions & 7 deletions cng/cipher.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,21 @@ type cipherAlgorithm struct {
}

func loadCipher(id, mode string) (cipherAlgorithm, error) {
v, err := loadOrStoreAlg(id, bcrypt.ALG_NONE_FLAG, mode, func(h bcrypt.ALG_HANDLE) (interface{}, error) {
return loadOrStoreAlg(id, bcrypt.ALG_NONE_FLAG, mode, func(h bcrypt.ALG_HANDLE) (cipherAlgorithm, error) {
if mode != "" {
// Windows 8 added support to set the CipherMode value on a key,
// but Windows 7 requires that it be set on the algorithm before key creation.
err := setString(bcrypt.HANDLE(h), bcrypt.CHAINING_MODE, mode)
if err != nil {
return nil, err
return cipherAlgorithm{}, err
}
}
lengths, err := getKeyLengths(bcrypt.HANDLE(h))
if err != nil {
return nil, err
return cipherAlgorithm{}, err
}
return cipherAlgorithm{h, lengths}, nil
})
if err != nil {
return cipherAlgorithm{}, nil
}
return v.(cipherAlgorithm), nil
}

func newCipherHandle(id, mode string, key []byte) (bcrypt.KEY_HANDLE, error) {
Expand Down
17 changes: 10 additions & 7 deletions cng/cng.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,32 +36,35 @@ func len32(s []byte) int {

var algCache sync.Map

type newAlgEntryFn func(h bcrypt.ALG_HANDLE) (interface{}, error)

func loadOrStoreAlg(id string, flags bcrypt.AlgorithmProviderFlags, mode string, fn newAlgEntryFn) (interface{}, error) {
// loadOrStoreAlg loads an algorithm with the given id, flags, and mode from the cache.
// If the algorithm is not in the cache, a new one is created and then initialized using fn.
// The returned algorithm handle should not be closed by the caller.
func loadOrStoreAlg[T any](id string, flags bcrypt.AlgorithmProviderFlags, mode string, fn func(h bcrypt.ALG_HANDLE) (T, error)) (T, error) {
var entryKey = struct {
id string
flags bcrypt.AlgorithmProviderFlags
mode string
}{id, flags, mode}

if v, ok := algCache.Load(entryKey); ok {
return v, nil
return v.(T), nil
}
var h bcrypt.ALG_HANDLE
err := bcrypt.OpenAlgorithmProvider(&h, utf16PtrFromString(id), nil, flags)
if err != nil {
return nil, err
var t T
return t, err
}
v, err := fn(h)
if err != nil {
bcrypt.CloseAlgorithmProvider(h, 0)
return nil, err
var t T
return t, err
}
if existing, loaded := algCache.LoadOrStore(entryKey, v); loaded {
// We can safely use a provider that has already been cached in another concurrent goroutine.
bcrypt.CloseAlgorithmProvider(h, 0)
v = existing
v = existing.(T)
}
return v, nil
}
Expand Down
8 changes: 2 additions & 6 deletions cng/dsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,13 @@ type dsaAlgorithm struct {
}

func loadDSA() (h dsaAlgorithm, err error) {
v, err := loadOrStoreAlg(bcrypt.DSA_ALGORITHM, bcrypt.ALG_NONE_FLAG, "", func(h bcrypt.ALG_HANDLE) (interface{}, error) {
return loadOrStoreAlg(bcrypt.DSA_ALGORITHM, bcrypt.ALG_NONE_FLAG, "", func(h bcrypt.ALG_HANDLE) (dsaAlgorithm, error) {
lengths, err := getKeyLengths(bcrypt.HANDLE(h))
if err != nil {
return nil, err
return dsaAlgorithm{}, err
}
return dsaAlgorithm{h, lengths}, nil
})
if err != nil {
return dsaAlgorithm{}, err
}
return v.(dsaAlgorithm), nil
}

// DSAParameters contains the DSA parameters.
Expand Down
61 changes: 28 additions & 33 deletions cng/ecdh.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,31 @@ var errInvalidPrivateKey = errors.New("cng: invalid private key")

type ecdhAlgorithm struct {
handle bcrypt.ALG_HANDLE
bits uint32
}

func loadECDH(curve string) (h ecdhAlgorithm, bits uint32, err error) {
var id string
switch curve {
case "P-256":
id, bits = bcrypt.ECC_CURVE_NISTP256, 256
case "P-384":
id, bits = bcrypt.ECC_CURVE_NISTP384, 384
case "P-521":
id, bits = bcrypt.ECC_CURVE_NISTP521, 521
case "X25519":
id, bits = bcrypt.ECC_CURVE_25519, 255
default:
err = errUnknownCurve
}
if err != nil {
return
}
v, err := loadOrStoreAlg(bcrypt.ECDH_ALGORITHM, bcrypt.ALG_NONE_FLAG, id, func(h bcrypt.ALG_HANDLE) (interface{}, error) {
err := setString(bcrypt.HANDLE(h), bcrypt.ECC_CURVE_NAME, id)
func loadECDH(curve string) (ecdhAlgorithm, error) {
return loadOrStoreAlg(bcrypt.ECDH_ALGORITHM, bcrypt.ALG_NONE_FLAG, curve, func(h bcrypt.ALG_HANDLE) (ecdhAlgorithm, error) {
var name string
var bits uint32
switch curve {
case "P-256":
name, bits = bcrypt.ECC_CURVE_NISTP256, 256
case "P-384":
name, bits = bcrypt.ECC_CURVE_NISTP384, 384
case "P-521":
name, bits = bcrypt.ECC_CURVE_NISTP521, 521
case "X25519":
name, bits = bcrypt.ECC_CURVE_25519, 255
default:
return ecdhAlgorithm{}, errUnknownCurve
}
err := setString(bcrypt.HANDLE(h), bcrypt.ECC_CURVE_NAME, name)
if err != nil {
return nil, err
return ecdhAlgorithm{}, err
}
return ecdhAlgorithm{h}, nil
return ecdhAlgorithm{h, bits}, nil
})
if err != nil {
return ecdhAlgorithm{}, 0, err
}
return v.(ecdhAlgorithm), bits, nil
}

type PublicKeyECDH struct {
Expand Down Expand Up @@ -115,12 +110,12 @@ func ECDH(priv *PrivateKeyECDH, pub *PublicKeyECDH) ([]byte, error) {
}

func GenerateKeyECDH(curve string) (*PrivateKeyECDH, []byte, error) {
h, bits, err := loadECDH(curve)
h, err := loadECDH(curve)
if err != nil {
return nil, nil, err
}
var hkey bcrypt.KEY_HANDLE
err = bcrypt.GenerateKeyPair(h.handle, &hkey, bits, 0)
err = bcrypt.GenerateKeyPair(h.handle, &hkey, h.bits, 0)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -154,7 +149,7 @@ func NewPublicKeyECDH(curve string, bytes []byte) (*PublicKeyECDH, error) {
if len(bytes) == 0 || (nist && bytes[0] != ecdhUncompressedPrefix) {
return nil, errInvalidPublicKey
}
h, bits, err := loadECDH(curve)
h, err := loadECDH(curve)
if err != nil {
return nil, err
}
Expand All @@ -169,11 +164,11 @@ func NewPublicKeyECDH(curve string, bytes []byte) (*PublicKeyECDH, error) {
ncomponents = 1
keyWithoutEncoding = bytes
}
keySize := int(bits+7) / 8
keySize := int(h.bits+7) / 8
if len(keyWithoutEncoding) != keySize*ncomponents {
return nil, errInvalidPublicKey
}
hkey, err := importECCKey(h.handle, bcrypt.ECDH_ALGORITHM, bits, keyWithoutEncoding[:keySize], keyWithoutEncoding[keySize:], nil)
hkey, err := importECCKey(h.handle, bcrypt.ECDH_ALGORITHM, h.bits, keyWithoutEncoding[:keySize], keyWithoutEncoding[keySize:], nil)
if err != nil {
return nil, err
}
Expand All @@ -185,11 +180,11 @@ func NewPublicKeyECDH(curve string, bytes []byte) (*PublicKeyECDH, error) {
func (k *PublicKeyECDH) Bytes() []byte { return k.bytes }

func NewPrivateKeyECDH(curve string, key []byte) (*PrivateKeyECDH, error) {
h, bits, err := loadECDH(curve)
h, err := loadECDH(curve)
if err != nil {
return nil, err
}
keySize := int(bits+7) / 8
keySize := int(h.bits+7) / 8
if len(key) != keySize {
return nil, errInvalidPrivateKey
}
Expand All @@ -202,7 +197,7 @@ func NewPrivateKeyECDH(curve string, key []byte) (*PrivateKeyECDH, error) {
// To trigger this behavior we pass a zeroed X/Y with keySize length.
// zero is big enough to fit P-521 curves, the largest we handle, in the stack.
var zero [(521 + 7) / 8]byte
hkey, err := importECCKey(h.handle, bcrypt.ECDH_ALGORITHM, bits, zero[:keySize], zero[:keySize], key)
hkey, err := importECCKey(h.handle, bcrypt.ECDH_ALGORITHM, h.bits, zero[:keySize], zero[:keySize], key)
if err != nil {
return nil, err
}
Expand Down
58 changes: 26 additions & 32 deletions cng/ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,47 +17,41 @@ var errUnknownCurve = errors.New("cng: unknown elliptic curve")

type ecdsaAlgorithm struct {
handle bcrypt.ALG_HANDLE
bits uint32
}

func loadECDSA(curve string) (h ecdsaAlgorithm, bits uint32, err error) {
var id string
switch curve {
case "P-224":
id, bits = bcrypt.ECC_CURVE_NISTP224, 224
case "P-256":
id, bits = bcrypt.ECC_CURVE_NISTP256, 256
case "P-384":
id, bits = bcrypt.ECC_CURVE_NISTP384, 384
case "P-521":
id, bits = bcrypt.ECC_CURVE_NISTP521, 521
default:
err = errUnknownCurve
}
if err != nil {
return
}
v, err := loadOrStoreAlg(bcrypt.ECDSA_ALGORITHM, bcrypt.ALG_NONE_FLAG, id, func(h bcrypt.ALG_HANDLE) (interface{}, error) {
err := setString(bcrypt.HANDLE(h), bcrypt.ECC_CURVE_NAME, id)
func loadECDSA(curve string) (ecdsaAlgorithm, error) {
return loadOrStoreAlg(bcrypt.ECDSA_ALGORITHM, bcrypt.ALG_NONE_FLAG, curve, func(h bcrypt.ALG_HANDLE) (ecdsaAlgorithm, error) {
var name string
var bits uint32
switch curve {
case "P-224":
name, bits = bcrypt.ECC_CURVE_NISTP224, 224
case "P-256":
name, bits = bcrypt.ECC_CURVE_NISTP256, 256
case "P-384":
name, bits = bcrypt.ECC_CURVE_NISTP384, 384
case "P-521":
name, bits = bcrypt.ECC_CURVE_NISTP521, 521
default:
return ecdsaAlgorithm{}, errUnknownCurve
}
err := setString(bcrypt.HANDLE(h), bcrypt.ECC_CURVE_NAME, name)
if err != nil {
return nil, err
return ecdsaAlgorithm{}, err
}
return ecdsaAlgorithm{h}, nil
return ecdsaAlgorithm{h, bits}, nil
})
if err != nil {
return ecdsaAlgorithm{}, 0, err
}
return v.(ecdsaAlgorithm), bits, nil
}

func GenerateKeyECDSA(curve string) (X, Y, D BigInt, err error) {
var h ecdsaAlgorithm
var bits uint32
h, bits, err = loadECDSA(curve)
h, err = loadECDSA(curve)
if err != nil {
return
}
var hkey bcrypt.KEY_HANDLE
err = bcrypt.GenerateKeyPair(h.handle, &hkey, bits, 0)
err = bcrypt.GenerateKeyPair(h.handle, &hkey, h.bits, 0)
if err != nil {
return
}
Expand Down Expand Up @@ -87,11 +81,11 @@ type PublicKeyECDSA struct {
}

func NewPublicKeyECDSA(curve string, X, Y BigInt) (*PublicKeyECDSA, error) {
h, bits, err := loadECDSA(curve)
h, err := loadECDSA(curve)
if err != nil {
return nil, err
}
hkey, err := importECCKey(h.handle, bcrypt.ECDSA_ALGORITHM, bits, X, Y, nil)
hkey, err := importECCKey(h.handle, bcrypt.ECDSA_ALGORITHM, h.bits, X, Y, nil)
if err != nil {
return nil, err
}
Expand All @@ -109,11 +103,11 @@ type PrivateKeyECDSA struct {
}

func NewPrivateKeyECDSA(curve string, X, Y, D BigInt) (*PrivateKeyECDSA, error) {
h, bits, err := loadECDSA(curve)
h, err := loadECDSA(curve)
if err != nil {
return nil, err
}
hkey, err := importECCKey(h.handle, bcrypt.ECDSA_ALGORITHM, bits, X, Y, D)
hkey, err := importECCKey(h.handle, bcrypt.ECDSA_ALGORITHM, h.bits, X, Y, D)
if err != nil {
return nil, err
}
Expand Down
6 changes: 1 addition & 5 deletions cng/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ type hashAlgorithm struct {
}

func loadHash(id string, flags bcrypt.AlgorithmProviderFlags) (*hashAlgorithm, error) {
v, err := loadOrStoreAlg(id, flags, "", func(h bcrypt.ALG_HANDLE) (interface{}, error) {
return loadOrStoreAlg(id, flags, "", func(h bcrypt.ALG_HANDLE) (*hashAlgorithm, error) {
size, err := getUint32(bcrypt.HANDLE(h), bcrypt.HASH_LENGTH)
if err != nil {
return nil, err
Expand All @@ -168,10 +168,6 @@ func loadHash(id string, flags bcrypt.AlgorithmProviderFlags) (*hashAlgorithm, e
}
return &hashAlgorithm{h, id, size, blockSize}, nil
})
if err != nil {
return nil, err
}
return v.(*hashAlgorithm), nil
}

// hashToID converts a hash.Hash implementation from this package
Expand Down
6 changes: 1 addition & 5 deletions cng/hkdf.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,9 @@ func SupportsHKDF() bool {
}

func loadHKDF() (bcrypt.ALG_HANDLE, error) {
h, err := loadOrStoreAlg(bcrypt.HKDF_ALGORITHM, 0, "", func(h bcrypt.ALG_HANDLE) (interface{}, error) {
return loadOrStoreAlg(bcrypt.HKDF_ALGORITHM, 0, "", func(h bcrypt.ALG_HANDLE) (bcrypt.ALG_HANDLE, error) {
return h, nil
})
if err != nil {
return 0, err
}
return h.(bcrypt.ALG_HANDLE), nil
}

type hkdf struct {
Expand Down
6 changes: 1 addition & 5 deletions cng/pbkdf2.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,9 @@ import (
)

func loadPBKDF2() (bcrypt.ALG_HANDLE, error) {
h, err := loadOrStoreAlg(bcrypt.PBKDF2_ALGORITHM, 0, "", func(h bcrypt.ALG_HANDLE) (interface{}, error) {
return loadOrStoreAlg(bcrypt.PBKDF2_ALGORITHM, 0, "", func(h bcrypt.ALG_HANDLE) (bcrypt.ALG_HANDLE, error) {
return h, nil
})
if err != nil {
return 0, err
}
return h.(bcrypt.ALG_HANDLE), nil
}

func PBKDF2(password, salt []byte, iter, keyLen int, h func() hash.Hash) ([]byte, error) {
Expand Down
8 changes: 2 additions & 6 deletions cng/rsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,13 @@ type rsaAlgorithm struct {
}

func loadRsa() (rsaAlgorithm, error) {
v, err := loadOrStoreAlg(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG, "", func(h bcrypt.ALG_HANDLE) (interface{}, error) {
return loadOrStoreAlg(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG, "", func(h bcrypt.ALG_HANDLE) (rsaAlgorithm, error) {
lengths, err := getKeyLengths(bcrypt.HANDLE(h))
if err != nil {
return nil, err
return rsaAlgorithm{}, err
}
return rsaAlgorithm{h, lengths}, nil
})
if err != nil {
return rsaAlgorithm{}, err
}
return v.(rsaAlgorithm), nil
}

func GenerateKeyRSA(bits int) (N, E, D, P, Q, Dp, Dq, Qinv BigInt, err error) {
Expand Down
6 changes: 1 addition & 5 deletions cng/tls1prf.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,9 @@ import (
)

func loadTLS1PRF(id string) (bcrypt.ALG_HANDLE, error) {
h, err := loadOrStoreAlg(id, 0, "", func(h bcrypt.ALG_HANDLE) (interface{}, error) {
return loadOrStoreAlg(id, 0, "", func(h bcrypt.ALG_HANDLE) (bcrypt.ALG_HANDLE, error) {
return h, nil
})
if err != nil {
return 0, err
}
return h.(bcrypt.ALG_HANDLE), nil
}

// TLS1PRF implements the TLS 1.0/1.1 pseudo-random function if h is nil,
Expand Down

0 comments on commit 1bb252f

Please sign in to comment.