diff --git a/tests/test_cli.py b/tests/test_cli.py index aea936ec2b..c7a5d5299e 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -17,6 +17,7 @@ DEFAULT_SKIP_MAGIC, ExtractionConfig, ) +from unblob.testing import is_sandbox_available from unblob.ui import ( NullProgressReporter, ProgressReporter, @@ -425,3 +426,29 @@ def test_clear_skip_magics( assert sorted(process_file_mock.call_args.args[0].skip_magic) == sorted( skip_magic ), fail_message + + +@pytest.mark.skipif( + not is_sandbox_available(), reason="Sandboxing is only available on Linux" +) +def test_sandbox_escape(tmp_path: Path): + runner = CliRunner() + + in_path = tmp_path / "input" + in_path.touch() + extract_dir = tmp_path / "extract-dir" + params = ["--extract-dir", str(extract_dir), str(in_path)] + + unrelated_file = tmp_path / "unrelated" + + process_file_mock = mock.MagicMock( + side_effect=lambda *_args, **_kwargs: unrelated_file.write_text( + "sandbox escape" + ) + ) + with mock.patch.object(unblob.cli, "process_file", process_file_mock): + result = runner.invoke(unblob.cli.cli, params) + + assert result.exit_code != 0 + assert isinstance(result.exception, PermissionError) + process_file_mock.assert_called_once() diff --git a/tests/test_sandbox.py b/tests/test_sandbox.py new file mode 100644 index 0000000000..ee0feff16a --- /dev/null +++ b/tests/test_sandbox.py @@ -0,0 +1,57 @@ +from pathlib import Path + +import pytest + +from unblob.processing import ExtractionConfig +from unblob.sandbox import Sandbox +from unblob.testing import is_sandbox_available + +pytestmark = pytest.mark.skipif( + not is_sandbox_available(), reason="Sandboxing only works on Linux" +) + + +@pytest.fixture +def log_path(tmp_path): + return tmp_path / "unblob.log" + + +@pytest.fixture +def extraction_config(extraction_config, tmp_path): + extraction_config.extract_root = tmp_path / "extract" / "root" + # parent has to exist + extraction_config.extract_root.parent.mkdir() + return extraction_config + + +@pytest.fixture +def sandbox(extraction_config: ExtractionConfig, log_path: Path): + return Sandbox(extraction_config, log_path, None) + + +def test_necessary_resources_can_be_created_in_sandbox( + sandbox: Sandbox, extraction_config: ExtractionConfig, log_path: Path +): + directory_in_extract_root = extraction_config.extract_root / "path" / "to" / "dir" + file_in_extract_root = directory_in_extract_root / "file" + + sandbox.run(extraction_config.extract_root.mkdir, parents=True) + sandbox.run(directory_in_extract_root.mkdir, parents=True) + + sandbox.run(file_in_extract_root.touch) + sandbox.run(file_in_extract_root.write_text, "file content") + + # log-file is already opened + log_path.touch() + sandbox.run(log_path.write_text, "log line") + + +def test_access_outside_sandbox_is_not_possible(sandbox: Sandbox, tmp_path: Path): + unrelated_dir = tmp_path / "unrelated" / "path" + unrelated_file = tmp_path / "unrelated-file" + + with pytest.raises(PermissionError): + sandbox.run(unrelated_dir.mkdir, parents=True) + + with pytest.raises(PermissionError): + sandbox.run(unrelated_file.touch) diff --git a/unblob/cli.py b/unblob/cli.py index fb48356327..93c6b206b8 100755 --- a/unblob/cli.py +++ b/unblob/cli.py @@ -33,6 +33,7 @@ ExtractionConfig, process_file, ) +from .sandbox import Sandbox from .ui import NullProgressReporter, RichConsoleProgressReporter logger = get_logger() @@ -321,7 +322,8 @@ def cli( ) logger.info("Start processing file", file=file) - process_results = process_file(config, file, report_file) + sandbox = Sandbox(config, log_path, report_file) + process_results = sandbox.run(process_file, config, file, report_file) if verbose == 0: if skip_extraction: print_scan_report(process_results) diff --git a/unblob/pool.py b/unblob/pool.py index eeee58ca7a..4b06ea3e85 100644 --- a/unblob/pool.py +++ b/unblob/pool.py @@ -203,10 +203,9 @@ def make_pool(process_num, handler, result_callback) -> Union[SinglePool, MultiP def _on_terminate(signum, frame): - with contextlib.suppress(StopIteration): - while True: - pool = next(iter(pools)) - pool.close(immediate=True) + pools_snapshot = list(pools) + for pool in pools_snapshot: + pool.close(immediate=True) if callable(orig_signal_handlers[signum]): orig_signal_handlers[signum](signum, frame) diff --git a/unblob/processing.py b/unblob/processing.py index a11fc2651d..c4a0a1391b 100644 --- a/unblob/processing.py +++ b/unblob/processing.py @@ -117,7 +117,6 @@ def get_carve_dir_for(self, path: Path) -> Path: return self._get_output_path(path.with_name(path.name + self.carve_suffix)) -@terminate_gracefully def process_file( config: ExtractionConfig, input_path: Path, report_file: Optional[Path] = None ) -> ProcessResult: diff --git a/unblob/sandbox.py b/unblob/sandbox.py new file mode 100644 index 0000000000..0bb4f52a7a --- /dev/null +++ b/unblob/sandbox.py @@ -0,0 +1,118 @@ +import ctypes +import sys +import threading +from pathlib import Path +from typing import Callable, Iterable, Optional, Type, TypeVar + +from structlog import get_logger +from unblob_native.sandbox import ( + AccessFS, + SandboxError, + restrict_access, +) + +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec + +from unblob.processing import ExtractionConfig + +logger = get_logger() + +P = ParamSpec("P") +R = TypeVar("R") + + +class Sandbox: + """Configures restricted file-systems to run functions in. + + When calling ``run()``, a separate thread will be configured with + minimum required file-system permissions. All subprocesses spawned + from that thread will honor the restrictions. + """ + + def __init__( + self, + config: ExtractionConfig, + log_path: Path, + report_file: Optional[Path], + extra_passthrough: Iterable[AccessFS] = (), + ): + self.passthrough = [ + # Python, shared libraries, extractor binaries and so on + AccessFS.read("/"), + # Multiprocessing + AccessFS.read_write("/dev/shm"), # noqa: S108 + # Extracted contents + AccessFS.read_write(config.extract_root), + AccessFS.make_dir(config.extract_root.parent), + AccessFS.read_write(log_path), + *extra_passthrough, + ] + + if report_file: + self.passthrough += [ + AccessFS.read_write(report_file), + AccessFS.make_reg(report_file.parent), + ] + + def run(self, callback: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: + """Run callback with restricted filesystem access.""" + exception = None + result = None + + def _run_in_thread(callback, *args, **kwargs): + nonlocal exception, result + + self._try_enter_sandbox() + try: + result = callback(*args, **kwargs) + except BaseException as e: + exception = e + + thread = threading.Thread( + target=_run_in_thread, args=(callback, *args), kwargs=kwargs + ) + thread.start() + + try: + thread.join() + except KeyboardInterrupt: + raise_in_thread(thread, KeyboardInterrupt) + thread.join() + + if exception: + raise exception # pyright: ignore[reportGeneralTypeIssues] + return result # pyright: ignore[reportReturnType] + + def _try_enter_sandbox(self): + try: + restrict_access(*self.passthrough) + except SandboxError: + logger.warning( + "Sandboxing FS access is unavailable on this system, skipping." + ) + + +def raise_in_thread(thread: threading.Thread, exctype: Type) -> None: + if thread.ident is None: + raise RuntimeError("Thread is not started") + + res = ctypes.pythonapi.PyThreadState_SetAsyncExc( + ctypes.c_ulong(thread.ident), ctypes.py_object(exctype) + ) + + # success + if res == 1: + return + + # Need to revert the call to restore interpreter state + ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_ulong(thread.ident), None) + + # Thread could have exited since + if res == 0: + return + + # Something bad have happened + raise RuntimeError("Could not raise exception in thread", thread.ident) diff --git a/unblob/testing.py b/unblob/testing.py index 86030c5463..6301ff9651 100644 --- a/unblob/testing.py +++ b/unblob/testing.py @@ -1,6 +1,7 @@ import binascii import glob import io +import platform import shlex import subprocess from pathlib import Path @@ -10,6 +11,7 @@ from lark.lark import Lark from lark.visitors import Discard, Transformer from pytest_cov.embed import cleanup_on_sigterm +from unblob_native.sandbox import AccessFS, SandboxError, restrict_access from unblob.finder import build_hyperscan_database from unblob.logging import configure_logger @@ -217,3 +219,17 @@ def start(self, s): rv.write(line.data) return rv.getvalue() + + +def is_sandbox_available(): + is_sandbox_available = True + + try: + restrict_access(AccessFS.read_write("/")) + except SandboxError: + is_sandbox_available = False + + if platform.architecture == "x86_64" and platform.system == "linux": + assert is_sandbox_available, "Sandboxing should work at least on Linux-x86_64" + + return is_sandbox_available