diff --git a/src/cryptography/hazmat/backends/openssl/backend.py b/src/cryptography/hazmat/backends/openssl/backend.py index 0e14bfb4e2b18..7ff208e607577 100644 --- a/src/cryptography/hazmat/backends/openssl/backend.py +++ b/src/cryptography/hazmat/backends/openssl/backend.py @@ -12,7 +12,6 @@ from cryptography import utils, x509 from cryptography.exceptions import UnsupportedAlgorithm, _Reasons from cryptography.hazmat.backends.openssl import aead -from cryptography.hazmat.backends.openssl.ciphers import _CipherContext from cryptography.hazmat.bindings._rust import openssl as rust_openssl from cryptography.hazmat.bindings.openssl import binding from cryptography.hazmat.primitives import hashes, serialization @@ -142,12 +141,8 @@ def __repr__(self) -> str: self._binding._legacy_provider_loaded, ) - def openssl_assert( - self, - ok: bool, - errors: list[rust_openssl.OpenSSLError] | None = None, - ) -> None: - return binding._openssl_assert(ok, errors=errors) + def openssl_assert(self, ok: bool) -> None: + return binding._openssl_assert(ok) def _enable_fips(self) -> None: # This function enables FIPS mode for OpenSSL 3.0.0 on installs that @@ -310,16 +305,6 @@ def _register_default_ciphers(self) -> None: _RC2, type(None), GetCipherByName("rc2") ) - def create_symmetric_encryption_ctx( - self, cipher: CipherAlgorithm, mode: Mode - ) -> _CipherContext: - return _CipherContext(self, cipher, mode, _CipherContext._ENCRYPT) - - def create_symmetric_decryption_ctx( - self, cipher: CipherAlgorithm, mode: Mode - ) -> _CipherContext: - return _CipherContext(self, cipher, mode, _CipherContext._DECRYPT) - def pbkdf2_hmac_supported(self, algorithm: hashes.HashAlgorithm) -> bool: return self.hmac_supported(algorithm) diff --git a/src/cryptography/hazmat/backends/openssl/ciphers.py b/src/cryptography/hazmat/backends/openssl/ciphers.py deleted file mode 100644 index 3916b1a510ad6..0000000000000 --- a/src/cryptography/hazmat/backends/openssl/ciphers.py +++ /dev/null @@ -1,282 +0,0 @@ -# This file is dual licensed under the terms of the Apache License, Version -# 2.0, and the BSD License. See the LICENSE file in the root of this repository -# for complete details. - -from __future__ import annotations - -import typing - -from cryptography.exceptions import InvalidTag, UnsupportedAlgorithm, _Reasons -from cryptography.hazmat.primitives import ciphers -from cryptography.hazmat.primitives.ciphers import algorithms, modes - -if typing.TYPE_CHECKING: - from cryptography.hazmat.backends.openssl.backend import Backend - - -class _CipherContext: - _ENCRYPT = 1 - _DECRYPT = 0 - _MAX_CHUNK_SIZE = 2**29 - - def __init__(self, backend: Backend, cipher, mode, operation: int) -> None: - self._backend = backend - self._cipher = cipher - self._mode = mode - self._operation = operation - self._tag: bytes | None = None - - if isinstance(self._cipher, ciphers.BlockCipherAlgorithm): - self._block_size_bytes = self._cipher.block_size // 8 - else: - self._block_size_bytes = 1 - - ctx = self._backend._lib.EVP_CIPHER_CTX_new() - ctx = self._backend._ffi.gc( - ctx, self._backend._lib.EVP_CIPHER_CTX_free - ) - - registry = self._backend._cipher_registry - try: - adapter = registry[type(cipher), type(mode)] - except KeyError: - raise UnsupportedAlgorithm( - "cipher {} in {} mode is not supported " - "by this backend.".format( - cipher.name, mode.name if mode else mode - ), - _Reasons.UNSUPPORTED_CIPHER, - ) - - evp_cipher = adapter(self._backend, cipher, mode) - if evp_cipher == self._backend._ffi.NULL: - msg = f"cipher {cipher.name} " - if mode is not None: - msg += f"in {mode.name} mode " - msg += ( - "is not supported by this backend (Your version of OpenSSL " - "may be too old. Current version: {}.)" - ).format(self._backend.openssl_version_text()) - raise UnsupportedAlgorithm(msg, _Reasons.UNSUPPORTED_CIPHER) - - if isinstance(mode, modes.ModeWithInitializationVector): - iv_nonce = self._backend._ffi.from_buffer( - mode.initialization_vector - ) - elif isinstance(mode, modes.ModeWithTweak): - iv_nonce = self._backend._ffi.from_buffer(mode.tweak) - elif isinstance(mode, modes.ModeWithNonce): - iv_nonce = self._backend._ffi.from_buffer(mode.nonce) - elif isinstance(cipher, algorithms.ChaCha20): - iv_nonce = self._backend._ffi.from_buffer(cipher.nonce) - else: - iv_nonce = self._backend._ffi.NULL - # begin init with cipher and operation type - res = self._backend._lib.EVP_CipherInit_ex( - ctx, - evp_cipher, - self._backend._ffi.NULL, - self._backend._ffi.NULL, - self._backend._ffi.NULL, - operation, - ) - self._backend.openssl_assert(res != 0) - # set the key length to handle variable key ciphers - res = self._backend._lib.EVP_CIPHER_CTX_set_key_length( - ctx, len(cipher.key) - ) - self._backend.openssl_assert(res != 0) - if isinstance(mode, modes.GCM): - res = self._backend._lib.EVP_CIPHER_CTX_ctrl( - ctx, - self._backend._lib.EVP_CTRL_AEAD_SET_IVLEN, - len(iv_nonce), - self._backend._ffi.NULL, - ) - self._backend.openssl_assert(res != 0) - if mode.tag is not None: - res = self._backend._lib.EVP_CIPHER_CTX_ctrl( - ctx, - self._backend._lib.EVP_CTRL_AEAD_SET_TAG, - len(mode.tag), - mode.tag, - ) - self._backend.openssl_assert(res != 0) - self._tag = mode.tag - - # pass key/iv - res = self._backend._lib.EVP_CipherInit_ex( - ctx, - self._backend._ffi.NULL, - self._backend._ffi.NULL, - self._backend._ffi.from_buffer(cipher.key), - iv_nonce, - operation, - ) - - # Check for XTS mode duplicate keys error - errors = self._backend._consume_errors() - lib = self._backend._lib - if res == 0 and ( - ( - not lib.CRYPTOGRAPHY_IS_LIBRESSL - and errors[0]._lib_reason_match( - lib.ERR_LIB_EVP, lib.EVP_R_XTS_DUPLICATED_KEYS - ) - ) - or ( - lib.Cryptography_HAS_PROVIDERS - and errors[0]._lib_reason_match( - lib.ERR_LIB_PROV, lib.PROV_R_XTS_DUPLICATED_KEYS - ) - ) - ): - raise ValueError("In XTS mode duplicated keys are not allowed") - - self._backend.openssl_assert(res != 0, errors=errors) - - # We purposely disable padding here as it's handled higher up in the - # API. - self._backend._lib.EVP_CIPHER_CTX_set_padding(ctx, 0) - self._ctx = ctx - - def update(self, data: bytes) -> bytes: - buf = bytearray(len(data) + self._block_size_bytes - 1) - n = self.update_into(data, buf) - return bytes(buf[:n]) - - def update_into(self, data: bytes, buf: bytes) -> int: - total_data_len = len(data) - if len(buf) < (total_data_len + self._block_size_bytes - 1): - raise ValueError( - "buffer must be at least {} bytes for this payload".format( - len(data) + self._block_size_bytes - 1 - ) - ) - - data_processed = 0 - total_out = 0 - outlen = self._backend._ffi.new("int *") - baseoutbuf = self._backend._ffi.from_buffer(buf, require_writable=True) - baseinbuf = self._backend._ffi.from_buffer(data) - - while data_processed != total_data_len: - outbuf = baseoutbuf + total_out - inbuf = baseinbuf + data_processed - inlen = min(self._MAX_CHUNK_SIZE, total_data_len - data_processed) - - res = self._backend._lib.EVP_CipherUpdate( - self._ctx, outbuf, outlen, inbuf, inlen - ) - if res == 0 and isinstance(self._mode, modes.XTS): - self._backend._consume_errors() - raise ValueError( - "In XTS mode you must supply at least a full block in the " - "first update call. For AES this is 16 bytes." - ) - else: - self._backend.openssl_assert(res != 0) - data_processed += inlen - total_out += outlen[0] - - return total_out - - def finalize(self) -> bytes: - if ( - self._operation == self._DECRYPT - and isinstance(self._mode, modes.ModeWithAuthenticationTag) - and self.tag is None - ): - raise ValueError( - "Authentication tag must be provided when decrypting." - ) - - buf = self._backend._ffi.new("unsigned char[]", self._block_size_bytes) - outlen = self._backend._ffi.new("int *") - res = self._backend._lib.EVP_CipherFinal_ex(self._ctx, buf, outlen) - if res == 0: - errors = self._backend._consume_errors() - - if not errors and isinstance(self._mode, modes.GCM): - raise InvalidTag - - lib = self._backend._lib - self._backend.openssl_assert( - errors[0]._lib_reason_match( - lib.ERR_LIB_EVP, - lib.EVP_R_DATA_NOT_MULTIPLE_OF_BLOCK_LENGTH, - ) - or ( - lib.Cryptography_HAS_PROVIDERS - and errors[0]._lib_reason_match( - lib.ERR_LIB_PROV, - lib.PROV_R_WRONG_FINAL_BLOCK_LENGTH, - ) - ) - or ( - lib.CRYPTOGRAPHY_IS_BORINGSSL - and errors[0].reason - == lib.CIPHER_R_DATA_NOT_MULTIPLE_OF_BLOCK_LENGTH - ), - errors=errors, - ) - raise ValueError( - "The length of the provided data is not a multiple of " - "the block length." - ) - - if ( - isinstance(self._mode, modes.GCM) - and self._operation == self._ENCRYPT - ): - tag_buf = self._backend._ffi.new( - "unsigned char[]", self._block_size_bytes - ) - res = self._backend._lib.EVP_CIPHER_CTX_ctrl( - self._ctx, - self._backend._lib.EVP_CTRL_AEAD_GET_TAG, - self._block_size_bytes, - tag_buf, - ) - self._backend.openssl_assert(res != 0) - self._tag = self._backend._ffi.buffer(tag_buf)[:] - - res = self._backend._lib.EVP_CIPHER_CTX_reset(self._ctx) - self._backend.openssl_assert(res == 1) - return self._backend._ffi.buffer(buf)[: outlen[0]] - - def finalize_with_tag(self, tag: bytes) -> bytes: - tag_len = len(tag) - if tag_len < self._mode._min_tag_length: - raise ValueError( - "Authentication tag must be {} bytes or longer.".format( - self._mode._min_tag_length - ) - ) - elif tag_len > self._block_size_bytes: - raise ValueError( - "Authentication tag cannot be more than {} bytes.".format( - self._block_size_bytes - ) - ) - res = self._backend._lib.EVP_CIPHER_CTX_ctrl( - self._ctx, self._backend._lib.EVP_CTRL_AEAD_SET_TAG, len(tag), tag - ) - self._backend.openssl_assert(res != 0) - self._tag = tag - return self.finalize() - - def authenticate_additional_data(self, data: bytes) -> None: - outlen = self._backend._ffi.new("int *") - res = self._backend._lib.EVP_CipherUpdate( - self._ctx, - self._backend._ffi.NULL, - outlen, - self._backend._ffi.from_buffer(data), - len(data), - ) - self._backend.openssl_assert(res != 0) - - @property - def tag(self) -> bytes | None: - return self._tag diff --git a/src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi b/src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi index 9cdb4d6a5c6e4..5180d3422979c 100644 --- a/src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi +++ b/src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi @@ -6,6 +6,7 @@ import typing from cryptography.hazmat.bindings._rust.openssl import ( aead, + ciphers, cmac, dh, dsa, @@ -26,6 +27,7 @@ __all__ = [ "openssl_version", "raise_openssl_error", "aead", + "ciphers", "cmac", "dh", "dsa", diff --git a/src/cryptography/hazmat/bindings/_rust/openssl/ciphers.pyi b/src/cryptography/hazmat/bindings/_rust/openssl/ciphers.pyi new file mode 100644 index 0000000000000..a64d4c755abb6 --- /dev/null +++ b/src/cryptography/hazmat/bindings/_rust/openssl/ciphers.pyi @@ -0,0 +1,35 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +import typing + +from cryptography.hazmat.primitives import ciphers +from cryptography.hazmat.primitives.ciphers import modes + +@typing.overload +def create_encryption_ctx( + algorithm: ciphers.CipherAlgorithm, mode: modes.ModeWithAuthenticationTag +) -> ciphers.AEADEncryptionContext: ... +@typing.overload +def create_encryption_ctx( + algorithm: ciphers.CipherAlgorithm, mode: modes.Mode +) -> ciphers.CipherContext: ... +@typing.overload +def create_decryption_ctx( + algorithm: ciphers.CipherAlgorithm, mode: modes.ModeWithAuthenticationTag +) -> ciphers.AEADDecryptionContext: ... +@typing.overload +def create_decryption_ctx( + algorithm: ciphers.CipherAlgorithm, mode: modes.Mode +) -> ciphers.CipherContext: ... +def _advance( + ctx: ciphers.AEADEncryptionContext | ciphers.AEADDecryptionContext, n: int +) -> None: ... +def _advance_aad( + ctx: ciphers.AEADEncryptionContext | ciphers.AEADDecryptionContext, n: int +) -> None: ... + +class CipherContext: ... +class AEADEncryptionContext: ... +class AEADDecryptionContext: ... diff --git a/src/cryptography/hazmat/bindings/openssl/binding.py b/src/cryptography/hazmat/bindings/openssl/binding.py index 40814f2a58a08..93c3acc833d18 100644 --- a/src/cryptography/hazmat/bindings/openssl/binding.py +++ b/src/cryptography/hazmat/bindings/openssl/binding.py @@ -17,13 +17,9 @@ from cryptography.hazmat.bindings.openssl._conditional import CONDITIONAL_NAMES -def _openssl_assert( - ok: bool, - errors: list[openssl.OpenSSLError] | None = None, -) -> None: +def _openssl_assert(ok: bool) -> None: if not ok: - if errors is None: - errors = openssl.capture_error_stack() + errors = openssl.capture_error_stack() raise InternalError( "Unknown OpenSSL error. This error is commonly encountered when " diff --git a/src/cryptography/hazmat/primitives/ciphers/base.py b/src/cryptography/hazmat/primitives/ciphers/base.py index 2082df669a23b..7c32cbec693e4 100644 --- a/src/cryptography/hazmat/primitives/ciphers/base.py +++ b/src/cryptography/hazmat/primitives/ciphers/base.py @@ -7,19 +7,10 @@ import abc import typing -from cryptography.exceptions import ( - AlreadyFinalized, - AlreadyUpdated, - NotYetFinalized, -) +from cryptography.hazmat.bindings._rust import openssl as rust_openssl from cryptography.hazmat.primitives._cipheralgorithm import CipherAlgorithm from cryptography.hazmat.primitives.ciphers import modes -if typing.TYPE_CHECKING: - from cryptography.hazmat.backends.openssl.ciphers import ( - _CipherContext as _BackendCipherContext, - ) - class CipherContext(metaclass=abc.ABCMeta): @abc.abstractmethod @@ -112,12 +103,10 @@ def encryptor(self): raise ValueError( "Authentication tag must be None when encrypting." ) - from cryptography.hazmat.backends.openssl.backend import backend - ctx = backend.create_symmetric_encryption_ctx( + return rust_openssl.ciphers.create_encryption_ctx( self.algorithm, self.mode ) - return self._wrap_ctx(ctx, encrypt=True) @typing.overload def decryptor( @@ -132,23 +121,9 @@ def decryptor( ... def decryptor(self): - from cryptography.hazmat.backends.openssl.backend import backend - - ctx = backend.create_symmetric_decryption_ctx( + return rust_openssl.ciphers.create_decryption_ctx( self.algorithm, self.mode ) - return self._wrap_ctx(ctx, encrypt=False) - - def _wrap_ctx( - self, ctx: _BackendCipherContext, encrypt: bool - ) -> AEADEncryptionContext | AEADDecryptionContext | CipherContext: - if isinstance(self.mode, modes.ModeWithAuthenticationTag): - if encrypt: - return _AEADEncryptionContext(ctx) - else: - return _AEADDecryptionContext(ctx) - else: - return _CipherContext(ctx) _CIPHER_TYPE = Cipher[ @@ -161,112 +136,6 @@ def _wrap_ctx( ] ] - -class _CipherContext(CipherContext): - _ctx: _BackendCipherContext | None - - def __init__(self, ctx: _BackendCipherContext) -> None: - self._ctx = ctx - - def update(self, data: bytes) -> bytes: - if self._ctx is None: - raise AlreadyFinalized("Context was already finalized.") - return self._ctx.update(data) - - def update_into(self, data: bytes, buf: bytes) -> int: - if self._ctx is None: - raise AlreadyFinalized("Context was already finalized.") - return self._ctx.update_into(data, buf) - - def finalize(self) -> bytes: - if self._ctx is None: - raise AlreadyFinalized("Context was already finalized.") - data = self._ctx.finalize() - self._ctx = None - return data - - -class _AEADCipherContext(AEADCipherContext): - _ctx: _BackendCipherContext | None - _tag: bytes | None - - def __init__(self, ctx: _BackendCipherContext) -> None: - self._ctx = ctx - self._bytes_processed = 0 - self._aad_bytes_processed = 0 - self._tag = None - self._updated = False - - def _check_limit(self, data_size: int) -> None: - if self._ctx is None: - raise AlreadyFinalized("Context was already finalized.") - self._updated = True - self._bytes_processed += data_size - if self._bytes_processed > self._ctx._mode._MAX_ENCRYPTED_BYTES: - raise ValueError( - "{} has a maximum encrypted byte limit of {}".format( - self._ctx._mode.name, self._ctx._mode._MAX_ENCRYPTED_BYTES - ) - ) - - def update(self, data: bytes) -> bytes: - self._check_limit(len(data)) - # mypy needs this assert even though _check_limit already checked - assert self._ctx is not None - return self._ctx.update(data) - - def update_into(self, data: bytes, buf: bytes) -> int: - self._check_limit(len(data)) - # mypy needs this assert even though _check_limit already checked - assert self._ctx is not None - return self._ctx.update_into(data, buf) - - def finalize(self) -> bytes: - if self._ctx is None: - raise AlreadyFinalized("Context was already finalized.") - data = self._ctx.finalize() - self._tag = self._ctx.tag - self._ctx = None - return data - - def authenticate_additional_data(self, data: bytes) -> None: - if self._ctx is None: - raise AlreadyFinalized("Context was already finalized.") - if self._updated: - raise AlreadyUpdated("Update has been called on this context.") - - self._aad_bytes_processed += len(data) - if self._aad_bytes_processed > self._ctx._mode._MAX_AAD_BYTES: - raise ValueError( - "{} has a maximum AAD byte limit of {}".format( - self._ctx._mode.name, self._ctx._mode._MAX_AAD_BYTES - ) - ) - - self._ctx.authenticate_additional_data(data) - - -class _AEADDecryptionContext(_AEADCipherContext, AEADDecryptionContext): - def finalize_with_tag(self, tag: bytes) -> bytes: - if self._ctx is None: - raise AlreadyFinalized("Context was already finalized.") - if self._ctx._tag is not None: - raise ValueError( - "tag provided both in mode and in call with finalize_with_tag:" - " tag should only be provided once" - ) - data = self._ctx.finalize_with_tag(tag) - self._tag = self._ctx.tag - self._ctx = None - return data - - -class _AEADEncryptionContext(_AEADCipherContext, AEADEncryptionContext): - @property - def tag(self) -> bytes: - if self._ctx is not None: - raise NotYetFinalized( - "You must finalize encryption before " "getting the tag." - ) - assert self._tag is not None - return self._tag +CipherContext.register(rust_openssl.ciphers.CipherContext) +AEADEncryptionContext.register(rust_openssl.ciphers.AEADEncryptionContext) +AEADDecryptionContext.register(rust_openssl.ciphers.AEADDecryptionContext) diff --git a/src/cryptography/utils.py b/src/cryptography/utils.py index a0ec7a3cd76d9..b837387a4c17a 100644 --- a/src/cryptography/utils.py +++ b/src/cryptography/utils.py @@ -52,6 +52,13 @@ def _extract_buffer_length(obj: typing.Any) -> tuple[typing.Any, int]: return buf, int(_openssl.ffi.cast("uintptr_t", buf)) +def _extract_mut_buffer_length(obj: typing.Any) -> tuple[typing.Any, int]: + from cryptography.hazmat.bindings._rust import _openssl + + buf = _openssl.ffi.from_buffer(obj, require_writable=True) + return buf, int(_openssl.ffi.cast("uintptr_t", buf)) + + class InterfaceNotImplemented(Exception): pass diff --git a/src/rust/Cargo.lock b/src/rust/Cargo.lock index e8e0517dcfbbf..ceb786df349db 100644 --- a/src/rust/Cargo.lock +++ b/src/rust/Cargo.lock @@ -178,8 +178,7 @@ checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "openssl" version = "0.10.62" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cde4d2d9200ad5909f8dac647e29482e07c3a35de8a13fce7c9c7747ad9f671" +source = "git+https://github.com/sfackler/rust-openssl#1ea720c6073818007330042472ad2f6d5621a152" dependencies = [ "bitflags 2.4.0", "cfg-if", @@ -193,8 +192,7 @@ dependencies = [ [[package]] name = "openssl-macros" version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +source = "git+https://github.com/sfackler/rust-openssl#1ea720c6073818007330042472ad2f6d5621a152" dependencies = [ "proc-macro2", "quote", @@ -204,8 +202,7 @@ dependencies = [ [[package]] name = "openssl-sys" version = "0.9.98" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1665caf8ab2dc9aef43d1c0023bd904633a6a05cb30b0ad59bec2ae986e57a7" +source = "git+https://github.com/sfackler/rust-openssl#1ea720c6073818007330042472ad2f6d5621a152" dependencies = [ "cc", "libc", diff --git a/src/rust/Cargo.toml b/src/rust/Cargo.toml index 13e35e298a30d..743e04443ca37 100644 --- a/src/rust/Cargo.toml +++ b/src/rust/Cargo.toml @@ -43,3 +43,7 @@ members = [ "cryptography-x509", "cryptography-x509-verification", ] + +[patch.crates-io] +openssl = { git = "https://github.com/sfackler/rust-openssl" } +openssl-sys = { git = "https://github.com/sfackler/rust-openssl" } \ No newline at end of file diff --git a/src/rust/src/backend/cipher_registry.rs b/src/rust/src/backend/cipher_registry.rs index 5c62ff8c0f739..d6704136b6ed0 100644 --- a/src/rust/src/backend/cipher_registry.rs +++ b/src/rust/src/backend/cipher_registry.rs @@ -52,9 +52,26 @@ impl std::hash::Hash for RegistryKey { } } +enum RegistryCipher { + Ref(&'static openssl::cipher::CipherRef), + Owned(Cipher), +} + +impl From<&'static openssl::cipher::CipherRef> for RegistryCipher { + fn from(c: &'static openssl::cipher::CipherRef) -> RegistryCipher { + RegistryCipher::Ref(c) + } +} + +impl From for RegistryCipher { + fn from(c: Cipher) -> RegistryCipher { + RegistryCipher::Owned(c) + } +} + struct RegisteryBuilder<'p> { py: pyo3::Python<'p>, - m: HashMap, + m: HashMap, } impl<'p> RegisteryBuilder<'p> { @@ -70,27 +87,26 @@ impl<'p> RegisteryBuilder<'p> { algorithm: &pyo3::PyAny, mode: &pyo3::PyAny, key_size: Option, - cipher: &'static openssl::cipher::CipherRef, + cipher: impl Into, ) -> CryptographyResult<()> { self.m.insert( RegistryKey::new(self.py, algorithm.into(), mode.into(), key_size)?, - cipher, + cipher.into(), ); Ok(()) } - fn build(self) -> HashMap { + fn build(self) -> HashMap { self.m } } fn get_cipher_registry( py: pyo3::Python<'_>, -) -> CryptographyResult<&HashMap> { - static REGISTRY: pyo3::sync::GILOnceCell< - HashMap, - > = pyo3::sync::GILOnceCell::new(); +) -> CryptographyResult<&HashMap> { + static REGISTRY: pyo3::sync::GILOnceCell> = + pyo3::sync::GILOnceCell::new(); REGISTRY.get_or_try_init(py, || { let mut m = RegisteryBuilder::new(py); @@ -111,49 +127,169 @@ fn get_cipher_registry( let sm4 = types::SM4.get(py)?; #[cfg(not(CRYPTOGRAPHY_OSSLCONF = "OPENSSL_NO_SEED"))] let seed = types::SEED.get(py)?; + let arc4 = types::ARC4.get(py)?; + let chacha20 = types::CHACHA20.get(py)?; let cbc = types::CBC.get(py)?; + let cfb = types::CFB.get(py)?; + let cfb8 = types::CFB8.get(py)?; + let ofb = types::OFB.get(py)?; + let ecb = types::ECB.get(py)?; + let ctr = types::CTR.get(py)?; + let gcm = types::GCM.get(py)?; + let xts = types::XTS.get(py)?; + + let none = py.None(); + let none_type = none.as_ref(py).get_type(); m.add(aes, cbc, Some(128), Cipher::aes_128_cbc())?; m.add(aes, cbc, Some(192), Cipher::aes_192_cbc())?; m.add(aes, cbc, Some(256), Cipher::aes_256_cbc())?; + m.add(aes, ofb, Some(128), Cipher::aes_128_ofb())?; + m.add(aes, ofb, Some(192), Cipher::aes_192_ofb())?; + m.add(aes, ofb, Some(256), Cipher::aes_256_ofb())?; + + m.add(aes, gcm, Some(128), Cipher::aes_128_gcm())?; + m.add(aes, gcm, Some(192), Cipher::aes_192_gcm())?; + m.add(aes, gcm, Some(256), Cipher::aes_256_gcm())?; + + m.add(aes, ctr, Some(128), Cipher::aes_128_ctr())?; + m.add(aes, ctr, Some(192), Cipher::aes_192_ctr())?; + m.add(aes, ctr, Some(256), Cipher::aes_256_ctr())?; + + #[cfg(not(CRYPTOGRAPHY_IS_BORINGSSL))] + { + m.add(aes, cfb8, Some(128), Cipher::aes_128_cfb8())?; + m.add(aes, cfb8, Some(192), Cipher::aes_192_cfb8())?; + m.add(aes, cfb8, Some(256), Cipher::aes_256_cfb8())?; + + m.add(aes, cfb, Some(128), Cipher::aes_128_cfb128())?; + m.add(aes, cfb, Some(192), Cipher::aes_192_cfb128())?; + m.add(aes, cfb, Some(256), Cipher::aes_256_cfb128())?; + } + + m.add(aes, ecb, Some(128), Cipher::aes_128_ecb())?; + m.add(aes, ecb, Some(192), Cipher::aes_192_ecb())?; + m.add(aes, ecb, Some(256), Cipher::aes_256_ecb())?; + + #[cfg(not(CRYPTOGRAPHY_IS_BORINGSSL))] + { + m.add(aes, xts, Some(256), Cipher::aes_128_xts())?; + m.add(aes, xts, Some(512), Cipher::aes_256_xts())?; + } + m.add(aes128, cbc, Some(128), Cipher::aes_128_cbc())?; m.add(aes256, cbc, Some(256), Cipher::aes_256_cbc())?; + m.add(aes128, ofb, Some(128), Cipher::aes_128_ofb())?; + m.add(aes256, ofb, Some(256), Cipher::aes_256_ofb())?; + + m.add(aes128, gcm, Some(128), Cipher::aes_128_gcm())?; + m.add(aes256, gcm, Some(256), Cipher::aes_256_gcm())?; + + m.add(aes128, ctr, Some(128), Cipher::aes_128_ctr())?; + m.add(aes256, ctr, Some(256), Cipher::aes_256_ctr())?; + + #[cfg(not(CRYPTOGRAPHY_IS_BORINGSSL))] + { + m.add(aes128, cfb8, Some(128), Cipher::aes_128_cfb8())?; + m.add(aes256, cfb8, Some(256), Cipher::aes_256_cfb8())?; + + m.add(aes128, cfb, Some(128), Cipher::aes_128_cfb128())?; + m.add(aes256, cfb, Some(256), Cipher::aes_256_cfb128())?; + } + + m.add(aes128, ecb, Some(128), Cipher::aes_128_ecb())?; + m.add(aes256, ecb, Some(256), Cipher::aes_256_ecb())?; + m.add(triple_des, cbc, Some(192), Cipher::des_ede3_cbc())?; + m.add(triple_des, ecb, Some(192), Cipher::des_ede3_ecb())?; + #[cfg(not(CRYPTOGRAPHY_IS_BORINGSSL))] + { + m.add(triple_des, cfb8, Some(192), Cipher::des_ede3_cfb8())?; + m.add(triple_des, cfb, Some(192), Cipher::des_ede3_cfb64())?; + m.add(triple_des, ofb, Some(192), Cipher::des_ede3_ofb())?; + } #[cfg(not(CRYPTOGRAPHY_OSSLCONF = "OPENSSL_NO_CAMELLIA"))] - m.add(camellia, cbc, Some(128), Cipher::camellia128_cbc())?; - #[cfg(not(CRYPTOGRAPHY_OSSLCONF = "OPENSSL_NO_CAMELLIA"))] - m.add(camellia, cbc, Some(192), Cipher::camellia192_cbc())?; - #[cfg(not(CRYPTOGRAPHY_OSSLCONF = "OPENSSL_NO_CAMELLIA"))] - m.add(camellia, cbc, Some(256), Cipher::camellia256_cbc())?; + { + m.add(camellia, cbc, Some(128), Cipher::camellia128_cbc())?; + m.add(camellia, cbc, Some(192), Cipher::camellia192_cbc())?; + m.add(camellia, cbc, Some(256), Cipher::camellia256_cbc())?; + + m.add(camellia, ecb, Some(128), Cipher::camellia128_ecb())?; + m.add(camellia, ecb, Some(192), Cipher::camellia192_ecb())?; + m.add(camellia, ecb, Some(256), Cipher::camellia256_ecb())?; + + m.add(camellia, ofb, Some(128), Cipher::camellia128_ofb())?; + m.add(camellia, ofb, Some(192), Cipher::camellia192_ofb())?; + m.add(camellia, ofb, Some(256), Cipher::camellia256_ofb())?; + + m.add(camellia, cfb, Some(128), Cipher::camellia128_cfb128())?; + m.add(camellia, cfb, Some(192), Cipher::camellia192_cfb128())?; + m.add(camellia, cfb, Some(256), Cipher::camellia256_cfb128())?; + } #[cfg(not(CRYPTOGRAPHY_OSSLCONF = "OPENSSL_NO_SM4"))] - m.add(sm4, cbc, Some(128), Cipher::sm4_cbc())?; + { + m.add(sm4, cbc, Some(128), Cipher::sm4_cbc())?; + m.add(sm4, ctr, Some(128), Cipher::sm4_ctr())?; + m.add(sm4, cfb, Some(128), Cipher::sm4_cfb128())?; + m.add(sm4, ofb, Some(128), Cipher::sm4_ofb())?; + m.add(sm4, ecb, Some(128), Cipher::sm4_ecb())?; + + if let Ok(c) = Cipher::fetch(None, "sm4-gcm", None) { + m.add(sm4, gcm, Some(128), c)?; + } + } #[cfg(not(CRYPTOGRAPHY_OSSLCONF = "OPENSSL_NO_SEED"))] - m.add(seed, cbc, Some(128), Cipher::seed_cbc())?; + { + m.add(seed, cbc, Some(128), Cipher::seed_cbc())?; + m.add(seed, cfb, Some(128), Cipher::seed_cfb128())?; + m.add(seed, ofb, Some(128), Cipher::seed_ofb())?; + m.add(seed, ecb, Some(128), Cipher::seed_ecb())?; + } #[cfg(not(CRYPTOGRAPHY_OSSLCONF = "OPENSSL_NO_BF"))] - m.add(blowfish, cbc, None, Cipher::bf_cbc())?; + { + m.add(blowfish, cbc, None, Cipher::bf_cbc())?; + m.add(blowfish, cfb, None, Cipher::bf_cfb64())?; + m.add(blowfish, ofb, None, Cipher::bf_ofb())?; + m.add(blowfish, ecb, None, Cipher::bf_ecb())?; + } #[cfg(not(CRYPTOGRAPHY_OSSLCONF = "OPENSSL_NO_CAST"))] - m.add(cast5, cbc, None, Cipher::cast5_cbc())?; + { + m.add(cast5, cbc, None, Cipher::cast5_cbc())?; + m.add(cast5, ecb, None, Cipher::cast5_ecb())?; + m.add(cast5, ofb, None, Cipher::cast5_ofb())?; + m.add(cast5, cfb, None, Cipher::cast5_cfb64())?; + } #[cfg(not(CRYPTOGRAPHY_OSSLCONF = "OPENSSL_NO_IDEA"))] - m.add(idea, cbc, Some(128), Cipher::idea_cbc())?; + { + m.add(idea, cbc, Some(128), Cipher::idea_cbc())?; + m.add(idea, ecb, Some(128), Cipher::idea_ecb())?; + m.add(idea, ofb, Some(128), Cipher::idea_ofb())?; + m.add(idea, cfb, Some(128), Cipher::idea_cfb64())?; + } + + #[cfg(not(CRYPTOGRAPHY_IS_BORINGSSL))] + m.add(chacha20, none_type, None, Cipher::chacha20())?; + + m.add(arc4, none_type, None, Cipher::rc4())?; Ok(m.build()) }) } -pub(crate) fn get_cipher<'a>( - py: pyo3::Python<'_>, +pub(crate) fn get_cipher<'py>( + py: pyo3::Python<'py>, algorithm: &pyo3::PyAny, mode_cls: &pyo3::PyAny, -) -> CryptographyResult> { +) -> CryptographyResult> { let registry = get_cipher_registry(py)?; let key_size = algorithm @@ -161,5 +297,9 @@ pub(crate) fn get_cipher<'a>( .extract()?; let key = RegistryKey::new(py, algorithm.get_type().into(), mode_cls.into(), key_size)?; - Ok(registry.get(&key).cloned()) + match registry.get(&key) { + Some(RegistryCipher::Ref(c)) => Ok(Some(c)), + Some(RegistryCipher::Owned(c)) => Ok(Some(c)), + None => Ok(None), + } } diff --git a/src/rust/src/backend/ciphers.rs b/src/rust/src/backend/ciphers.rs new file mode 100644 index 0000000000000..e55b5256fd3bc --- /dev/null +++ b/src/rust/src/backend/ciphers.rs @@ -0,0 +1,557 @@ +// This file is dual licensed under the terms of the Apache License, Version +// 2.0, and the BSD License. See the LICENSE file in the root of this repository +// for complete details. + +use crate::backend::cipher_registry; +use crate::buf::{CffiBuf, CffiMutBuf}; +use crate::error::{CryptographyError, CryptographyResult}; +use crate::exceptions; +use crate::types; +use pyo3::IntoPy; + +struct CipherContext { + ctx: openssl::cipher_ctx::CipherCtx, + py_mode: pyo3::PyObject, +} + +impl CipherContext { + fn new( + py: pyo3::Python<'_>, + algorithm: &pyo3::PyAny, + mode: &pyo3::PyAny, + side: openssl::symm::Mode, + ) -> CryptographyResult { + let cipher = match cipher_registry::get_cipher(py, algorithm, mode.get_type())? { + Some(c) => c, + None => { + return Err(CryptographyError::from( + exceptions::UnsupportedAlgorithm::new_err(( + format!( + "cipher {} in {} mode is not supported ", + algorithm.getattr(pyo3::intern!(py, "name"))?, + if mode.is_true()? { + mode.getattr(pyo3::intern!(py, "name"))? + } else { + mode + } + ), + exceptions::Reasons::UNSUPPORTED_CIPHER, + )), + )) + } + }; + + let iv_nonce = if mode.is_instance(types::MODE_WITH_INITIALIZATION_VECTOR.get(py)?)? { + Some( + mode.getattr(pyo3::intern!(py, "initialization_vector"))? + .extract::>()?, + ) + } else if mode.is_instance(types::MODE_WITH_TWEAK.get(py)?)? { + Some( + mode.getattr(pyo3::intern!(py, "tweak"))? + .extract::>()?, + ) + } else if mode.is_instance(types::MODE_WITH_NONCE.get(py)?)? { + Some( + mode.getattr(pyo3::intern!(py, "nonce"))? + .extract::>()?, + ) + } else if algorithm.is_instance(types::CHACHA20.get(py)?)? { + Some( + algorithm + .getattr(pyo3::intern!(py, "nonce"))? + .extract::>()?, + ) + } else { + None + }; + + let key = algorithm + .getattr(pyo3::intern!(py, "key"))? + .extract::>()?; + + let init_op = match side { + openssl::symm::Mode::Encrypt => openssl::cipher_ctx::CipherCtxRef::encrypt_init, + openssl::symm::Mode::Decrypt => openssl::cipher_ctx::CipherCtxRef::decrypt_init, + }; + + let mut ctx = openssl::cipher_ctx::CipherCtx::new()?; + init_op(&mut ctx, Some(cipher), None, None)?; + ctx.set_key_length(key.as_bytes().len())?; + + if let Some(iv) = iv_nonce.as_ref() { + if cipher.iv_length() != 0 { + ctx.set_iv_length(iv.as_bytes().len())?; + } + } + + if mode.is_instance(types::XTS.get(py)?)? { + init_op( + &mut ctx, + None, + Some(key.as_bytes()), + iv_nonce.as_ref().map(|b| b.as_bytes()), + ) + .map_err(|_| { + pyo3::exceptions::PyValueError::new_err( + "In XTS mode duplicated keys are not allowed", + ) + })?; + } else { + init_op( + &mut ctx, + None, + Some(key.as_bytes()), + iv_nonce.as_ref().map(|b| b.as_bytes()), + )?; + }; + + ctx.set_padding(false); + + Ok(CipherContext { + ctx, + py_mode: mode.into(), + }) + } + + fn update<'p>( + &mut self, + py: pyo3::Python<'p>, + buf: &[u8], + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let mut out_buf = vec![0; buf.len() + self.ctx.block_size()]; + let n = self.update_into(py, buf, &mut out_buf)?; + Ok(pyo3::types::PyBytes::new(py, &out_buf[..n])) + } + + fn update_into( + &mut self, + py: pyo3::Python<'_>, + buf: &[u8], + out_buf: &mut [u8], + ) -> CryptographyResult { + if out_buf.len() < (buf.len() + self.ctx.block_size() - 1) { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err(format!( + "buffer must be at least {} bytes for this payload", + buf.len() + self.ctx.block_size() - 1 + )), + )); + } + + let mut total_written = 0; + for chunk in buf.chunks(1 << 29) { + // SAFETY: We ensure that outbuf is sufficiently large above. + unsafe { + let n = if self.py_mode.as_ref(py).is_instance(types::XTS.get(py)?)? { + self.ctx.cipher_update_unchecked(chunk, Some(&mut out_buf[total_written..])).map_err(|_| { + pyo3::exceptions::PyValueError::new_err( + "In XTS mode you must supply at least a full block in the first update call. For AES this is 16 bytes." + ) + })? + } else { + self.ctx + .cipher_update_unchecked(chunk, Some(&mut out_buf[total_written..]))? + }; + total_written += n; + } + } + + Ok(total_written) + } + + fn authenticate_additional_data(&mut self, buf: &[u8]) -> CryptographyResult<()> { + self.ctx.cipher_update(buf, None)?; + Ok(()) + } + + fn finalize<'p>( + &mut self, + py: pyo3::Python<'p>, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let mut out_buf = vec![0; self.ctx.block_size()]; + let n = self.ctx.cipher_final(&mut out_buf).or_else(|e| { + if e.errors().is_empty() + && self + .py_mode + .as_ref(py) + .is_instance(types::MODE_WITH_AUTHENTICATION_TAG.get(py)?)? + { + return Err(CryptographyError::from(exceptions::InvalidTag::new_err(()))); + } + Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err( + "The length of the provided data is not a multiple of the block length.", + ), + )) + })?; + Ok(pyo3::types::PyBytes::new(py, &out_buf[..n])) + } +} + +#[pyo3::prelude::pyclass( + module = "cryptography.hazmat.bindings._rust.openssl.ciphers", + name = "CipherContext" +)] +struct PyCipherContext { + ctx: Option, +} + +#[pyo3::prelude::pyclass( + module = "cryptography.hazmat.bindings._rust.openssl.ciphers", + name = "AEADEncryptionContext" +)] +struct PyAEADEncryptionContext { + ctx: Option, + tag: Option>, + updated: bool, + bytes_remaining: u64, + aad_bytes_remaining: u64, +} + +#[pyo3::prelude::pyclass( + module = "cryptography.hazmat.bindings._rust.openssl.ciphers", + name = "AEADDecryptionContext" +)] +struct PyAEADDecryptionContext { + ctx: Option, + updated: bool, + bytes_remaining: u64, + aad_bytes_remaining: u64, +} + +fn get_mut_ctx(ctx: Option<&mut CipherContext>) -> pyo3::PyResult<&mut CipherContext> { + ctx.ok_or_else(|| exceptions::AlreadyFinalized::new_err("Context was already finalized.")) +} + +#[pyo3::prelude::pymethods] +impl PyCipherContext { + fn update<'p>( + &mut self, + py: pyo3::Python<'p>, + buf: CffiBuf<'_>, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + get_mut_ctx(self.ctx.as_mut())?.update(py, buf.as_bytes()) + } + + fn update_into( + &mut self, + py: pyo3::Python<'_>, + buf: CffiBuf<'_>, + mut out_buf: CffiMutBuf<'_>, + ) -> CryptographyResult { + get_mut_ctx(self.ctx.as_mut())?.update_into(py, buf.as_bytes(), out_buf.as_mut_bytes()) + } + + fn finalize<'p>( + &mut self, + py: pyo3::Python<'p>, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let result = get_mut_ctx(self.ctx.as_mut())?.finalize(py)?; + self.ctx = None; + Ok(result) + } +} + +#[pyo3::prelude::pymethods] +impl PyAEADEncryptionContext { + fn update<'p>( + &mut self, + py: pyo3::Python<'p>, + buf: CffiBuf<'_>, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let data = buf.as_bytes(); + + self.updated = true; + self.bytes_remaining = self + .bytes_remaining + .checked_sub(data.len().try_into().unwrap()) + .ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err("Exceeded maximum encrypted byte limit") + })?; + get_mut_ctx(self.ctx.as_mut())?.update(py, data) + } + + fn update_into( + &mut self, + py: pyo3::Python<'_>, + buf: CffiBuf<'_>, + mut out_buf: CffiMutBuf<'_>, + ) -> CryptographyResult { + let data = buf.as_bytes(); + + self.updated = true; + self.bytes_remaining = self + .bytes_remaining + .checked_sub(data.len().try_into().unwrap()) + .ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err("Exceeded maximum encrypted byte limit") + })?; + get_mut_ctx(self.ctx.as_mut())?.update_into(py, data, out_buf.as_mut_bytes()) + } + + fn authenticate_additional_data(&mut self, buf: CffiBuf<'_>) -> CryptographyResult<()> { + let ctx = get_mut_ctx(self.ctx.as_mut())?; + if self.updated { + return Err(CryptographyError::from( + exceptions::AlreadyUpdated::new_err("Update has been called on this context."), + )); + } + + let data = buf.as_bytes(); + self.aad_bytes_remaining = self + .aad_bytes_remaining + .checked_sub(data.len().try_into().unwrap()) + .ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err("Exceeded maximum AAD byte limit") + })?; + ctx.authenticate_additional_data(data) + } + + fn finalize<'p>( + &mut self, + py: pyo3::Python<'p>, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let ctx = get_mut_ctx(self.ctx.as_mut())?; + let result = ctx.finalize(py)?; + + // XXX: do not hard code 16 + let tag = pyo3::types::PyBytes::new_with(py, 16, |t| { + ctx.ctx.tag(t).map_err(CryptographyError::from)?; + Ok(()) + })?; + self.tag = Some(tag.into_py(py)); + self.ctx = None; + + Ok(result) + } + + #[getter] + fn tag(&self, py: pyo3::Python<'_>) -> CryptographyResult> { + Ok(self + .tag + .as_ref() + .ok_or_else(|| { + exceptions::NotYetFinalized::new_err( + "You must finalize encryption before getting the tag.", + ) + })? + .clone_ref(py)) + } +} + +#[pyo3::prelude::pymethods] +impl PyAEADDecryptionContext { + fn update<'p>( + &mut self, + py: pyo3::Python<'p>, + buf: CffiBuf<'_>, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let data = buf.as_bytes(); + + self.updated = true; + self.bytes_remaining = self + .bytes_remaining + .checked_sub(data.len().try_into().unwrap()) + .ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err("Exceeded maximum encrypted byte limit") + })?; + get_mut_ctx(self.ctx.as_mut())?.update(py, data) + } + + fn update_into( + &mut self, + py: pyo3::Python<'_>, + buf: CffiBuf<'_>, + mut out_buf: CffiMutBuf<'_>, + ) -> CryptographyResult { + let data = buf.as_bytes(); + + self.updated = true; + self.bytes_remaining = self + .bytes_remaining + .checked_sub(data.len().try_into().unwrap()) + .ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err("Exceeded maximum encrypted byte limit") + })?; + get_mut_ctx(self.ctx.as_mut())?.update_into(py, data, out_buf.as_mut_bytes()) + } + + fn authenticate_additional_data(&mut self, buf: CffiBuf<'_>) -> CryptographyResult<()> { + let ctx = get_mut_ctx(self.ctx.as_mut())?; + if self.updated { + return Err(CryptographyError::from( + exceptions::AlreadyUpdated::new_err("Update has been called on this context."), + )); + } + + let data = buf.as_bytes(); + self.aad_bytes_remaining = self + .aad_bytes_remaining + .checked_sub(data.len().try_into().unwrap()) + .ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err("Exceeded maximum AAD byte limit") + })?; + ctx.authenticate_additional_data(data) + } + + fn finalize<'p>( + &mut self, + py: pyo3::Python<'p>, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let ctx = get_mut_ctx(self.ctx.as_mut())?; + + if ctx + .py_mode + .as_ref(py) + .getattr(pyo3::intern!(py, "tag"))? + .is_none() + { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err( + "Authentication tag must be provided when decrypting.", + ), + )); + } + + let result = ctx.finalize(py)?; + self.ctx = None; + Ok(result) + } + + fn finalize_with_tag<'p>( + &mut self, + py: pyo3::Python<'p>, + tag: &[u8], + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let ctx = get_mut_ctx(self.ctx.as_mut())?; + + if !ctx + .py_mode + .as_ref(py) + .getattr(pyo3::intern!(py, "tag"))? + .is_none() + { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err( + "Authentication tag must be provided only once.", + ), + )); + } + + let min_tag_length = ctx + .py_mode + .as_ref(py) + .getattr(pyo3::intern!(py, "_min_tag_length"))? + .extract()?; + // XXX: Do not hard code 16 + if tag.len() < min_tag_length { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err(format!( + "Authentication tag must be {} bytes or longer.", + min_tag_length + )), + )); + } else if tag.len() > 16 { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err(format!( + "Authentication tag cannot be more than {} bytes.", + 16 + )), + )); + } + + ctx.ctx.set_tag(tag)?; + let result = ctx.finalize(py)?; + self.ctx = None; + Ok(result) + } +} + +#[pyo3::prelude::pyfunction] +fn create_encryption_ctx( + py: pyo3::Python<'_>, + algorithm: &pyo3::PyAny, + mode: &pyo3::PyAny, +) -> CryptographyResult { + let ctx = CipherContext::new(py, algorithm, mode, openssl::symm::Mode::Encrypt)?; + + if mode.is_instance(types::MODE_WITH_AUTHENTICATION_TAG.get(py)?)? { + Ok(PyAEADEncryptionContext { + ctx: Some(ctx), + tag: None, + updated: false, + bytes_remaining: mode + .getattr(pyo3::intern!(py, "_MAX_ENCRYPTED_BYTES"))? + .extract()?, + aad_bytes_remaining: mode + .getattr(pyo3::intern!(py, "_MAX_AAD_BYTES"))? + .extract()?, + } + .into_py(py)) + } else { + Ok(PyCipherContext { ctx: Some(ctx) }.into_py(py)) + } +} + +#[pyo3::prelude::pyfunction] +fn create_decryption_ctx( + py: pyo3::Python<'_>, + algorithm: &pyo3::PyAny, + mode: &pyo3::PyAny, +) -> CryptographyResult { + let mut ctx = CipherContext::new(py, algorithm, mode, openssl::symm::Mode::Decrypt)?; + + if mode.is_instance(types::MODE_WITH_AUTHENTICATION_TAG.get(py)?)? { + if let Some(tag) = mode.getattr(pyo3::intern!(py, "tag"))?.extract()? { + ctx.ctx.set_tag(tag)?; + } + + Ok(PyAEADDecryptionContext { + ctx: Some(ctx), + updated: false, + bytes_remaining: mode + .getattr(pyo3::intern!(py, "_MAX_ENCRYPTED_BYTES"))? + .extract()?, + aad_bytes_remaining: mode + .getattr(pyo3::intern!(py, "_MAX_AAD_BYTES"))? + .extract()?, + } + .into_py(py)) + } else { + Ok(PyCipherContext { ctx: Some(ctx) }.into_py(py)) + } +} + +#[pyo3::prelude::pyfunction] +fn _advance(ctx: &pyo3::PyAny, n: u64) { + if let Ok(c) = ctx.downcast::>() { + c.borrow_mut().bytes_remaining -= n; + } else if let Ok(c) = ctx.downcast::>() { + c.borrow_mut().bytes_remaining -= n; + } +} + +#[pyo3::prelude::pyfunction] +fn _advance_aad(ctx: &pyo3::PyAny, n: u64) { + if let Ok(c) = ctx.downcast::>() { + c.borrow_mut().aad_bytes_remaining -= n; + } else if let Ok(c) = ctx.downcast::>() { + c.borrow_mut().aad_bytes_remaining -= n; + } +} + +pub(crate) fn create_module(py: pyo3::Python<'_>) -> pyo3::PyResult<&pyo3::prelude::PyModule> { + let m = pyo3::prelude::PyModule::new(py, "ciphers")?; + m.add_function(pyo3::wrap_pyfunction!(create_encryption_ctx, m)?)?; + m.add_function(pyo3::wrap_pyfunction!(create_decryption_ctx, m)?)?; + + m.add_function(pyo3::wrap_pyfunction!(_advance, m)?)?; + m.add_function(pyo3::wrap_pyfunction!(_advance_aad, m)?)?; + + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + + Ok(m) +} diff --git a/src/rust/src/backend/mod.rs b/src/rust/src/backend/mod.rs index 7e085d623b40f..be7b2d0ac2808 100644 --- a/src/rust/src/backend/mod.rs +++ b/src/rust/src/backend/mod.rs @@ -4,6 +4,7 @@ pub(crate) mod aead; pub(crate) mod cipher_registry; +pub(crate) mod ciphers; pub(crate) mod cmac; pub(crate) mod dh; pub(crate) mod dsa; @@ -24,6 +25,7 @@ pub(crate) mod x448; pub(crate) fn add_to_module(module: &pyo3::prelude::PyModule) -> pyo3::PyResult<()> { module.add_submodule(aead::create_module(module.py())?)?; + module.add_submodule(ciphers::create_module(module.py())?)?; module.add_submodule(cmac::create_module(module.py())?)?; module.add_submodule(dh::create_module(module.py())?)?; module.add_submodule(dsa::create_module(module.py())?)?; diff --git a/src/rust/src/buf.rs b/src/rust/src/buf.rs index 0a39a80f4341c..f65e12ea1b115 100644 --- a/src/rust/src/buf.rs +++ b/src/rust/src/buf.rs @@ -48,3 +48,47 @@ impl<'a> pyo3::conversion::FromPyObject<'a> for CffiBuf<'a> { }) } } + +pub(crate) struct CffiMutBuf<'p> { + _pyobj: &'p pyo3::PyAny, + _bufobj: &'p pyo3::PyAny, + buf: &'p mut [u8], +} + +impl CffiMutBuf<'_> { + pub(crate) fn as_mut_bytes(&mut self) -> &mut [u8] { + self.buf + } +} + +impl<'a> pyo3::conversion::FromPyObject<'a> for CffiMutBuf<'a> { + fn extract(pyobj: &'a pyo3::PyAny) -> pyo3::PyResult { + let py = pyobj.py(); + + let (bufobj, ptrval): (&pyo3::PyAny, usize) = types::EXTRACT_MUT_BUFFER_LENGTH + .get(py)? + .call1((pyobj,))? + .extract()?; + + let len = bufobj.len()?; + let ptr = if len == 0 { + ptr::NonNull::dangling().as_ptr() + } else { + ptrval as *mut u8 + }; + + Ok(CffiMutBuf { + _pyobj: pyobj, + _bufobj: bufobj, + // SAFETY: _extract_buffer_length ensures that we have a valid ptr + // and length (and we ensure we meet slice's requirements for + // 0-length slices above), we're keeping pyobj alive which ensures + // the buffer is valid. But! There is no actually guarantee + // against concurrent mutation. See + // https://alexgaynor.net/2022/oct/23/buffers-on-the-edge/ + // for details. This is the same as our cffi status quo ante, so + // we're doing an unsound thing and living with it. + buf: unsafe { slice::from_raw_parts_mut(ptr, len) }, + }) + } +} diff --git a/src/rust/src/exceptions.rs b/src/rust/src/exceptions.rs index 1354d1b596b86..aae00f71c0c43 100644 --- a/src/rust/src/exceptions.rs +++ b/src/rust/src/exceptions.rs @@ -23,10 +23,12 @@ pub(crate) enum Reasons { UNSUPPORTED_MAC, } +pyo3::import_exception!(cryptography.exceptions, AlreadyUpdated); pyo3::import_exception!(cryptography.exceptions, AlreadyFinalized); pyo3::import_exception!(cryptography.exceptions, InternalError); pyo3::import_exception!(cryptography.exceptions, InvalidSignature); pyo3::import_exception!(cryptography.exceptions, InvalidTag); +pyo3::import_exception!(cryptography.exceptions, NotYetFinalized); pyo3::import_exception!(cryptography.exceptions, UnsupportedAlgorithm); pyo3::import_exception!(cryptography.x509, AttributeNotFound); pyo3::import_exception!(cryptography.x509, DuplicateExtension); diff --git a/src/rust/src/types.rs b/src/rust/src/types.rs index cf323bfd28af0..4172aa53f5efd 100644 --- a/src/rust/src/types.rs +++ b/src/rust/src/types.rs @@ -481,6 +481,8 @@ pub static DSA_PRIVATE_NUMBERS: LazyPyImport = LazyPyImport::new( pub static EXTRACT_BUFFER_LENGTH: LazyPyImport = LazyPyImport::new("cryptography.utils", &["_extract_buffer_length"]); +pub static EXTRACT_MUT_BUFFER_LENGTH: LazyPyImport = + LazyPyImport::new("cryptography.utils", &["_extract_mut_buffer_length"]); pub static BLOCK_CIPHER_ALGORITHM: LazyPyImport = LazyPyImport::new( "cryptography.hazmat.primitives.ciphers", @@ -503,6 +505,10 @@ pub static AES256: LazyPyImport = LazyPyImport::new( "cryptography.hazmat.primitives.ciphers.algorithms", &["AES256"], ); +pub static CHACHA20: LazyPyImport = LazyPyImport::new( + "cryptography.hazmat.primitives.ciphers.algorithms", + &["ChaCha20"], +); pub static SM4: LazyPyImport = LazyPyImport::new( "cryptography.hazmat.primitives.ciphers.algorithms", &["SM4"], @@ -528,9 +534,43 @@ pub static IDEA: LazyPyImport = LazyPyImport::new( "cryptography.hazmat.primitives.ciphers.algorithms", &["_IDEAInternal"], ); +pub static ARC4: LazyPyImport = LazyPyImport::new( + "cryptography.hazmat.primitives.ciphers.algorithms", + &["ARC4"], +); +pub static MODE_WITH_INITIALIZATION_VECTOR: LazyPyImport = LazyPyImport::new( + "cryptography.hazmat.primitives.ciphers.modes", + &["ModeWithInitializationVector"], +); +pub static MODE_WITH_TWEAK: LazyPyImport = LazyPyImport::new( + "cryptography.hazmat.primitives.ciphers.modes", + &["ModeWithTweak"], +); +pub static MODE_WITH_NONCE: LazyPyImport = LazyPyImport::new( + "cryptography.hazmat.primitives.ciphers.modes", + &["ModeWithNonce"], +); +pub static MODE_WITH_AUTHENTICATION_TAG: LazyPyImport = LazyPyImport::new( + "cryptography.hazmat.primitives.ciphers.modes", + &["ModeWithAuthenticationTag"], +); pub static CBC: LazyPyImport = LazyPyImport::new("cryptography.hazmat.primitives.ciphers.modes", &["CBC"]); +pub static CFB: LazyPyImport = + LazyPyImport::new("cryptography.hazmat.primitives.ciphers.modes", &["CFB"]); +pub static CFB8: LazyPyImport = + LazyPyImport::new("cryptography.hazmat.primitives.ciphers.modes", &["CFB8"]); +pub static OFB: LazyPyImport = + LazyPyImport::new("cryptography.hazmat.primitives.ciphers.modes", &["OFB"]); +pub static ECB: LazyPyImport = + LazyPyImport::new("cryptography.hazmat.primitives.ciphers.modes", &["ECB"]); +pub static CTR: LazyPyImport = + LazyPyImport::new("cryptography.hazmat.primitives.ciphers.modes", &["CTR"]); +pub static GCM: LazyPyImport = + LazyPyImport::new("cryptography.hazmat.primitives.ciphers.modes", &["GCM"]); +pub static XTS: LazyPyImport = + LazyPyImport::new("cryptography.hazmat.primitives.ciphers.modes", &["XTS"]); #[cfg(test)] mod tests { diff --git a/tests/hazmat/backends/test_openssl.py b/tests/hazmat/backends/test_openssl.py index 5b33d76ef2450..e81b8f8a2ef18 100644 --- a/tests/hazmat/backends/test_openssl.py +++ b/tests/hazmat/backends/test_openssl.py @@ -13,7 +13,6 @@ from cryptography.hazmat.backends.openssl.backend import backend from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import padding -from cryptography.hazmat.primitives.ciphers import Cipher from cryptography.hazmat.primitives.ciphers.algorithms import AES from cryptography.hazmat.primitives.ciphers.modes import CBC @@ -97,22 +96,6 @@ def test_register_duplicate_cipher_adapter(self): with pytest.raises(ValueError): backend.register_cipher_adapter(AES, CBC, None) - @pytest.mark.parametrize("mode", [DummyMode(), None]) - def test_nonexistent_cipher(self, mode, backend, monkeypatch): - # We can't use register_cipher_adapter because backend is a - # global singleton and we want to revert the change after the test - monkeypatch.setitem( - backend._cipher_registry, - (DummyCipherAlgorithm, type(mode)), - lambda backend, cipher, mode: backend._ffi.NULL, - ) - cipher = Cipher( - DummyCipherAlgorithm(), - mode, - ) - with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_CIPHER): - cipher.encryptor() - def test_openssl_assert(self): backend.openssl_assert(True) with pytest.raises(InternalError): @@ -141,14 +124,6 @@ def test_evp_ciphers_registered(self): cipher = backend._lib.EVP_get_cipherbyname(b"aes-256-cbc") assert cipher != backend._ffi.NULL - def test_unknown_error_in_cipher_finalize(self): - cipher = Cipher(AES(b"\0" * 16), CBC(b"\0" * 16), backend=backend) - enc = cipher.encryptor() - enc.update(b"\0") - backend._lib.ERR_put_error(0, 0, 1, b"test_openssl.py", -1) - with pytest.raises(InternalError): - enc.finalize() - class TestOpenSSLRSA: def test_generate_rsa_parameters_supported(self): diff --git a/tests/hazmat/primitives/test_aes_gcm.py b/tests/hazmat/primitives/test_aes_gcm.py index d82e37470cae5..0543270413589 100644 --- a/tests/hazmat/primitives/test_aes_gcm.py +++ b/tests/hazmat/primitives/test_aes_gcm.py @@ -8,20 +8,13 @@ import pytest +from cryptography.hazmat.bindings._rust import openssl as rust_openssl from cryptography.hazmat.primitives.ciphers import algorithms, base, modes from ...utils import load_nist_vectors from .utils import generate_aead_test -def _advance(ctx, n): - ctx._bytes_processed += n - - -def _advance_aad(ctx, n): - ctx._aad_bytes_processed += n - - @pytest.mark.supported( only_if=lambda backend: backend.cipher_supported( algorithms.AES(b"\x00" * 16), modes.GCM(b"\x00" * 12) @@ -80,7 +73,9 @@ def test_gcm_ciphertext_limit(self, backend): backend=backend, ) encryptor = cipher.encryptor() - _advance(encryptor, modes.GCM._MAX_ENCRYPTED_BYTES - 16) + rust_openssl.ciphers._advance( + encryptor, modes.GCM._MAX_ENCRYPTED_BYTES - 16 + ) encryptor.update(b"0" * 16) with pytest.raises(ValueError): encryptor.update(b"0") @@ -88,7 +83,9 @@ def test_gcm_ciphertext_limit(self, backend): encryptor.update_into(b"0", bytearray(1)) decryptor = cipher.decryptor() - _advance(decryptor, modes.GCM._MAX_ENCRYPTED_BYTES - 16) + rust_openssl.ciphers._advance( + decryptor, modes.GCM._MAX_ENCRYPTED_BYTES - 16 + ) decryptor.update(b"0" * 16) with pytest.raises(ValueError): decryptor.update(b"0") @@ -102,45 +99,21 @@ def test_gcm_aad_limit(self, backend): backend=backend, ) encryptor = cipher.encryptor() - _advance_aad(encryptor, modes.GCM._MAX_AAD_BYTES - 16) + rust_openssl.ciphers._advance_aad( + encryptor, modes.GCM._MAX_AAD_BYTES - 16 + ) encryptor.authenticate_additional_data(b"0" * 16) with pytest.raises(ValueError): encryptor.authenticate_additional_data(b"0") decryptor = cipher.decryptor() - _advance_aad(decryptor, modes.GCM._MAX_AAD_BYTES - 16) + rust_openssl.ciphers._advance_aad( + decryptor, modes.GCM._MAX_AAD_BYTES - 16 + ) decryptor.authenticate_additional_data(b"0" * 16) with pytest.raises(ValueError): decryptor.authenticate_additional_data(b"0") - def test_gcm_ciphertext_increments(self, backend): - encryptor = base.Cipher( - algorithms.AES(b"\x00" * 16), - modes.GCM(b"\x01" * 16), - backend=backend, - ).encryptor() - encryptor.update(b"0" * 8) - assert encryptor._bytes_processed == 8 # type: ignore[attr-defined] - encryptor.update(b"0" * 7) - assert encryptor._bytes_processed == 15 # type: ignore[attr-defined] - encryptor.update(b"0" * 18) - assert encryptor._bytes_processed == 33 # type: ignore[attr-defined] - - def test_gcm_aad_increments(self, backend): - encryptor = base.Cipher( - algorithms.AES(b"\x00" * 16), - modes.GCM(b"\x01" * 16), - backend=backend, - ).encryptor() - encryptor.authenticate_additional_data(b"0" * 8) - assert ( - encryptor._aad_bytes_processed == 8 # type: ignore[attr-defined] - ) - encryptor.authenticate_additional_data(b"0" * 18) - assert ( - encryptor._aad_bytes_processed == 26 # type: ignore[attr-defined] - ) - def test_gcm_tag_decrypt_none(self, backend): key = binascii.unhexlify(b"5211242698bed4774a090620a6ca56f3") iv = binascii.unhexlify(b"b1e1349120b6e832ef976f5d")