From 58f9f25526688bdf97e9a578d52379608391541a Mon Sep 17 00:00:00 2001 From: Quentin Kaiser Date: Tue, 2 Jan 2024 18:34:10 +0100 Subject: [PATCH] tests(processing,cli): introduce tests for skip-extraction option. --- tests/test_cli.py | 34 ++++++++++++++++++++++++++++++++++ tests/test_processing.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/tests/test_cli.py b/tests/test_cli.py index 5857a9e187..e38c58486b 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -333,3 +333,37 @@ def test_skip_extension( result = runner.invoke(unblob.cli.cli, params) assert extracted_files_count == len(list(tmp_path.rglob("*"))) assert result.exit_code == 0 + + +@pytest.mark.parametrize( + "args, skip_extraction, fail_message", + [ + ([], False, "Should *NOT* have skipped extraction"), + (["-s"], True, "Should have skipped extraction"), + (["--skip-extraction"], True, "Should have skipped extraction"), + ], +) +def test_skip_extraction( + args: List[str], skip_extraction: bool, fail_message: str, tmp_path: Path +): + runner = CliRunner() + in_path = ( + Path(__file__).parent + / "integration" + / "archive" + / "zip" + / "regular" + / "__input__" + / "apple.zip" + ) + params = [*args, "--extract-dir", str(tmp_path), str(in_path)] + + process_file_mock = mock.MagicMock() + with mock.patch.object(unblob.cli, "process_file", process_file_mock): + result = runner.invoke(unblob.cli.cli, params) + + assert result.exit_code == 0 + process_file_mock.assert_called_once() + assert ( + process_file_mock.call_args.args[0].skip_extraction == skip_extraction + ), fail_message diff --git a/tests/test_processing.py b/tests/test_processing.py index ef800e3da2..40058c86bf 100644 --- a/tests/test_processing.py +++ b/tests/test_processing.py @@ -447,6 +447,37 @@ def get_all(file_name, report_type: Type[ReportType]) -> List[ReportType]: ) +@pytest.mark.parametrize( + "skip_extraction, file_count, extracted_file_count", + [ + (True, 5, 0), + (False, 5, 6), + ], +) +def test_skip_extraction( + skip_extraction: bool, + file_count: int, + extracted_file_count: int, + tmp_path: Path, + extraction_config: ExtractionConfig, +): + input_file = tmp_path / "input" + with zipfile.ZipFile(input_file, "w") as zf: + for i in range(file_count): + zf.writestr(f"file{i}", data=b"This is a test file.") + + extraction_config.extract_root = tmp_path / "output" + extraction_config.skip_extraction = skip_extraction + + process_result = process_file(extraction_config, input_file) + task_result_by_path = {r.task.path: r for r in process_result.results} + + assert len(task_result_by_path) == extracted_file_count + 1 + assert ( + len(list(extraction_config.extract_root.rglob("**/*"))) == extracted_file_count + ) + + class ConcatenateExtractor(DirectoryExtractor): def extract(self, paths: List[Path], outdir: Path): outfile = outdir / "data"