Skip to content

Commit

Permalink
Move the scrypt scaffholding code to Rust (#11818)
Browse files Browse the repository at this point in the history
  • Loading branch information
alex authored Oct 24, 2024
1 parent f6d9074 commit 4acdfbd
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 98 deletions.
2 changes: 1 addition & 1 deletion src/cryptography/hazmat/backends/openssl/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def scrypt_supported(self) -> bool:
if self._fips_enabled:
return False
else:
return hasattr(rust_openssl.kdf, "derive_scrypt")
return hasattr(rust_openssl.kdf.Scrypt, "derive")

def hmac_supported(self, algorithm: hashes.HashAlgorithm) -> bool:
# FIPS mode still allows SHA1 for HMAC
Expand Down
24 changes: 15 additions & 9 deletions src/cryptography/hazmat/bindings/_rust/openssl/kdf.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.

import typing

from cryptography.hazmat.primitives.hashes import HashAlgorithm

def derive_pbkdf2_hmac(
Expand All @@ -11,12 +13,16 @@ def derive_pbkdf2_hmac(
iterations: int,
length: int,
) -> bytes: ...
def derive_scrypt(
key_material: bytes,
salt: bytes,
n: int,
r: int,
p: int,
max_mem: int,
length: int,
) -> bytes: ...

class Scrypt:
def __init__(
self,
salt: bytes,
length: int,
n: int,
r: int,
p: int,
backend: typing.Any = None,
) -> None: ...
def derive(self, key_material: bytes) -> bytes: ...
def verify(self, key_material: bytes, expected_key: bytes) -> None: ...
67 changes: 2 additions & 65 deletions src/cryptography/hazmat/primitives/kdf/scrypt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,76 +5,13 @@
from __future__ import annotations

import sys
import typing

from cryptography import utils
from cryptography.exceptions import (
AlreadyFinalized,
InvalidKey,
UnsupportedAlgorithm,
)
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
from cryptography.hazmat.primitives import constant_time
from cryptography.hazmat.primitives.kdf import KeyDerivationFunction

# This is used by the scrypt tests to skip tests that require more memory
# than the MEM_LIMIT
_MEM_LIMIT = sys.maxsize // 2


class Scrypt(KeyDerivationFunction):
def __init__(
self,
salt: bytes,
length: int,
n: int,
r: int,
p: int,
backend: typing.Any = None,
):
from cryptography.hazmat.backends.openssl.backend import (
backend as ossl,
)

if not ossl.scrypt_supported():
raise UnsupportedAlgorithm(
"This version of OpenSSL does not support scrypt"
)
self._length = length
utils._check_bytes("salt", salt)
if n < 2 or (n & (n - 1)) != 0:
raise ValueError("n must be greater than 1 and be a power of 2.")

if r < 1:
raise ValueError("r must be greater than or equal to 1.")

if p < 1:
raise ValueError("p must be greater than or equal to 1.")

self._used = False
self._salt = salt
self._n = n
self._r = r
self._p = p

def derive(self, key_material: bytes) -> bytes:
if self._used:
raise AlreadyFinalized("Scrypt instances can only be used once.")
self._used = True

utils._check_byteslike("key_material", key_material)

return rust_openssl.kdf.derive_scrypt(
key_material,
self._salt,
self._n,
self._r,
self._p,
_MEM_LIMIT,
self._length,
)

def verify(self, key_material: bytes, expected_key: bytes) -> None:
derived_key = self.derive(key_material)
if not constant_time.bytes_eq(derived_key, expected_key):
raise InvalidKey("Keys do not match.")
Scrypt = rust_openssl.kdf.Scrypt
KeyDerivationFunction.register(Scrypt)
161 changes: 138 additions & 23 deletions src/rust/src/backend/kdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
// 2.0, and the BSD License. See the LICENSE file in the root of this repository
// for complete details.

#[cfg(not(CRYPTOGRAPHY_IS_LIBRESSL))]
use pyo3::types::PyBytesMethods;

use crate::backend::hashes;
use crate::buf::CffiBuf;
use crate::error::CryptographyResult;
use crate::error::{CryptographyError, CryptographyResult};
use crate::exceptions;

#[pyo3::pyfunction]
pub(crate) fn derive_pbkdf2_hmac<'p>(
Expand All @@ -23,36 +27,147 @@ pub(crate) fn derive_pbkdf2_hmac<'p>(
})?)
}

