Skip to content

Commit

Permalink
Merge pull request #597 from onekey-sec/landlock
Browse files Browse the repository at this point in the history
Experimental Landlock based sandboxing
  • Loading branch information
qkaiser authored Dec 4, 2024
2 parents 495e351 + f882f70 commit c31b05e
Show file tree
Hide file tree
Showing 8 changed files with 305 additions and 77 deletions.
50 changes: 41 additions & 9 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Iterable, List, Optional, Type
from typing import Iterable, List, Optional, Tuple, Type
from unittest import mock

import pytest
Expand All @@ -13,9 +13,11 @@
from unblob.processing import (
DEFAULT_DEPTH,
DEFAULT_PROCESS_NUM,
DEFAULT_SKIP_EXTENSION,
DEFAULT_SKIP_MAGIC,
ExtractionConfig,
)
from unblob.testing import is_sandbox_available
from unblob.ui import (
NullProgressReporter,
ProgressReporter,
Expand Down Expand Up @@ -310,16 +312,16 @@ def test_keep_extracted_chunks(


@pytest.mark.parametrize(
"skip_extension, extracted_files_count",
"skip_extension, expected_skip_extensions",
[
pytest.param([], 5, id="skip-extension-empty"),
pytest.param([""], 5, id="skip-zip-extension-empty-suffix"),
pytest.param([".zip"], 0, id="skip-extension-zip"),
pytest.param([".rlib"], 5, id="skip-extension-rlib"),
pytest.param((), DEFAULT_SKIP_EXTENSION, id="skip-extension-empty"),
pytest.param(("",), ("",), id="skip-zip-extension-empty-suffix"),
pytest.param((".zip",), (".zip",), id="skip-extension-zip"),
pytest.param((".rlib",), (".rlib",), id="skip-extension-rlib"),
],
)
def test_skip_extension(
skip_extension: List[str], extracted_files_count: int, tmp_path: Path
skip_extension: List[str], expected_skip_extensions: Tuple[str, ...], tmp_path: Path
):
runner = CliRunner()
in_path = (
Expand All @@ -335,8 +337,12 @@ def test_skip_extension(
for suffix in skip_extension:
args += ["--skip-extension", suffix]
params = [*args, "--extract-dir", str(tmp_path), str(in_path)]
result = runner.invoke(unblob.cli.cli, params)
assert extracted_files_count == len(list(tmp_path.rglob("*")))
process_file_mock = mock.MagicMock()
with mock.patch.object(unblob.cli, "process_file", process_file_mock):
result = runner.invoke(unblob.cli.cli, params)
assert (
process_file_mock.call_args.args[0].skip_extension == expected_skip_extensions
)
assert result.exit_code == 0


Expand Down Expand Up @@ -420,3 +426,29 @@ def test_clear_skip_magics(
assert sorted(process_file_mock.call_args.args[0].skip_magic) == sorted(
skip_magic
), fail_message


@pytest.mark.skipif(
not is_sandbox_available(), reason="Sandboxing is only available on Linux"
)
def test_sandbox_escape(tmp_path: Path):
runner = CliRunner()

in_path = tmp_path / "input"
in_path.touch()
extract_dir = tmp_path / "extract-dir"
params = ["--extract-dir", str(extract_dir), str(in_path)]

unrelated_file = tmp_path / "unrelated"

process_file_mock = mock.MagicMock(
side_effect=lambda *_args, **_kwargs: unrelated_file.write_text(
"sandbox escape"
)
)
with mock.patch.object(unblob.cli, "process_file", process_file_mock):
result = runner.invoke(unblob.cli.cli, params)

assert result.exit_code != 0
assert isinstance(result.exception, PermissionError)
process_file_mock.assert_called_once()
57 changes: 57 additions & 0 deletions tests/test_sandbox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from pathlib import Path

import pytest

from unblob.processing import ExtractionConfig
from unblob.sandbox import Sandbox
from unblob.testing import is_sandbox_available

pytestmark = pytest.mark.skipif(
not is_sandbox_available(), reason="Sandboxing only works on Linux"
)


@pytest.fixture
def log_path(tmp_path):
return tmp_path / "unblob.log"


@pytest.fixture
def extraction_config(extraction_config, tmp_path):
extraction_config.extract_root = tmp_path / "extract" / "root"
# parent has to exist
extraction_config.extract_root.parent.mkdir()
return extraction_config


@pytest.fixture
def sandbox(extraction_config: ExtractionConfig, log_path: Path):
return Sandbox(extraction_config, log_path, None)


def test_necessary_resources_can_be_created_in_sandbox(
sandbox: Sandbox, extraction_config: ExtractionConfig, log_path: Path
):
directory_in_extract_root = extraction_config.extract_root / "path" / "to" / "dir"
file_in_extract_root = directory_in_extract_root / "file"

sandbox.run(extraction_config.extract_root.mkdir, parents=True)
sandbox.run(directory_in_extract_root.mkdir, parents=True)

sandbox.run(file_in_extract_root.touch)
sandbox.run(file_in_extract_root.write_text, "file content")

# log-file is already opened
log_path.touch()
sandbox.run(log_path.write_text, "log line")


def test_access_outside_sandbox_is_not_possible(sandbox: Sandbox, tmp_path: Path):
unrelated_dir = tmp_path / "unrelated" / "path"
unrelated_file = tmp_path / "unrelated-file"

with pytest.raises(PermissionError):
sandbox.run(unrelated_dir.mkdir, parents=True)

with pytest.raises(PermissionError):
sandbox.run(unrelated_file.touch)
4 changes: 3 additions & 1 deletion unblob/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
ExtractionConfig,
process_file,
)
from .sandbox import Sandbox
from .ui import NullProgressReporter, RichConsoleProgressReporter

logger = get_logger()
Expand Down Expand Up @@ -321,7 +322,8 @@ def cli(
)

logger.info("Start processing file", file=file)
process_results = process_file(config, file, report_file)
sandbox = Sandbox(config, log_path, report_file)
process_results = sandbox.run(process_file, config, file, report_file)
if verbose == 0:
if skip_extraction:
print_scan_report(process_results)
Expand Down
84 changes: 70 additions & 14 deletions unblob/pool.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
import abc
import contextlib
import multiprocessing as mp
import os
import queue
import signal
import sys
import threading
from multiprocessing.queues import JoinableQueue
from typing import Any, Callable, Union
from typing import Any, Callable, Set, Union

from .logging import multiprocessing_breakpoint

mp.set_start_method("fork")


class PoolBase(abc.ABC):
def __init__(self):
with pools_lock:
pools.add(self)

@abc.abstractmethod
def submit(self, args):
pass
Expand All @@ -24,15 +30,20 @@ def process_until_done(self):
def start(self):
pass

def close(self):
pass
def close(self, *, immediate=False): # noqa: ARG002
with pools_lock:
pools.remove(self)

def __enter__(self):
self.start()
return self

def __exit__(self, *args):
self.close()
def __exit__(self, exc_type, _exc_value, _tb):
self.close(immediate=exc_type is not None)


pools_lock = threading.Lock()
pools: Set[PoolBase] = set()


class Queue(JoinableQueue):
Expand All @@ -53,9 +64,15 @@ class _Sentinel:


def _worker_process(handler, input_, output):
# Creates a new process group, making sure no signals are propagated from the main process to the worker processes.
# Creates a new process group, making sure no signals are
# propagated from the main process to the worker processes.
os.setpgrp()

# Restore default signal handlers, otherwise workers would inherit
# them from main process
signal.signal(signal.SIGTERM, signal.SIG_DFL)
signal.signal(signal.SIGINT, signal.SIG_DFL)

sys.breakpointhook = multiprocessing_breakpoint
while (args := input_.get()) is not _SENTINEL:
result = handler(args)
Expand All @@ -71,11 +88,14 @@ def __init__(
*,
result_callback: Callable[["MultiPool", Any], Any],
):
super().__init__()
if process_num <= 0:
raise ValueError("At process_num must be greater than 0")

self._running = False
self._result_callback = result_callback
self._input = Queue(ctx=mp.get_context())
self._input.cancel_join_thread()
self._output = mp.SimpleQueue()
self._procs = [
mp.Process(
Expand All @@ -87,14 +107,32 @@ def __init__(
self._tid = threading.get_native_id()

def start(self):
self._running = True
for p in self._procs:
p.start()

def close(self):
self._clear_input_queue()
self._request_workers_to_quit()
self._clear_output_queue()
def close(self, *, immediate=False):
if not self._running:
return
self._running = False

if immediate:
self._terminate_workers()
else:
self._clear_input_queue()
self._request_workers_to_quit()
self._clear_output_queue()

self._wait_for_workers_to_quit()
super().close(immediate=immediate)

def _terminate_workers(self):
for proc in self._procs:
proc.terminate()

self._input.close()
if sys.version_info >= (3, 9):
self._output.close()

def _clear_input_queue(self):
try:
Expand Down Expand Up @@ -129,14 +167,16 @@ def submit(self, args):
self._input.put(args)

def process_until_done(self):
while not self._input.is_empty():
result = self._output.get()
self._result_callback(self, result)
self._input.task_done()
with contextlib.suppress(EOFError):
while not self._input.is_empty():
result = self._output.get()
self._result_callback(self, result)
self._input.task_done()


class SinglePool(PoolBase):
def __init__(self, handler, *, result_callback):
super().__init__()
self._handler = handler
self._result_callback = result_callback

Expand All @@ -157,3 +197,19 @@ def make_pool(process_num, handler, result_callback) -> Union[SinglePool, MultiP
handler=handler,
result_callback=result_callback,
)


orig_signal_handlers = {}


def _on_terminate(signum, frame):
pools_snapshot = list(pools)
for pool in pools_snapshot:
pool.close(immediate=True)

if callable(orig_signal_handlers[signum]):
orig_signal_handlers[signum](signum, frame)


orig_signal_handlers[signal.SIGTERM] = signal.signal(signal.SIGTERM, _on_terminate)
orig_signal_handlers[signal.SIGINT] = signal.signal(signal.SIGINT, _on_terminate)
2 changes: 0 additions & 2 deletions unblob/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
StatReport,
UnknownError,
)
from .signals import terminate_gracefully
from .ui import NullProgressReporter, ProgressReporter

logger = get_logger()
Expand Down Expand Up @@ -118,7 +117,6 @@ def get_carve_dir_for(self, path: Path) -> Path:
return self._get_output_path(path.with_name(path.name + self.carve_suffix))


@terminate_gracefully
def process_file(
config: ExtractionConfig, input_path: Path, report_file: Optional[Path] = None
) -> ProcessResult:
Expand Down
Loading

0 comments on commit c31b05e

Please sign in to comment.