diff --git a/tests/test_cli.py b/tests/test_cli.py index 7620247f66..bfa8278008 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -302,3 +302,37 @@ def test_keep_extracted_chunks( process_file_mock.call_args.args[0].keep_extracted_chunks == keep_extracted_chunks ), fail_message + + +@pytest.mark.parametrize( + "args, skip_extraction, fail_message", + [ + ([], False, "Should *NOT* have skipped extraction"), + (["-s"], True, "Should have skipped extraction"), + (["--skip-extraction"], True, "Should have skipped extraction"), + ], +) +def test_skip_extraction( + args: List[str], skip_extraction: bool, fail_message: str, tmp_path: Path +): + runner = CliRunner() + in_path = ( + Path(__file__).parent + / "integration" + / "archive" + / "zip" + / "regular" + / "__input__" + / "apple.zip" + ) + params = [*args, "--extract-dir", str(tmp_path), str(in_path)] + + process_file_mock = mock.MagicMock() + with mock.patch.object(unblob.cli, "process_file", process_file_mock): + result = runner.invoke(unblob.cli.cli, params) + + assert result.exit_code == 0 + process_file_mock.assert_called_once() + assert ( + process_file_mock.call_args.args[0].skip_extraction == skip_extraction + ), fail_message diff --git a/tests/test_processing.py b/tests/test_processing.py index ef800e3da2..40058c86bf 100644 --- a/tests/test_processing.py +++ b/tests/test_processing.py @@ -447,6 +447,37 @@ def get_all(file_name, report_type: Type[ReportType]) -> List[ReportType]: ) +@pytest.mark.parametrize( + "skip_extraction, file_count, extracted_file_count", + [ + (True, 5, 0), + (False, 5, 6), + ], +) +def test_skip_extraction( + skip_extraction: bool, + file_count: int, + extracted_file_count: int, + tmp_path: Path, + extraction_config: ExtractionConfig, +): + input_file = tmp_path / "input" + with zipfile.ZipFile(input_file, "w") as zf: + for i in range(file_count): + zf.writestr(f"file{i}", data=b"This is a test file.") + + extraction_config.extract_root = tmp_path / "output" + extraction_config.skip_extraction = skip_extraction + + process_result = process_file(extraction_config, input_file) + task_result_by_path = {r.task.path: r for r in process_result.results} + + assert len(task_result_by_path) == extracted_file_count + 1 + assert ( + len(list(extraction_config.extract_root.rglob("**/*"))) == extracted_file_count + ) + + class ConcatenateExtractor(DirectoryExtractor): def extract(self, paths: List[Path], outdir: Path): outfile = outdir / "data"