diff --git a/rohmu/encryptor.py b/rohmu/encryptor.py index 23d4de9c..982ae141 100644 --- a/rohmu/encryptor.py +++ b/rohmu/encryptor.py @@ -428,15 +428,19 @@ def write(self, data: BinaryData) -> int: class Encryptor(BaseEncryptor): - def __init__(self, public_key_pem: Union[str, bytes]): - if not isinstance(public_key_pem, bytes): - public_key_pem = public_key_pem.encode("ascii") - public_key = serialization.load_pem_public_key(public_key_pem, backend=default_backend()) - if not isinstance(public_key, RSAPublicKey): - raise ValueError("Key must be RSA") + def __init__(self, public_key_pem: Union[str, bytes, RSAPublicKey]): + if isinstance(public_key_pem, RSAPublicKey): + rsa_public_key = public_key_pem + else: + if not isinstance(public_key_pem, bytes): + public_key_pem = public_key_pem.encode("ascii") + public_key = serialization.load_pem_public_key(public_key_pem, backend=default_backend()) + if not isinstance(public_key, RSAPublicKey): + raise ValueError("Key must be RSA") + rsa_public_key = public_key super().__init__() - self.rsa_public_key = public_key + self.rsa_public_key = rsa_public_key def init_cipher(self) -> bytes: cipher_key = os.urandom(16) @@ -450,27 +454,31 @@ def init_cipher(self) -> bytes: class EncryptorFile(BaseEncryptorFile): - def __init__(self, next_fp: FileLike, public_key_pem: Union[str, bytes]) -> None: + def __init__(self, next_fp: FileLike, public_key_pem: Union[str, bytes, RSAPublicKey]) -> None: super().__init__(next_fp, Encryptor(public_key_pem)) class EncryptorStream(BaseEncryptorStream): """Non-seekable stream of data that adds encryption on top of given source stream""" - def __init__(self, src_fp: HasRead, public_key_pem: Union[str, bytes]) -> None: + def __init__(self, src_fp: HasRead, public_key_pem: Union[str, bytes, RSAPublicKey]) -> None: super().__init__(src_fp, Encryptor(public_key_pem)) class Decryptor(BaseDecryptor): - def __init__(self, private_key_pem: Union[str, bytes]) -> None: - if not isinstance(private_key_pem, bytes): - private_key_pem = private_key_pem.encode("ascii") - private_key = serialization.load_pem_private_key(data=private_key_pem, password=None, backend=default_backend()) - if not isinstance(private_key, RSAPrivateKey): - raise ValueError("Key must be RSA") + def __init__(self, private_key_pem: Union[str, bytes, RSAPrivateKey]) -> None: + if isinstance(private_key_pem, RSAPrivateKey): + rsa_private_key = private_key_pem + else: + if not isinstance(private_key_pem, bytes): + private_key_pem = private_key_pem.encode("ascii") + private_key = serialization.load_pem_private_key(data=private_key_pem, password=None, backend=default_backend()) + if not isinstance(private_key, RSAPrivateKey): + raise ValueError("Key must be RSA") + rsa_private_key = private_key super().__init__() - self.rsa_private_key = private_key + self.rsa_private_key = rsa_private_key self._key_size = None self._header_size = None @@ -513,12 +521,12 @@ def process_header(self, data: bytes) -> None: class DecryptorFile(BaseDecryptorFile): - def __init__(self, next_fp: FileLike, private_key_pem: Union[bytes, str]): + def __init__(self, next_fp: FileLike, private_key_pem: Union[bytes, str, RSAPrivateKey]): super().__init__(next_fp, lambda: Decryptor(private_key_pem)) class DecryptSink(BaseDecryptSink): - def __init__(self, next_sink: HasWrite, file_size: int, private_key_pem: Union[bytes, str]): + def __init__(self, next_sink: HasWrite, file_size: int, private_key_pem: Union[bytes, str, RSAPrivateKey]): super().__init__(next_sink, file_size, Decryptor(private_key_pem)) diff --git a/rohmu/rohmufile.py b/rohmu/rohmufile.py index 98a9e2de..f433fbd2 100644 --- a/rohmu/rohmufile.py +++ b/rohmu/rohmufile.py @@ -11,6 +11,7 @@ from .filewrap import ThrottleSink from .typing import FileLike, HasRead, HasWrite, Metadata from contextlib import suppress +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey from inspect import signature from rohmu.object_storage.base import IncrementalProgressCallbackType from typing import Any, Callable, Optional, Union @@ -27,8 +28,8 @@ def _obj_name(input_obj: Any) -> str: def _get_encryption_key_data( - metadata: Optional[Metadata], key_lookup: Optional[Callable[[str], Optional[str]]] -) -> Optional[str]: + metadata: Optional[Metadata], key_lookup: Optional[Callable[[str], Optional[str | bytes | RSAPrivateKey]]] +) -> Optional[str | bytes | RSAPrivateKey]: if not metadata or not metadata.get("encryption-key-id"): return None @@ -47,7 +48,7 @@ def file_reader( *, fileobj: FileLike, metadata: Optional[Metadata] = None, - key_lookup: Optional[Callable[[str], Optional[str]]] = None, + key_lookup: Optional[Callable[[str], Optional[str | bytes | RSAPrivateKey]]] = None, ) -> FileLike: if not metadata: return fileobj @@ -68,7 +69,7 @@ def create_sink_pipeline( output: HasWrite, file_size: int = 0, metadata: Optional[Metadata] = None, - key_lookup: Optional[Callable[[str], Optional[str]]] = None, + key_lookup: Optional[Callable[[str], Optional[str | bytes | RSAPrivateKey]]] = None, throttle_time: float = 0.001, ) -> HasWrite: if throttle_time: @@ -143,7 +144,7 @@ def file_writer( compression_algorithm: Optional[str] = None, compression_level: int = 0, compression_threads: int = 0, - rsa_public_key: Union[None, str, bytes] = None, + rsa_public_key: Union[None, str, bytes, RSAPublicKey] = None, ) -> FileLike: if rsa_public_key: fileobj = EncryptorFile(fileobj, rsa_public_key) @@ -162,7 +163,7 @@ def write_file( compression_algorithm: Optional[str] = None, compression_level: int = 0, compression_threads: int = 0, - rsa_public_key: Union[None, str, bytes] = None, + rsa_public_key: Union[None, str, bytes, RSAPublicKey] = None, log_func: Optional[Callable[..., None]] = None, header_func: Optional[Callable[[bytes], None]] = None, data_callback: Optional[Callable[[bytes], None]] = None, diff --git a/test/test_encryptor.py b/test/test_encryptor.py index ee11c787..75ae7646 100644 --- a/test/test_encryptor.py +++ b/test/test_encryptor.py @@ -3,6 +3,9 @@ from __future__ import annotations +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey +from cryptography.hazmat.primitives.serialization import load_pem_private_key, load_pem_public_key from pathlib import Path from rohmu.common.constants import IO_BLOCK_SIZE from rohmu.encryptor import ( @@ -66,6 +69,12 @@ -----END PRIVATE KEY-----""" ) +LOADED_RSA_PUBLIC_KEY = load_pem_public_key(RSA_PUBLIC_KEY.encode(), backend=default_backend()) +assert isinstance(LOADED_RSA_PUBLIC_KEY, RSAPublicKey) + +LOADED_RSA_PRIVATE_KEY = load_pem_private_key(RSA_PRIVATE_KEY.encode(), password=None, backend=default_backend()) +assert isinstance(LOADED_RSA_PRIVATE_KEY, RSAPrivateKey) + @pytest.mark.parametrize( ("plaintext"), @@ -81,6 +90,14 @@ lambda: Encryptor(RSA_PUBLIC_KEY), lambda: Decryptor(RSA_PRIVATE_KEY), ), + ( + lambda: Encryptor(RSA_PUBLIC_KEY.encode()), + lambda: Decryptor(RSA_PRIVATE_KEY.encode()), + ), + ( + lambda: Encryptor(LOADED_RSA_PUBLIC_KEY), + lambda: Decryptor(LOADED_RSA_PRIVATE_KEY), + ), ( lambda: SymmetricEncryptor(SYMMETRIC_KEY), lambda: SymmetricDecryptor(SYMMETRIC_KEY), @@ -115,6 +132,14 @@ def test_encryptor_decryptor( lambda x: EncryptorStream(x, RSA_PUBLIC_KEY), lambda x: DecryptorFile(x, RSA_PRIVATE_KEY), ), + ( + lambda x: EncryptorStream(x, RSA_PUBLIC_KEY.encode()), + lambda x: DecryptorFile(x, RSA_PRIVATE_KEY), + ), + ( + lambda x: EncryptorStream(x, LOADED_RSA_PUBLIC_KEY), + lambda x: DecryptorFile(x, RSA_PRIVATE_KEY), + ), ( lambda x: SymmetricEncryptorStream(x, SYMMETRIC_KEY), lambda x: SymmetricDecryptorFile(x, SYMMETRIC_KEY), @@ -160,6 +185,14 @@ def test_encryptor_stream( lambda: Encryptor(RSA_PUBLIC_KEY), lambda x: DecryptorFile(x, RSA_PRIVATE_KEY), ), + ( + lambda: Encryptor(RSA_PUBLIC_KEY), + lambda x: DecryptorFile(x, RSA_PRIVATE_KEY.encode()), + ), + ( + lambda: Encryptor(RSA_PUBLIC_KEY), + lambda x: DecryptorFile(x, LOADED_RSA_PRIVATE_KEY), + ), ( lambda: SymmetricEncryptor(SYMMETRIC_KEY), lambda x: SymmetricDecryptorFile(x, SYMMETRIC_KEY), @@ -305,6 +338,14 @@ def test_decryptorfile_for_tarfile( lambda x: EncryptorFile(x, RSA_PUBLIC_KEY), lambda x: DecryptorFile(x, RSA_PRIVATE_KEY), ), + ( + lambda x: EncryptorFile(x, RSA_PUBLIC_KEY.encode()), + lambda x: DecryptorFile(x, RSA_PRIVATE_KEY), + ), + ( + lambda x: EncryptorFile(x, LOADED_RSA_PUBLIC_KEY), + lambda x: DecryptorFile(x, RSA_PRIVATE_KEY), + ), ( lambda x: SymmetricEncryptorFile(x, SYMMETRIC_KEY), lambda x: SymmetricDecryptorFile(x, SYMMETRIC_KEY),