Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update loadOrStoreAlg to accept type parameters #63

Merged
merged 4 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions cng/cipher.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,21 @@ type cipherAlgorithm struct {
}

func loadCipher(id, mode string) (cipherAlgorithm, error) {
v, err := loadOrStoreAlg(id, bcrypt.ALG_NONE_FLAG, mode, func(h bcrypt.ALG_HANDLE) (interface{}, error) {
return loadOrStoreAlg(id, bcrypt.ALG_NONE_FLAG, mode, func(h bcrypt.ALG_HANDLE) (cipherAlgorithm, error) {
if mode != "" {
// Windows 8 added support to set the CipherMode value on a key,
// but Windows 7 requires that it be set on the algorithm before key creation.
err := setString(bcrypt.HANDLE(h), bcrypt.CHAINING_MODE, mode)
if err != nil {
return nil, err
return cipherAlgorithm{}, err
}
}
lengths, err := getKeyLengths(bcrypt.HANDLE(h))
if err != nil {
return nil, err
return cipherAlgorithm{}, err
}
return cipherAlgorithm{h, lengths}, nil
})
if err != nil {
return cipherAlgorithm{}, err
}
return v.(cipherAlgorithm), nil
}

func newCipherHandle(id, mode string, key []byte) (bcrypt.KEY_HANDLE, error) {
Expand Down
17 changes: 10 additions & 7 deletions cng/cng.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,32 +36,35 @@ func len32(s []byte) int {

var algCache sync.Map

type newAlgEntryFn func(h bcrypt.ALG_HANDLE) (interface{}, error)

func loadOrStoreAlg(id string, flags bcrypt.AlgorithmProviderFlags, mode string, fn newAlgEntryFn) (interface{}, error) {
// loadOrStoreAlg loads an algorithm with the given id, flags, and mode from the cache.
// If the algorithm is not in the cache, a new one is created and then initialized using fn.
// The returned algorithm handle should not be closed by the caller.
func loadOrStoreAlg[T any](id string, flags bcrypt.AlgorithmProviderFlags, mode string, fn func(h bcrypt.ALG_HANDLE) (T, error)) (T, error) {
karianna marked this conversation as resolved.
Show resolved Hide resolved
var entryKey = struct {
id string
flags bcrypt.AlgorithmProviderFlags
mode string
}{id, flags, mode}

if v, ok := algCache.Load(entryKey); ok {
return v, nil
return v.(T), nil
}
var h bcrypt.ALG_HANDLE
err := bcrypt.OpenAlgorithmProvider(&h, utf16PtrFromString(id), nil, flags)
if err != nil {
return nil, err
var t T
return t, err
qmuntal marked this conversation as resolved.
Show resolved Hide resolved
}
v, err := fn(h)
if err != nil {
bcrypt.CloseAlgorithmProvider(h, 0)
return nil, err
var t T
return t, err
}
if existing, loaded := algCache.LoadOrStore(entryKey, v); loaded {
// We can safely use a provider that has already been cached in another concurrent goroutine.
bcrypt.CloseAlgorithmProvider(h, 0)
v = existing
v = existing.(T)
}
return v, nil
}
Expand Down
8 changes: 2 additions & 6 deletions cng/dsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,13 @@ type dsaAlgorithm struct {
}

func loadDSA() (h dsaAlgorithm, err error) {
v, err := loadOrStoreAlg(bcrypt.DSA_ALGORITHM, bcrypt.ALG_NONE_FLAG, "", func(h bcrypt.ALG_HANDLE) (interface{}, error) {
return loadOrStoreAlg(bcrypt.DSA_ALGORITHM, bcrypt.ALG_NONE_FLAG, "", func(h bcrypt.ALG_HANDLE) (dsaAlgorithm, error) {
lengths, err := getKeyLengths(bcrypt.HANDLE(h))
if err != nil {
return nil, err
return dsaAlgorithm{}, err
}
return dsaAlgorithm{h, lengths}, nil
})
if err != nil {
return dsaAlgorithm{}, err
}
return v.(dsaAlgorithm), nil
}

// DSAParameters contains the DSA parameters.
Expand Down
61 changes: 28 additions & 33 deletions cng/ecdh.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,31 @@ var errInvalidPrivateKey = errors.New("cng: invalid private key")

type ecdhAlgorithm struct {
handle bcrypt.ALG_HANDLE
bits uint32
}

func loadECDH(curve string) (h ecdhAlgorithm, bits uint32, err error) {
var id string
switch curve {
case "P-256":
id, bits = bcrypt.ECC_CURVE_NISTP256, 256
case "P-384":
id, bits = bcrypt.ECC_CURVE_NISTP384, 384
case "P-521":
id, bits = bcrypt.ECC_CURVE_NISTP521, 521
case "X25519":
id, bits = bcrypt.ECC_CURVE_25519, 255
default:
err = errUnknownCurve
}
if err != nil {
return
}
v, err := loadOrStoreAlg(bcrypt.ECDH_ALGORITHM, bcrypt.ALG_NONE_FLAG, id, func(h bcrypt.ALG_HANDLE) (interface{}, error) {
err := setString(bcrypt.HANDLE(h), bcrypt.ECC_CURVE_NAME, id)
func loadECDH(curve string) (ecdhAlgorithm, error) {
return loadOrStoreAlg(bcrypt.ECDH_ALGORITHM, bcrypt.ALG_NONE_FLAG, curve, func(h bcrypt.ALG_HANDLE) (ecdhAlgorithm, error) {
var name string
var bits uint32
switch curve {
case "P-256":
karianna marked this conversation as resolved.
Show resolved Hide resolved
name, bits = bcrypt.ECC_CURVE_NISTP256, 256
case "P-384":
name, bits = bcrypt.ECC_CURVE_NISTP384, 384
case "P-521":
name, bits = bcrypt.ECC_CURVE_NISTP521, 521
case "X25519":
name, bits = bcrypt.ECC_CURVE_25519, 255
default:
return ecdhAlgorithm{}, errUnknownCurve
}
err := setString(bcrypt.HANDLE(h), bcrypt.ECC_CURVE_NAME, name)
if err != nil {
return nil, err
return ecdhAlgorithm{}, err
}
return ecdhAlgorithm{h}, nil
return ecdhAlgorithm{h, bits}, nil
})
if err != nil {
return ecdhAlgorithm{}, 0, err
}
return v.(ecdhAlgorithm), bits, nil
}

type PublicKeyECDH struct {
Expand Down Expand Up @@ -115,12 +110,12 @@ func ECDH(priv *PrivateKeyECDH, pub *PublicKeyECDH) ([]byte, error) {
}

func GenerateKeyECDH(curve string) (*PrivateKeyECDH, []byte, error) {
h, bits, err := loadECDH(curve)
h, err := loadECDH(curve)
if err != nil {
return nil, nil, err
}
var hkey bcrypt.KEY_HANDLE
err = bcrypt.GenerateKeyPair(h.handle, &hkey, bits, 0)
err = bcrypt.GenerateKeyPair(h.handle, &hkey, h.bits, 0)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -154,7 +149,7 @@ func NewPublicKeyECDH(curve string, bytes []byte) (*PublicKeyECDH, error) {
if len(bytes) == 0 || (nist && bytes[0] != ecdhUncompressedPrefix) {
return nil, errInvalidPublicKey
}
h, bits, err := loadECDH(curve)
h, err := loadECDH(curve)
if err != nil {
return nil, err
}
Expand All @@ -169,11 +164,11 @@ func NewPublicKeyECDH(curve string, bytes []byte) (*PublicKeyECDH, error) {
ncomponents = 1
keyWithoutEncoding = bytes
}
keySize := int(bits+7) / 8
keySize := int(h.bits+7) / 8
if len(keyWithoutEncoding) != keySize*ncomponents {
return nil, errInvalidPublicKey
}
hkey, err := importECCKey(h.handle, bcrypt.ECDH_ALGORITHM, bits, keyWithoutEncoding[:keySize], keyWithoutEncoding[keySize:], nil)
hkey, err := importECCKey(h.handle, bcrypt.ECDH_ALGORITHM, h.bits, keyWithoutEncoding[:keySize], keyWithoutEncoding[keySize:], nil)
if err != nil {
return nil, err
}
Expand All @@ -185,11 +180,11 @@ func NewPublicKeyECDH(curve string, bytes []byte) (*PublicKeyECDH, error) {
func (k *PublicKeyECDH) Bytes() []byte { return k.bytes }

func NewPrivateKeyECDH(curve string, key []byte) (*PrivateKeyECDH, error) {
h, bits, err := loadECDH(curve)
h, err := loadECDH(curve)
if err != nil {
return nil, err
}
keySize := int(bits+7) / 8
keySize := int(h.bits+7) / 8
if len(key) != keySize {
return nil, errInvalidPrivateKey
}
Expand All @@ -202,7 +197,7 @@ func NewPrivateKeyECDH(curve string, key []byte) (*PrivateKeyECDH, error) {
// To trigger this behavior we pass a zeroed X/Y with keySize length.
// zero is big enough to fit P-521 curves, the largest we handle, in the stack.
var zero [(521 + 7) / 8]byte
hkey, err := importECCKey(h.handle, bcrypt.ECDH_ALGORITHM, bits, zero[:keySize], zero[:keySize], key)
hkey, err := importECCKey(h.handle, bcrypt.ECDH_ALGORITHM, h.bits, zero[:keySize], zero[:keySize], key)
if err != nil {
return nil, err
}
Expand Down
58 changes: 26 additions & 32 deletions cng/ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,47 +17,41 @@ var errUnknownCurve = errors.New("cng: unknown elliptic curve")

type ecdsaAlgorithm struct {
handle bcrypt.ALG_HANDLE
bits uint32
}

func loadECDSA(curve string) (h ecdsaAlgorithm, bits uint32, err error) {
var id string
switch curve {
case "P-224":
id, bits = bcrypt.ECC_CURVE_NISTP224, 224
case "P-256":
id, bits = bcrypt.ECC_CURVE_NISTP256, 256
case "P-384":
id, bits = bcrypt.ECC_CURVE_NISTP384, 384
case "P-521":
id, bits = bcrypt.ECC_CURVE_NISTP521, 521
default:
err = errUnknownCurve
}
if err != nil {
return
}
v, err := loadOrStoreAlg(bcrypt.ECDSA_ALGORITHM, bcrypt.ALG_NONE_FLAG, id, func(h bcrypt.ALG_HANDLE) (interface{}, error) {
err := setString(bcrypt.HANDLE(h), bcrypt.ECC_CURVE_NAME, id)
func loadECDSA(curve string) (ecdsaAlgorithm, error) {
return loadOrStoreAlg(bcrypt.ECDSA_ALGORITHM, bcrypt.ALG_NONE_FLAG, curve, func(h bcrypt.ALG_HANDLE) (ecdsaAlgorithm, error) {
var name string
var bits uint32
switch curve {
case "P-224":
name, bits = bcrypt.ECC_CURVE_NISTP224, 224
case "P-256":
name, bits = bcrypt.ECC_CURVE_NISTP256, 256
case "P-384":
name, bits = bcrypt.ECC_CURVE_NISTP384, 384
case "P-521":
name, bits = bcrypt.ECC_CURVE_NISTP521, 521
default:
return ecdsaAlgorithm{}, errUnknownCurve
}
err := setString(bcrypt.HANDLE(h), bcrypt.ECC_CURVE_NAME, name)
if err != nil {
return nil, err
return ecdsaAlgorithm{}, err
}
return ecdsaAlgorithm{h}, nil
return ecdsaAlgorithm{h, bits}, nil
})
if err != nil {
return ecdsaAlgorithm{}, 0, err
}
return v.(ecdsaAlgorithm), bits, nil
}

func GenerateKeyECDSA(curve string) (X, Y, D BigInt, err error) {
var h ecdsaAlgorithm
var bits uint32
h, bits, err = loadECDSA(curve)
h, err = loadECDSA(curve)
if err != nil {
return
}
var hkey bcrypt.KEY_HANDLE
err = bcrypt.GenerateKeyPair(h.handle, &hkey, bits, 0)
err = bcrypt.GenerateKeyPair(h.handle, &hkey, h.bits, 0)
if err != nil {
return
}
Expand Down Expand Up @@ -87,11 +81,11 @@ type PublicKeyECDSA struct {
}

func NewPublicKeyECDSA(curve string, X, Y BigInt) (*PublicKeyECDSA, error) {
h, bits, err := loadECDSA(curve)
h, err := loadECDSA(curve)
if err != nil {
return nil, err
}
hkey, err := importECCKey(h.handle, bcrypt.ECDSA_ALGORITHM, bits, X, Y, nil)
hkey, err := importECCKey(h.handle, bcrypt.ECDSA_ALGORITHM, h.bits, X, Y, nil)
if err != nil {
return nil, err
}
Expand All @@ -109,11 +103,11 @@ type PrivateKeyECDSA struct {
}

func NewPrivateKeyECDSA(curve string, X, Y, D BigInt) (*PrivateKeyECDSA, error) {
h, bits, err := loadECDSA(curve)
h, err := loadECDSA(curve)
if err != nil {
return nil, err
}
hkey, err := importECCKey(h.handle, bcrypt.ECDSA_ALGORITHM, bits, X, Y, D)
hkey, err := importECCKey(h.handle, bcrypt.ECDSA_ALGORITHM, h.bits, X, Y, D)
if err != nil {
return nil, err
}
Expand Down
6 changes: 1 addition & 5 deletions cng/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ type hashAlgorithm struct {
}

func loadHash(id string, flags bcrypt.AlgorithmProviderFlags) (*hashAlgorithm, error) {
v, err := loadOrStoreAlg(id, flags, "", func(h bcrypt.ALG_HANDLE) (interface{}, error) {
return loadOrStoreAlg(id, flags, "", func(h bcrypt.ALG_HANDLE) (*hashAlgorithm, error) {
size, err := getUint32(bcrypt.HANDLE(h), bcrypt.HASH_LENGTH)
if err != nil {
return nil, err
Expand All @@ -168,10 +168,6 @@ func loadHash(id string, flags bcrypt.AlgorithmProviderFlags) (*hashAlgorithm, e
}
return &hashAlgorithm{h, id, size, blockSize}, nil
})
if err != nil {
return nil, err
}
return v.(*hashAlgorithm), nil
}

// hashToID converts a hash.Hash implementation from this package
Expand Down
6 changes: 1 addition & 5 deletions cng/hkdf.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,9 @@ func SupportsHKDF() bool {
}

func loadHKDF() (bcrypt.ALG_HANDLE, error) {
h, err := loadOrStoreAlg(bcrypt.HKDF_ALGORITHM, 0, "", func(h bcrypt.ALG_HANDLE) (interface{}, error) {
return loadOrStoreAlg(bcrypt.HKDF_ALGORITHM, 0, "", func(h bcrypt.ALG_HANDLE) (bcrypt.ALG_HANDLE, error) {
return h, nil
})
if err != nil {
return 0, err
}
return h.(bcrypt.ALG_HANDLE), nil
}

type hkdf struct {
Expand Down
6 changes: 1 addition & 5 deletions cng/pbkdf2.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,9 @@ import (
)

func loadPBKDF2() (bcrypt.ALG_HANDLE, error) {
h, err := loadOrStoreAlg(bcrypt.PBKDF2_ALGORITHM, 0, "", func(h bcrypt.ALG_HANDLE) (interface{}, error) {
return loadOrStoreAlg(bcrypt.PBKDF2_ALGORITHM, 0, "", func(h bcrypt.ALG_HANDLE) (bcrypt.ALG_HANDLE, error) {
return h, nil
})
if err != nil {
return 0, err
}
return h.(bcrypt.ALG_HANDLE), nil
}

func PBKDF2(password, salt []byte, iter, keyLen int, h func() hash.Hash) ([]byte, error) {
Expand Down
8 changes: 2 additions & 6 deletions cng/rsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,13 @@ type rsaAlgorithm struct {
}

func loadRsa() (rsaAlgorithm, error) {
v, err := loadOrStoreAlg(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG, "", func(h bcrypt.ALG_HANDLE) (interface{}, error) {
return loadOrStoreAlg(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG, "", func(h bcrypt.ALG_HANDLE) (rsaAlgorithm, error) {
lengths, err := getKeyLengths(bcrypt.HANDLE(h))
if err != nil {
return nil, err
return rsaAlgorithm{}, err
}
return rsaAlgorithm{h, lengths}, nil
})
if err != nil {
return rsaAlgorithm{}, err
}
return v.(rsaAlgorithm), nil
}

func GenerateKeyRSA(bits int) (N, E, D, P, Q, Dp, Dq, Qinv BigInt, err error) {
Expand Down
6 changes: 1 addition & 5 deletions cng/tls1prf.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,9 @@ import (
)

func loadTLS1PRF(id string) (bcrypt.ALG_HANDLE, error) {
h, err := loadOrStoreAlg(id, 0, "", func(h bcrypt.ALG_HANDLE) (interface{}, error) {
return loadOrStoreAlg(id, 0, "", func(h bcrypt.ALG_HANDLE) (bcrypt.ALG_HANDLE, error) {
return h, nil
})
if err != nil {
return 0, err
}
return h.(bcrypt.ALG_HANDLE), nil
}

// TLS1PRF implements the TLS 1.0/1.1 pseudo-random function if h is nil,
Expand Down
Loading