#[cfg(not(CRYPTOGRAPHY_IS_LIBRESSL))]
#[pyo3::pyfunction]
#[allow(clippy::too_many_arguments)]
fn derive_scrypt<'p>(
py: pyo3::Python<'p>,
key_material: CffiBuf<'_>,
salt: &[u8],
#[pyo3::pyclass(module = "cryptography.hazmat.primitives.kdf.scrypt")]
struct Scrypt {
#[cfg(not(CRYPTOGRAPHY_IS_LIBRESSL))]
salt: pyo3::Py<pyo3::types::PyBytes>,
#[cfg(not(CRYPTOGRAPHY_IS_LIBRESSL))]
length: usize,
#[cfg(not(CRYPTOGRAPHY_IS_LIBRESSL))]
n: u64,
#[cfg(not(CRYPTOGRAPHY_IS_LIBRESSL))]
r: u64,
#[cfg(not(CRYPTOGRAPHY_IS_LIBRESSL))]
p: u64,
max_mem: u64,
length: usize,
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
Ok(pyo3::types::PyBytes::new_bound_with(py, length, |b| {
openssl::pkcs5::scrypt(key_material.as_bytes(), salt, n, r, p, max_mem, b).map_err(|_| {
// memory required formula explained here:
// https://blog.filippo.io/the-scrypt-parameters/
let min_memory = 128 * n * r / (1024 * 1024);
pyo3::exceptions::PyMemoryError::new_err(format!(
"Not enough memory to derive key. These parameters require {min_memory}MB of memory."
))
})
})?)

#[cfg(not(CRYPTOGRAPHY_IS_LIBRESSL))]
used: bool,
}

#[pyo3::pymethods]
impl Scrypt {
#[new]
#[pyo3(signature = (salt, length, n, r, p, backend=None))]
fn new(
salt: pyo3::Py<pyo3::types::PyBytes>,
length: usize,
n: u64,
r: u64,
p: u64,
backend: Option<pyo3::Bound<'_, pyo3::PyAny>>,
) -> CryptographyResult<Self> {
_ = backend;

cfg_if::cfg_if! {
if #[cfg(CRYPTOGRAPHY_IS_LIBRESSL)] {
_ = salt;
_ = length;
_ = n;
_ = r;
_ = p;

Err(CryptographyError::from(
exceptions::UnsupportedAlgorithm::new_err(
"This version of OpenSSL does not support scrypt"
),
))
} else {
if cryptography_openssl::fips::is_enabled() {
return Err(CryptographyError::from(
exceptions::UnsupportedAlgorithm::new_err(
"This version of OpenSSL does not support scrypt"
),
));
}

if n < 2 || (n & (n - 1)) != 0 {
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err(
"n must be greater than 1 and be a power of 2."
),
));
}
if r < 1 {
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err(
"r must be greater than or equal to 1."
),
));
}
if p < 1 {
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err(
"p must be greater than or equal to 1."
),
));
}

Ok(Scrypt{
salt,
length,
n,
r,
p,
used: false,
})
}
}
}

#[cfg(not(CRYPTOGRAPHY_IS_LIBRESSL))]
fn derive<'p>(
&mut self,
py: pyo3::Python<'p>,
key_material: CffiBuf<'_>,
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
if self.used {
return Err(exceptions::already_finalized_error());
}
self.used = true;

Ok(pyo3::types::PyBytes::new_bound_with(
py,
self.length,
|b| {
openssl::pkcs5::scrypt(key_material.as_bytes(), self.salt.as_bytes(py), self.n, self.r, self.p, (usize::MAX / 2).try_into().unwrap(), b).map_err(|_| {
// memory required formula explained here:
// https://blog.filippo.io/the-scrypt-parameters/
let min_memory = 128 * self.n * self.r / (1024 * 1024);
pyo3::exceptions::PyMemoryError::new_err(format!(
"Not enough memory to derive key. These parameters require {min_memory}MB of memory."
))
})
},
)?)
}

#[cfg(not(CRYPTOGRAPHY_IS_LIBRESSL))]
fn verify(
&mut self,
py: pyo3::Python<'_>,
key_material: CffiBuf<'_>,
expected_key: CffiBuf<'_>,
) -> CryptographyResult<()> {
let actual = self.derive(py, key_material)?;
let actual_bytes = actual.as_bytes();
let expected_bytes = expected_key.as_bytes();

if actual_bytes.len() != expected_bytes.len()
|| !openssl::memcmp::eq(actual_bytes, expected_bytes)
{
return Err(CryptographyError::from(exceptions::InvalidKey::new_err(
"Keys do not match.",
)));
}

Ok(())
}
}

#[pyo3::pymodule]
pub(crate) mod kdf {
#[pymodule_export]
use super::derive_pbkdf2_hmac;
#[cfg(not(CRYPTOGRAPHY_IS_LIBRESSL))]
#[pymodule_export]
use super::derive_scrypt;
use super::Scrypt;
}
1 change: 1 addition & 0 deletions src/rust/src/exceptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub(crate) enum Reasons {
pyo3::import_exception_bound!(cryptography.exceptions, AlreadyUpdated);
pyo3::import_exception_bound!(cryptography.exceptions, AlreadyFinalized);
pyo3::import_exception_bound!(cryptography.exceptions, InternalError);
pyo3::import_exception_bound!(cryptography.exceptions, InvalidKey);
pyo3::import_exception_bound!(cryptography.exceptions, InvalidSignature);
pyo3::import_exception_bound!(cryptography.exceptions, InvalidTag);
pyo3::import_exception_bound!(cryptography.exceptions, NotYetFinalized);
Expand Down

0 comments on commit 4acdfbd

Please sign in to comment.