Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental Landlock based sandboxing #597

Merged
merged 3 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 41 additions & 9 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Iterable, List, Optional, Type
from typing import Iterable, List, Optional, Tuple, Type
from unittest import mock

import pytest
Expand All @@ -13,9 +13,11 @@
from unblob.processing import (
DEFAULT_DEPTH,
DEFAULT_PROCESS_NUM,
DEFAULT_SKIP_EXTENSION,
DEFAULT_SKIP_MAGIC,
ExtractionConfig,
)
from unblob.testing import is_sandbox_available
from unblob.ui import (
NullProgressReporter,
ProgressReporter,
Expand Down Expand Up @@ -310,16 +312,16 @@ def test_keep_extracted_chunks(


@pytest.mark.parametrize(
"skip_extension, extracted_files_count",
"skip_extension, expected_skip_extensions",
[
pytest.param([], 5, id="skip-extension-empty"),
pytest.param([""], 5, id="skip-zip-extension-empty-suffix"),
pytest.param([".zip"], 0, id="skip-extension-zip"),
pytest.param([".rlib"], 5, id="skip-extension-rlib"),
pytest.param((), DEFAULT_SKIP_EXTENSION, id="skip-extension-empty"),
pytest.param(("",), ("",), id="skip-zip-extension-empty-suffix"),
pytest.param((".zip",), (".zip",), id="skip-extension-zip"),
pytest.param((".rlib",), (".rlib",), id="skip-extension-rlib"),
],
)
def test_skip_extension(
skip_extension: List[str], extracted_files_count: int, tmp_path: Path
skip_extension: List[str], expected_skip_extensions: Tuple[str, ...], tmp_path: Path
):
runner = CliRunner()
in_path = (
Expand All @@ -335,8 +337,12 @@ def test_skip_extension(
for suffix in skip_extension:
args += ["--skip-extension", suffix]
params = [*args, "--extract-dir", str(tmp_path), str(in_path)]
result = runner.invoke(unblob.cli.cli, params)
assert extracted_files_count == len(list(tmp_path.rglob("*")))
process_file_mock = mock.MagicMock()
with mock.patch.object(unblob.cli, "process_file", process_file_mock):
result = runner.invoke(unblob.cli.cli, params)
assert (
process_file_mock.call_args.args[0].skip_extension == expected_skip_extensions
)
assert result.exit_code == 0


Expand Down Expand Up @@ -420,3 +426,29 @@ def test_clear_skip_magics(
assert sorted(process_file_mock.call_args.args[0].skip_magic) == sorted(
skip_magic
), fail_message


@pytest.mark.skipif(
not is_sandbox_available(), reason="Sandboxing is only available on Linux"
)
def test_sandbox_escape(tmp_path: Path):
runner = CliRunner()

in_path = tmp_path / "input"
in_path.touch()
extract_dir = tmp_path / "extract-dir"
params = ["--extract-dir", str(extract_dir), str(in_path)]

unrelated_file = tmp_path / "unrelated"

process_file_mock = mock.MagicMock(
side_effect=lambda *_args, **_kwargs: unrelated_file.write_text(
"sandbox escape"
)
)
with mock.patch.object(unblob.cli, "process_file", process_file_mock):
result = runner.invoke(unblob.cli.cli, params)

assert result.exit_code != 0
assert isinstance(result.exception, PermissionError)
process_file_mock.assert_called_once()
e3krisztian marked this conversation as resolved.
Show resolved Hide resolved
57 changes: 57 additions & 0 deletions tests/test_sandbox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from pathlib import Path

import pytest

from unblob.processing import ExtractionConfig
from unblob.sandbox import Sandbox
from unblob.testing import is_sandbox_available

pytestmark = pytest.mark.skipif(
not is_sandbox_available(), reason="Sandboxing only works on Linux"
)


@pytest.fixture
def log_path(tmp_path):
return tmp_path / "unblob.log"


@pytest.fixture
def extraction_config(extraction_config, tmp_path):
extraction_config.extract_root = tmp_path / "extract" / "root"
# parent has to exist
extraction_config.extract_root.parent.mkdir()
return extraction_config


@pytest.fixture
def sandbox(extraction_config: ExtractionConfig, log_path: Path):
return Sandbox(extraction_config, log_path, None)


def test_necessary_resources_can_be_created_in_sandbox(
sandbox: Sandbox, extraction_config: ExtractionConfig, log_path: Path
):
directory_in_extract_root = extraction_config.extract_root / "path" / "to" / "dir"
file_in_extract_root = directory_in_extract_root / "file"

sandbox.run(extraction_config.extract_root.mkdir, parents=True)
sandbox.run(directory_in_extract_root.mkdir, parents=True)

sandbox.run(file_in_extract_root.touch)
sandbox.run(file_in_extract_root.write_text, "file content")

# log-file is already opened
log_path.touch()
sandbox.run(log_path.write_text, "log line")
e3krisztian marked this conversation as resolved.
Show resolved Hide resolved


def test_access_outside_sandbox_is_not_possible(sandbox: Sandbox, tmp_path: Path):
unrelated_dir = tmp_path / "unrelated" / "path"
unrelated_file = tmp_path / "unrelated-file"

with pytest.raises(PermissionError):
sandbox.run(unrelated_dir.mkdir, parents=True)

with pytest.raises(PermissionError):
sandbox.run(unrelated_file.touch)
e3krisztian marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 3 additions & 1 deletion unblob/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
ExtractionConfig,
process_file,
)
from .sandbox import Sandbox
from .ui import NullProgressReporter, RichConsoleProgressReporter

logger = get_logger()
Expand Down Expand Up @@ -321,7 +322,8 @@ def cli(
)

logger.info("Start processing file", file=file)
process_results = process_file(config, file, report_file)
sandbox = Sandbox(config, log_path, report_file)
process_results = sandbox.run(process_file, config, file, report_file)
if verbose == 0:
if skip_extraction:
print_scan_report(process_results)
Expand Down
84 changes: 70 additions & 14 deletions unblob/pool.py
vlaci marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
import abc
import contextlib
import multiprocessing as mp
import os
import queue
import signal
import sys
import threading
from multiprocessing.queues import JoinableQueue
from typing import Any, Callable, Union
from typing import Any, Callable, Set, Union
vlaci marked this conversation as resolved.
Show resolved Hide resolved

from .logging import multiprocessing_breakpoint

mp.set_start_method("fork")


class PoolBase(abc.ABC):
def __init__(self):
with pools_lock:
pools.add(self)

@abc.abstractmethod
def submit(self, args):
pass
Expand All @@ -24,15 +30,20 @@ def process_until_done(self):
def start(self):
pass

def close(self):
pass
def close(self, *, immediate=False): # noqa: ARG002
with pools_lock:
pools.remove(self)

def __enter__(self):
self.start()
return self

def __exit__(self, *args):
self.close()
def __exit__(self, exc_type, _exc_value, _tb):
self.close(immediate=exc_type is not None)


pools_lock = threading.Lock()
pools: Set[PoolBase] = set()


class Queue(JoinableQueue):
Expand All @@ -53,9 +64,15 @@ class _Sentinel:


def _worker_process(handler, input_, output):
# Creates a new process group, making sure no signals are propagated from the main process to the worker processes.
# Creates a new process group, making sure no signals are
# propagated from the main process to the worker processes.
os.setpgrp()

# Restore default signal handlers, otherwise workers would inherit
# them from main process
signal.signal(signal.SIGTERM, signal.SIG_DFL)
signal.signal(signal.SIGINT, signal.SIG_DFL)

sys.breakpointhook = multiprocessing_breakpoint
while (args := input_.get()) is not _SENTINEL:
result = handler(args)
Expand All @@ -71,11 +88,14 @@ def __init__(
*,
result_callback: Callable[["MultiPool", Any], Any],
):
super().__init__()
if process_num <= 0:
raise ValueError("At process_num must be greater than 0")

self._running = False
self._result_callback = result_callback
self._input = Queue(ctx=mp.get_context())
self._input.cancel_join_thread()
self._output = mp.SimpleQueue()
self._procs = [
mp.Process(
Expand All @@ -87,14 +107,32 @@ def __init__(
self._tid = threading.get_native_id()

def start(self):
self._running = True
for p in self._procs:
p.start()

def close(self):
self._clear_input_queue()
self._request_workers_to_quit()
self._clear_output_queue()
def close(self, *, immediate=False):
if not self._running:
return
self._running = False

if immediate:
self._terminate_workers()
else:
self._clear_input_queue()
self._request_workers_to_quit()
self._clear_output_queue()

self._wait_for_workers_to_quit()
super().close(immediate=immediate)

def _terminate_workers(self):
for proc in self._procs:
proc.terminate()

self._input.close()
if sys.version_info >= (3, 9):
vlaci marked this conversation as resolved.
Show resolved Hide resolved
self._output.close()

def _clear_input_queue(self):
try:
Expand Down Expand Up @@ -129,14 +167,16 @@ def submit(self, args):
self._input.put(args)

def process_until_done(self):
while not self._input.is_empty():
result = self._output.get()
self._result_callback(self, result)
self._input.task_done()
with contextlib.suppress(EOFError):
qkaiser marked this conversation as resolved.
Show resolved Hide resolved
while not self._input.is_empty():
result = self._output.get()
self._result_callback(self, result)
self._input.task_done()


class SinglePool(PoolBase):
def __init__(self, handler, *, result_callback):
super().__init__()
self._handler = handler
self._result_callback = result_callback

Expand All @@ -157,3 +197,19 @@ def make_pool(process_num, handler, result_callback) -> Union[SinglePool, MultiP
handler=handler,
result_callback=result_callback,
)


orig_signal_handlers = {}


def _on_terminate(signum, frame):
pools_snapshot = list(pools)
for pool in pools_snapshot:
pool.close(immediate=True)

if callable(orig_signal_handlers[signum]):
Fixed Show fixed Hide fixed
orig_signal_handlers[signum](signum, frame)


orig_signal_handlers[signal.SIGTERM] = signal.signal(signal.SIGTERM, _on_terminate)
orig_signal_handlers[signal.SIGINT] = signal.signal(signal.SIGINT, _on_terminate)
2 changes: 0 additions & 2 deletions unblob/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
StatReport,
UnknownError,
)
from .signals import terminate_gracefully
from .ui import NullProgressReporter, ProgressReporter

logger = get_logger()
Expand Down Expand Up @@ -118,7 +117,6 @@ def get_carve_dir_for(self, path: Path) -> Path:
return self._get_output_path(path.with_name(path.name + self.carve_suffix))


@terminate_gracefully
e3krisztian marked this conversation as resolved.
Show resolved Hide resolved
def process_file(
config: ExtractionConfig, input_path: Path, report_file: Optional[Path] = None
) -> ProcessResult:
Expand Down
Loading
Loading