From 4acdfbd3e8f01ecf631d26c4fcd18b7a9f70d3b9 Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Thu, 24 Oct 2024 19:18:20 -0400 Subject: [PATCH] Move the scrypt scaffholding code to Rust (#11818) --- .../hazmat/backends/openssl/backend.py | 2 +- .../hazmat/bindings/_rust/openssl/kdf.pyi | 24 ++- .../hazmat/primitives/kdf/scrypt.py | 67 +------- src/rust/src/backend/kdf.rs | 161 +++++++++++++++--- src/rust/src/exceptions.rs | 1 + 5 files changed, 157 insertions(+), 98 deletions(-) diff --git a/src/cryptography/hazmat/backends/openssl/backend.py b/src/cryptography/hazmat/backends/openssl/backend.py index d31b039add0e..9a3dc2108701 100644 --- a/src/cryptography/hazmat/backends/openssl/backend.py +++ b/src/cryptography/hazmat/backends/openssl/backend.py @@ -120,7 +120,7 @@ def scrypt_supported(self) -> bool: if self._fips_enabled: return False else: - return hasattr(rust_openssl.kdf, "derive_scrypt") + return hasattr(rust_openssl.kdf.Scrypt, "derive") def hmac_supported(self, algorithm: hashes.HashAlgorithm) -> bool: # FIPS mode still allows SHA1 for HMAC diff --git a/src/cryptography/hazmat/bindings/_rust/openssl/kdf.pyi b/src/cryptography/hazmat/bindings/_rust/openssl/kdf.pyi index 034a8fed2e78..01f7d606e8cc 100644 --- a/src/cryptography/hazmat/bindings/_rust/openssl/kdf.pyi +++ b/src/cryptography/hazmat/bindings/_rust/openssl/kdf.pyi @@ -2,6 +2,8 @@ # 2.0, and the BSD License. See the LICENSE file in the root of this repository # for complete details. +import typing + from cryptography.hazmat.primitives.hashes import HashAlgorithm def derive_pbkdf2_hmac( @@ -11,12 +13,16 @@ def derive_pbkdf2_hmac( iterations: int, length: int, ) -> bytes: ... -def derive_scrypt( - key_material: bytes, - salt: bytes, - n: int, - r: int, - p: int, - max_mem: int, - length: int, -) -> bytes: ... + +class Scrypt: + def __init__( + self, + salt: bytes, + length: int, + n: int, + r: int, + p: int, + backend: typing.Any = None, + ) -> None: ... + def derive(self, key_material: bytes) -> bytes: ... + def verify(self, key_material: bytes, expected_key: bytes) -> None: ... diff --git a/src/cryptography/hazmat/primitives/kdf/scrypt.py b/src/cryptography/hazmat/primitives/kdf/scrypt.py index 05a4f675b6ab..43a7704d48e3 100644 --- a/src/cryptography/hazmat/primitives/kdf/scrypt.py +++ b/src/cryptography/hazmat/primitives/kdf/scrypt.py @@ -5,76 +5,13 @@ from __future__ import annotations import sys -import typing -from cryptography import utils -from cryptography.exceptions import ( - AlreadyFinalized, - InvalidKey, - UnsupportedAlgorithm, -) from cryptography.hazmat.bindings._rust import openssl as rust_openssl -from cryptography.hazmat.primitives import constant_time from cryptography.hazmat.primitives.kdf import KeyDerivationFunction # This is used by the scrypt tests to skip tests that require more memory # than the MEM_LIMIT _MEM_LIMIT = sys.maxsize // 2 - -class Scrypt(KeyDerivationFunction): - def __init__( - self, - salt: bytes, - length: int, - n: int, - r: int, - p: int, - backend: typing.Any = None, - ): - from cryptography.hazmat.backends.openssl.backend import ( - backend as ossl, - ) - - if not ossl.scrypt_supported(): - raise UnsupportedAlgorithm( - "This version of OpenSSL does not support scrypt" - ) - self._length = length - utils._check_bytes("salt", salt) - if n < 2 or (n & (n - 1)) != 0: - raise ValueError("n must be greater than 1 and be a power of 2.") - - if r < 1: - raise ValueError("r must be greater than or equal to 1.") - - if p < 1: - raise ValueError("p must be greater than or equal to 1.") - - self._used = False - self._salt = salt - self._n = n - self._r = r - self._p = p - - def derive(self, key_material: bytes) -> bytes: - if self._used: - raise AlreadyFinalized("Scrypt instances can only be used once.") - self._used = True - - utils._check_byteslike("key_material", key_material) - - return rust_openssl.kdf.derive_scrypt( - key_material, - self._salt, - self._n, - self._r, - self._p, - _MEM_LIMIT, - self._length, - ) - - def verify(self, key_material: bytes, expected_key: bytes) -> None: - derived_key = self.derive(key_material) - if not constant_time.bytes_eq(derived_key, expected_key): - raise InvalidKey("Keys do not match.") +Scrypt = rust_openssl.kdf.Scrypt +KeyDerivationFunction.register(Scrypt) diff --git a/src/rust/src/backend/kdf.rs b/src/rust/src/backend/kdf.rs index 8c6a151a17d0..2292c08af5e2 100644 --- a/src/rust/src/backend/kdf.rs +++ b/src/rust/src/backend/kdf.rs @@ -2,9 +2,13 @@ // 2.0, and the BSD License. See the LICENSE file in the root of this repository // for complete details. +#[cfg(not(CRYPTOGRAPHY_IS_LIBRESSL))] +use pyo3::types::PyBytesMethods; + use crate::backend::hashes; use crate::buf::CffiBuf; -use crate::error::CryptographyResult; +use crate::error::{CryptographyError, CryptographyResult}; +use crate::exceptions; #[pyo3::pyfunction] pub(crate) fn derive_pbkdf2_hmac<'p>( @@ -23,36 +27,147 @@ pub(crate) fn derive_pbkdf2_hmac<'p>( })?) } -#[cfg(not(CRYPTOGRAPHY_IS_LIBRESSL))] -#[pyo3::pyfunction] -#[allow(clippy::too_many_arguments)] -fn derive_scrypt<'p>( - py: pyo3::Python<'p>, - key_material: CffiBuf<'_>, - salt: &[u8], +#[pyo3::pyclass(module = "cryptography.hazmat.primitives.kdf.scrypt")] +struct Scrypt { + #[cfg(not(CRYPTOGRAPHY_IS_LIBRESSL))] + salt: pyo3::Py, + #[cfg(not(CRYPTOGRAPHY_IS_LIBRESSL))] + length: usize, + #[cfg(not(CRYPTOGRAPHY_IS_LIBRESSL))] n: u64, + #[cfg(not(CRYPTOGRAPHY_IS_LIBRESSL))] r: u64, + #[cfg(not(CRYPTOGRAPHY_IS_LIBRESSL))] p: u64, - max_mem: u64, - length: usize, -) -> CryptographyResult> { - Ok(pyo3::types::PyBytes::new_bound_with(py, length, |b| { - openssl::pkcs5::scrypt(key_material.as_bytes(), salt, n, r, p, max_mem, b).map_err(|_| { - // memory required formula explained here: - // https://blog.filippo.io/the-scrypt-parameters/ - let min_memory = 128 * n * r / (1024 * 1024); - pyo3::exceptions::PyMemoryError::new_err(format!( - "Not enough memory to derive key. These parameters require {min_memory}MB of memory." - )) - }) - })?) + + #[cfg(not(CRYPTOGRAPHY_IS_LIBRESSL))] + used: bool, +} + +#[pyo3::pymethods] +impl Scrypt { + #[new] + #[pyo3(signature = (salt, length, n, r, p, backend=None))] + fn new( + salt: pyo3::Py, + length: usize, + n: u64, + r: u64, + p: u64, + backend: Option>, + ) -> CryptographyResult { + _ = backend; + + cfg_if::cfg_if! { + if #[cfg(CRYPTOGRAPHY_IS_LIBRESSL)] { + _ = salt; + _ = length; + _ = n; + _ = r; + _ = p; + + Err(CryptographyError::from( + exceptions::UnsupportedAlgorithm::new_err( + "This version of OpenSSL does not support scrypt" + ), + )) + } else { + if cryptography_openssl::fips::is_enabled() { + return Err(CryptographyError::from( + exceptions::UnsupportedAlgorithm::new_err( + "This version of OpenSSL does not support scrypt" + ), + )); + } + + if n < 2 || (n & (n - 1)) != 0 { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err( + "n must be greater than 1 and be a power of 2." + ), + )); + } + if r < 1 { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err( + "r must be greater than or equal to 1." + ), + )); + } + if p < 1 { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err( + "p must be greater than or equal to 1." + ), + )); + } + + Ok(Scrypt{ + salt, + length, + n, + r, + p, + used: false, + }) + } + } + } + + #[cfg(not(CRYPTOGRAPHY_IS_LIBRESSL))] + fn derive<'p>( + &mut self, + py: pyo3::Python<'p>, + key_material: CffiBuf<'_>, + ) -> CryptographyResult> { + if self.used { + return Err(exceptions::already_finalized_error()); + } + self.used = true; + + Ok(pyo3::types::PyBytes::new_bound_with( + py, + self.length, + |b| { + openssl::pkcs5::scrypt(key_material.as_bytes(), self.salt.as_bytes(py), self.n, self.r, self.p, (usize::MAX / 2).try_into().unwrap(), b).map_err(|_| { + // memory required formula explained here: + // https://blog.filippo.io/the-scrypt-parameters/ + let min_memory = 128 * self.n * self.r / (1024 * 1024); + pyo3::exceptions::PyMemoryError::new_err(format!( + "Not enough memory to derive key. These parameters require {min_memory}MB of memory." + )) + }) + }, + )?) + } + + #[cfg(not(CRYPTOGRAPHY_IS_LIBRESSL))] + fn verify( + &mut self, + py: pyo3::Python<'_>, + key_material: CffiBuf<'_>, + expected_key: CffiBuf<'_>, + ) -> CryptographyResult<()> { + let actual = self.derive(py, key_material)?; + let actual_bytes = actual.as_bytes(); + let expected_bytes = expected_key.as_bytes(); + + if actual_bytes.len() != expected_bytes.len() + || !openssl::memcmp::eq(actual_bytes, expected_bytes) + { + return Err(CryptographyError::from(exceptions::InvalidKey::new_err( + "Keys do not match.", + ))); + } + + Ok(()) + } } #[pyo3::pymodule] pub(crate) mod kdf { #[pymodule_export] use super::derive_pbkdf2_hmac; - #[cfg(not(CRYPTOGRAPHY_IS_LIBRESSL))] #[pymodule_export] - use super::derive_scrypt; + use super::Scrypt; } diff --git a/src/rust/src/exceptions.rs b/src/rust/src/exceptions.rs index 5e0a44f8cc78..cfcedd2eb474 100644 --- a/src/rust/src/exceptions.rs +++ b/src/rust/src/exceptions.rs @@ -30,6 +30,7 @@ pub(crate) enum Reasons { pyo3::import_exception_bound!(cryptography.exceptions, AlreadyUpdated); pyo3::import_exception_bound!(cryptography.exceptions, AlreadyFinalized); pyo3::import_exception_bound!(cryptography.exceptions, InternalError); +pyo3::import_exception_bound!(cryptography.exceptions, InvalidKey); pyo3::import_exception_bound!(cryptography.exceptions, InvalidSignature); pyo3::import_exception_bound!(cryptography.exceptions, InvalidTag); pyo3::import_exception_bound!(cryptography.exceptions, NotYetFinalized);