diff --git a/tests/test_cli.py b/tests/test_cli.py index 44f0ebdf79..c7a5d5299e 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Iterable, List, Optional, Type +from typing import Iterable, List, Optional, Tuple, Type from unittest import mock import pytest @@ -13,9 +13,11 @@ from unblob.processing import ( DEFAULT_DEPTH, DEFAULT_PROCESS_NUM, + DEFAULT_SKIP_EXTENSION, DEFAULT_SKIP_MAGIC, ExtractionConfig, ) +from unblob.testing import is_sandbox_available from unblob.ui import ( NullProgressReporter, ProgressReporter, @@ -310,16 +312,16 @@ def test_keep_extracted_chunks( @pytest.mark.parametrize( - "skip_extension, extracted_files_count", + "skip_extension, expected_skip_extensions", [ - pytest.param([], 5, id="skip-extension-empty"), - pytest.param([""], 5, id="skip-zip-extension-empty-suffix"), - pytest.param([".zip"], 0, id="skip-extension-zip"), - pytest.param([".rlib"], 5, id="skip-extension-rlib"), + pytest.param((), DEFAULT_SKIP_EXTENSION, id="skip-extension-empty"), + pytest.param(("",), ("",), id="skip-zip-extension-empty-suffix"), + pytest.param((".zip",), (".zip",), id="skip-extension-zip"), + pytest.param((".rlib",), (".rlib",), id="skip-extension-rlib"), ], ) def test_skip_extension( - skip_extension: List[str], extracted_files_count: int, tmp_path: Path + skip_extension: List[str], expected_skip_extensions: Tuple[str, ...], tmp_path: Path ): runner = CliRunner() in_path = ( @@ -335,8 +337,12 @@ def test_skip_extension( for suffix in skip_extension: args += ["--skip-extension", suffix] params = [*args, "--extract-dir", str(tmp_path), str(in_path)] - result = runner.invoke(unblob.cli.cli, params) - assert extracted_files_count == len(list(tmp_path.rglob("*"))) + process_file_mock = mock.MagicMock() + with mock.patch.object(unblob.cli, "process_file", process_file_mock): + result = runner.invoke(unblob.cli.cli, params) + assert ( + process_file_mock.call_args.args[0].skip_extension == expected_skip_extensions + ) assert result.exit_code == 0 @@ -420,3 +426,29 @@ def test_clear_skip_magics( assert sorted(process_file_mock.call_args.args[0].skip_magic) == sorted( skip_magic ), fail_message + + +@pytest.mark.skipif( + not is_sandbox_available(), reason="Sandboxing is only available on Linux" +) +def test_sandbox_escape(tmp_path: Path): + runner = CliRunner() + + in_path = tmp_path / "input" + in_path.touch() + extract_dir = tmp_path / "extract-dir" + params = ["--extract-dir", str(extract_dir), str(in_path)] + + unrelated_file = tmp_path / "unrelated" + + process_file_mock = mock.MagicMock( + side_effect=lambda *_args, **_kwargs: unrelated_file.write_text( + "sandbox escape" + ) + ) + with mock.patch.object(unblob.cli, "process_file", process_file_mock): + result = runner.invoke(unblob.cli.cli, params) + + assert result.exit_code != 0 + assert isinstance(result.exception, PermissionError) + process_file_mock.assert_called_once() diff --git a/tests/test_sandbox.py b/tests/test_sandbox.py new file mode 100644 index 0000000000..ee0feff16a --- /dev/null +++ b/tests/test_sandbox.py @@ -0,0 +1,57 @@ +from pathlib import Path + +import pytest + +from unblob.processing import ExtractionConfig +from unblob.sandbox import Sandbox +from unblob.testing import is_sandbox_available + +pytestmark = pytest.mark.skipif( + not is_sandbox_available(), reason="Sandboxing only works on Linux" +) + + +@pytest.fixture +def log_path(tmp_path): + return tmp_path / "unblob.log" + + +@pytest.fixture +def extraction_config(extraction_config, tmp_path): + extraction_config.extract_root = tmp_path / "extract" / "root" + # parent has to exist + extraction_config.extract_root.parent.mkdir() + return extraction_config + + +@pytest.fixture +def sandbox(extraction_config: ExtractionConfig, log_path: Path): + return Sandbox(extraction_config, log_path, None) + + +def test_necessary_resources_can_be_created_in_sandbox( + sandbox: Sandbox, extraction_config: ExtractionConfig, log_path: Path +): + directory_in_extract_root = extraction_config.extract_root / "path" / "to" / "dir" + file_in_extract_root = directory_in_extract_root / "file" + + sandbox.run(extraction_config.extract_root.mkdir, parents=True) + sandbox.run(directory_in_extract_root.mkdir, parents=True) + + sandbox.run(file_in_extract_root.touch) + sandbox.run(file_in_extract_root.write_text, "file content") + + # log-file is already opened + log_path.touch() + sandbox.run(log_path.write_text, "log line") + + +def test_access_outside_sandbox_is_not_possible(sandbox: Sandbox, tmp_path: Path): + unrelated_dir = tmp_path / "unrelated" / "path" + unrelated_file = tmp_path / "unrelated-file" + + with pytest.raises(PermissionError): + sandbox.run(unrelated_dir.mkdir, parents=True) + + with pytest.raises(PermissionError): + sandbox.run(unrelated_file.touch) diff --git a/unblob/cli.py b/unblob/cli.py index fb48356327..93c6b206b8 100755 --- a/unblob/cli.py +++ b/unblob/cli.py @@ -33,6 +33,7 @@ ExtractionConfig, process_file, ) +from .sandbox import Sandbox from .ui import NullProgressReporter, RichConsoleProgressReporter logger = get_logger() @@ -321,7 +322,8 @@ def cli( ) logger.info("Start processing file", file=file) - process_results = process_file(config, file, report_file) + sandbox = Sandbox(config, log_path, report_file) + process_results = sandbox.run(process_file, config, file, report_file) if verbose == 0: if skip_extraction: print_scan_report(process_results) diff --git a/unblob/pool.py b/unblob/pool.py index 810011a209..4b06ea3e85 100644 --- a/unblob/pool.py +++ b/unblob/pool.py @@ -1,11 +1,13 @@ import abc +import contextlib import multiprocessing as mp import os import queue +import signal import sys import threading from multiprocessing.queues import JoinableQueue -from typing import Any, Callable, Union +from typing import Any, Callable, Set, Union from .logging import multiprocessing_breakpoint @@ -13,6 +15,10 @@ class PoolBase(abc.ABC): + def __init__(self): + with pools_lock: + pools.add(self) + @abc.abstractmethod def submit(self, args): pass @@ -24,15 +30,20 @@ def process_until_done(self): def start(self): pass - def close(self): - pass + def close(self, *, immediate=False): # noqa: ARG002 + with pools_lock: + pools.remove(self) def __enter__(self): self.start() return self - def __exit__(self, *args): - self.close() + def __exit__(self, exc_type, _exc_value, _tb): + self.close(immediate=exc_type is not None) + + +pools_lock = threading.Lock() +pools: Set[PoolBase] = set() class Queue(JoinableQueue): @@ -53,9 +64,15 @@ class _Sentinel: def _worker_process(handler, input_, output): - # Creates a new process group, making sure no signals are propagated from the main process to the worker processes. + # Creates a new process group, making sure no signals are + # propagated from the main process to the worker processes. os.setpgrp() + # Restore default signal handlers, otherwise workers would inherit + # them from main process + signal.signal(signal.SIGTERM, signal.SIG_DFL) + signal.signal(signal.SIGINT, signal.SIG_DFL) + sys.breakpointhook = multiprocessing_breakpoint while (args := input_.get()) is not _SENTINEL: result = handler(args) @@ -71,11 +88,14 @@ def __init__( *, result_callback: Callable[["MultiPool", Any], Any], ): + super().__init__() if process_num <= 0: raise ValueError("At process_num must be greater than 0") + self._running = False self._result_callback = result_callback self._input = Queue(ctx=mp.get_context()) + self._input.cancel_join_thread() self._output = mp.SimpleQueue() self._procs = [ mp.Process( @@ -87,14 +107,32 @@ def __init__( self._tid = threading.get_native_id() def start(self): + self._running = True for p in self._procs: p.start() - def close(self): - self._clear_input_queue() - self._request_workers_to_quit() - self._clear_output_queue() + def close(self, *, immediate=False): + if not self._running: + return + self._running = False + + if immediate: + self._terminate_workers() + else: + self._clear_input_queue() + self._request_workers_to_quit() + self._clear_output_queue() + self._wait_for_workers_to_quit() + super().close(immediate=immediate) + + def _terminate_workers(self): + for proc in self._procs: + proc.terminate() + + self._input.close() + if sys.version_info >= (3, 9): + self._output.close() def _clear_input_queue(self): try: @@ -129,14 +167,16 @@ def submit(self, args): self._input.put(args) def process_until_done(self): - while not self._input.is_empty(): - result = self._output.get() - self._result_callback(self, result) - self._input.task_done() + with contextlib.suppress(EOFError): + while not self._input.is_empty(): + result = self._output.get() + self._result_callback(self, result) + self._input.task_done() class SinglePool(PoolBase): def __init__(self, handler, *, result_callback): + super().__init__() self._handler = handler self._result_callback = result_callback @@ -157,3 +197,19 @@ def make_pool(process_num, handler, result_callback) -> Union[SinglePool, MultiP handler=handler, result_callback=result_callback, ) + + +orig_signal_handlers = {} + + +def _on_terminate(signum, frame): + pools_snapshot = list(pools) + for pool in pools_snapshot: + pool.close(immediate=True) + + if callable(orig_signal_handlers[signum]): + orig_signal_handlers[signum](signum, frame) + + +orig_signal_handlers[signal.SIGTERM] = signal.signal(signal.SIGTERM, _on_terminate) +orig_signal_handlers[signal.SIGINT] = signal.signal(signal.SIGINT, _on_terminate) diff --git a/unblob/processing.py b/unblob/processing.py index b759b55159..c4a0a1391b 100644 --- a/unblob/processing.py +++ b/unblob/processing.py @@ -45,7 +45,6 @@ StatReport, UnknownError, ) -from .signals import terminate_gracefully from .ui import NullProgressReporter, ProgressReporter logger = get_logger() @@ -118,7 +117,6 @@ def get_carve_dir_for(self, path: Path) -> Path: return self._get_output_path(path.with_name(path.name + self.carve_suffix)) -@terminate_gracefully def process_file( config: ExtractionConfig, input_path: Path, report_file: Optional[Path] = None ) -> ProcessResult: diff --git a/unblob/sandbox.py b/unblob/sandbox.py new file mode 100644 index 0000000000..0bb4f52a7a --- /dev/null +++ b/unblob/sandbox.py @@ -0,0 +1,118 @@ +import ctypes +import sys +import threading +from pathlib import Path +from typing import Callable, Iterable, Optional, Type, TypeVar + +from structlog import get_logger +from unblob_native.sandbox import ( + AccessFS, + SandboxError, + restrict_access, +) + +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec + +from unblob.processing import ExtractionConfig + +logger = get_logger() + +P = ParamSpec("P") +R = TypeVar("R") + + +class Sandbox: + """Configures restricted file-systems to run functions in. + + When calling ``run()``, a separate thread will be configured with + minimum required file-system permissions. All subprocesses spawned + from that thread will honor the restrictions. + """ + + def __init__( + self, + config: ExtractionConfig, + log_path: Path, + report_file: Optional[Path], + extra_passthrough: Iterable[AccessFS] = (), + ): + self.passthrough = [ + # Python, shared libraries, extractor binaries and so on + AccessFS.read("/"), + # Multiprocessing + AccessFS.read_write("/dev/shm"), # noqa: S108 + # Extracted contents + AccessFS.read_write(config.extract_root), + AccessFS.make_dir(config.extract_root.parent), + AccessFS.read_write(log_path), + *extra_passthrough, + ] + + if report_file: + self.passthrough += [ + AccessFS.read_write(report_file), + AccessFS.make_reg(report_file.parent), + ] + + def run(self, callback: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: + """Run callback with restricted filesystem access.""" + exception = None + result = None + + def _run_in_thread(callback, *args, **kwargs): + nonlocal exception, result + + self._try_enter_sandbox() + try: + result = callback(*args, **kwargs) + except BaseException as e: + exception = e + + thread = threading.Thread( + target=_run_in_thread, args=(callback, *args), kwargs=kwargs + ) + thread.start() + + try: + thread.join() + except KeyboardInterrupt: + raise_in_thread(thread, KeyboardInterrupt) + thread.join() + + if exception: + raise exception # pyright: ignore[reportGeneralTypeIssues] + return result # pyright: ignore[reportReturnType] + + def _try_enter_sandbox(self): + try: + restrict_access(*self.passthrough) + except SandboxError: + logger.warning( + "Sandboxing FS access is unavailable on this system, skipping." + ) + + +def raise_in_thread(thread: threading.Thread, exctype: Type) -> None: + if thread.ident is None: + raise RuntimeError("Thread is not started") + + res = ctypes.pythonapi.PyThreadState_SetAsyncExc( + ctypes.c_ulong(thread.ident), ctypes.py_object(exctype) + ) + + # success + if res == 1: + return + + # Need to revert the call to restore interpreter state + ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_ulong(thread.ident), None) + + # Thread could have exited since + if res == 0: + return + + # Something bad have happened + raise RuntimeError("Could not raise exception in thread", thread.ident) diff --git a/unblob/signals.py b/unblob/signals.py deleted file mode 100644 index 76b70a4dbe..0000000000 --- a/unblob/signals.py +++ /dev/null @@ -1,51 +0,0 @@ -import functools -import signal - -from structlog import get_logger - -logger = get_logger() - - -class ShutDownRequired(BaseException): - def __init__(self, signal: str): - super().__init__() - self.signal = signal - - -def terminate_gracefully(func): - @functools.wraps(func) - def decorator(*args, **kwargs): - signals_fired = [] - - def _handle_signal(signum: int, frame): - nonlocal signals_fired - signals_fired.append((signum, frame)) - raise ShutDownRequired(signal=signal.Signals(signum).name) - - original_signal_handlers = { - signal.SIGINT: signal.signal(signal.SIGINT, _handle_signal), - signal.SIGTERM: signal.signal(signal.SIGTERM, _handle_signal), - } - - logger.debug( - "Setting up signal handlers", - original_signal_handlers=original_signal_handlers, - _verbosity=2, - ) - - try: - return func(*args, **kwargs) - except ShutDownRequired as exc: - logger.warning("Shutting down", signal=exc.signal) - finally: - # Set back the original signal handlers - for sig, handler in original_signal_handlers.items(): - signal.signal(sig, handler) - - # Call the original signal handler with the fired and catched signal(s) - for sig, frame in signals_fired: - handler = original_signal_handlers.get(sig) - if callable(handler): - handler(sig, frame) - - return decorator diff --git a/unblob/testing.py b/unblob/testing.py index 86030c5463..6301ff9651 100644 --- a/unblob/testing.py +++ b/unblob/testing.py @@ -1,6 +1,7 @@ import binascii import glob import io +import platform import shlex import subprocess from pathlib import Path @@ -10,6 +11,7 @@ from lark.lark import Lark from lark.visitors import Discard, Transformer from pytest_cov.embed import cleanup_on_sigterm +from unblob_native.sandbox import AccessFS, SandboxError, restrict_access from unblob.finder import build_hyperscan_database from unblob.logging import configure_logger @@ -217,3 +219,17 @@ def start(self, s): rv.write(line.data) return rv.getvalue() + + +def is_sandbox_available(): + is_sandbox_available = True + + try: + restrict_access(AccessFS.read_write("/")) + except SandboxError: + is_sandbox_available = False + + if platform.architecture == "x86_64" and platform.system == "linux": + assert is_sandbox_available, "Sandboxing should work at least on Linux-x86_64" + + return is_sandbox_available