diff --git a/test/test_rohmu.py b/test/test_rohmu.py new file mode 100644 index 00000000..1aa0808b --- /dev/null +++ b/test/test_rohmu.py @@ -0,0 +1,163 @@ +import hashlib +import logging +import os +from tempfile import NamedTemporaryFile + +import pytest +from rohmu import get_transfer, rohmufile +from rohmu.rohmufile import create_sink_pipeline + +from .base import CONSTANT_TEST_RSA_PRIVATE_KEY, CONSTANT_TEST_RSA_PUBLIC_KEY + +log = logging.getLogger(__name__) + + +@pytest.mark.parametrize( + "compress_algorithm, file_size", + [("lzma", 0), ("snappy", 0), ("zstd", 0), ("lzma", 1), ("snappy", 1), ("zstd", 1)], + ids=[ + "test_lzma_0byte_file", "test_snappy_0byte_file", "test_zstd_0byte_file", "test_lzma_1byte_file", + "test_snappy_1byte_file", "test_zstd_1byte_file" + ], +) +def test_rohmu_with_local_storage(compress_algorithm: str, file_size: int, tmp_path): + hash_algorithm = "sha1" + compression_level = 0 + + # 0 - Prepare the file + work_dir = tmp_path + orig_file = work_dir / "hello.bin" + content = os.urandom(file_size) + with open(orig_file, "wb") as file_out: + file_out.write(content) + + with open(orig_file, "rb") as file_in: + assert file_in.read() == content + + # 1 - Compressed the file + compressed_filepath = work_dir / "compressed" / "hello_compressed" + compressed_filepath.parent.mkdir(exist_ok=True) + hasher = hashlib.new(hash_algorithm) + input_obj = open(orig_file, "rb") + output_obj = NamedTemporaryFile( + dir=os.path.dirname(compressed_filepath), prefix=os.path.basename(compressed_filepath), suffix=".tmp-compress" + ) + with input_obj, output_obj: + original_file_size, compressed_file_size = rohmufile.write_file( + data_callback=hasher.update, + input_obj=input_obj, + output_obj=output_obj, + compression_algorithm=compress_algorithm, + compression_level=compression_level, + rsa_public_key=CONSTANT_TEST_RSA_PUBLIC_KEY, + log_func=log.debug, + ) + os.link(output_obj.name, compressed_filepath) + + log.info("original_file_size: %s, compressed_file_size: %s", original_file_size, compressed_file_size) + assert original_file_size == len(content) + file_hash = hasher.hexdigest() + log.info("original_file_hash: %s", file_hash) + + # 2 - Upload the compressed file + upload_dir = work_dir / "uploaded" + upload_dir.mkdir() + storage_config = { + "directory": str(upload_dir), + "storage_type": "local", + } + metadata = { + "encryption-key-id": "No matter", + "compression-algorithm": compress_algorithm, + "compression-level": compression_level, + } + storage = get_transfer(storage_config) + + metadata_copy = metadata.copy() + metadata_copy["Content-Length"] = str(compressed_file_size) + file_key = "compressed/hello_compressed" + + def upload_progress_callback(n_bytes: int) -> None: + log.debug("File: '%s', uploaded %d bytes", file_key, n_bytes) + + with open(compressed_filepath, "rb") as f: + storage.store_file_object(file_key, f, metadata=metadata_copy, upload_progress_fn=upload_progress_callback) + + # 3 - Decrypt and decompress + # 3.1 Use file downloading rohmu API + decompressed_filepath = work_dir / "hello_decompressed_1" + + decompressed_size = _download_and_decompress_with_file(storage, str(decompressed_filepath), file_key, metadata) + assert len(content) == decompressed_size + # Compare content + with open(decompressed_filepath, "rb") as file_in: + content_decrypted = file_in.read() + hasher = hashlib.new(hash_algorithm) + hasher.update(content_decrypted) + assert hasher.hexdigest() == file_hash + assert content_decrypted == content + + # 3.2 Use rohmu SinkIO API + decompressed_filepath = work_dir / "hello_decompressed_2" + decompressed_size = _download_and_decompress_with_sink(storage, str(decompressed_filepath), file_key, metadata) + assert len(content) == decompressed_size + + # Compare content + hasher.hexdigest() + with open(decompressed_filepath, "rb") as file_in: + content_decrypted = file_in.read() + hasher = hashlib.new(hash_algorithm) + hasher.update(content_decrypted) + assert hasher.hexdigest() == file_hash + assert content_decrypted == content + + if file_size == 0: + empty_file_sha1 = "da39a3ee5e6b4b0d3255bfef95601890afd80709" + assert empty_file_sha1 == hasher.hexdigest() + + +def _key_lookup(key_id: str): # pylint: disable=unused-argument + return CONSTANT_TEST_RSA_PRIVATE_KEY + + +def _download_and_decompress_with_sink(storage, output_path: str, file_key: str, metadata: dict): + data, _ = storage.get_contents_to_string(file_key) + if isinstance(data, str): + data = data.encode("latin1") + file_size = len(data) + + with open(output_path, "wb") as target_file: + output = create_sink_pipeline( + output=target_file, file_size=file_size, metadata=metadata, key_lookup=_key_lookup, throttle_time=0 + ) + output.write(data) + decompressed_size = os.path.getsize(output_path) + return decompressed_size + + +def _download_and_decompress_with_file(storage, output_path: str, file_key: str, metadata: dict): + # Download the compressed file + file_download_path = output_path + ".tmp" + + def download_progress_callback(bytes_written: int, input_size: int) -> None: + log.debug("File: '%s', downloaded %d of %d bytes", file_key, bytes_written, input_size) + + with open(file_download_path, "wb") as f: + storage.get_contents_to_fileobj(file_key, f, progress_callback=download_progress_callback) + + # Decrypt and decompress + input_obj = open(file_download_path, "rb") + output_obj = open(output_path, "wb") + with input_obj, output_obj: + _, decompressed_size = rohmufile.read_file( + input_obj=input_obj, + output_obj=output_obj, + metadata=metadata, + key_lookup=_key_lookup, + log_func=log.debug, + ) + output_obj.flush() + + # Delete temporary file + os.unlink(file_download_path) + return decompressed_size