diff --git a/astacus/common/snapshot.py b/astacus/common/snapshot.py index 2251f250..e08b8630 100644 --- a/astacus/common/snapshot.py +++ b/astacus/common/snapshot.py @@ -3,7 +3,7 @@ See LICENSE for details """ from astacus.common.magic import DEFAULT_EMBEDDED_FILE_SIZE -from typing import Sequence +from typing import Self, Sequence import dataclasses @@ -15,3 +15,6 @@ class SnapshotGroup: excluded_names: Sequence[str] = () # None means "no limit": all files matching the glob will be embedded embedded_file_size_max: int | None = DEFAULT_EMBEDDED_FILE_SIZE + + def without_excluded_names(self) -> Self: + return dataclasses.replace(self, excluded_names=()) diff --git a/astacus/common/utils.py b/astacus/common/utils.py index 49b35caf..559e44f7 100644 --- a/astacus/common/utils.py +++ b/astacus/common/utils.py @@ -28,6 +28,7 @@ import requests import tempfile import time +import wcmatch.glob logger = logging.getLogger(__name__) diff --git a/astacus/node/memory_snapshot.py b/astacus/node/memory_snapshot.py index 2a5dd80a..d7542ffa 100644 --- a/astacus/node/memory_snapshot.py +++ b/astacus/node/memory_snapshot.py @@ -11,7 +11,6 @@ from astacus.common.snapshot import SnapshotGroup from astacus.node.snapshot import Snapshot from astacus.node.snapshotter import hash_hexdigest_readable, Snapshotter -from glob import iglob from pathlib import Path from typing import Iterable, Iterator, Mapping, Sequence @@ -77,13 +76,11 @@ def get_all_digests(self) -> Iterable[SnapshotHash]: class MemorySnapshotter(Snapshotter[MemorySnapshot]): def _list_files(self, basepath: Path) -> list[FoundFile]: result_files = set() - for group in self._groups: - for p in iglob(group.root_glob, root_dir=basepath, recursive=True): + for group in self._groups.groups: + for p in group.glob(root_dir=basepath): path = basepath / p if not path.is_file() or path.is_symlink(): continue - if path.name in group.excluded_names: - continue relpath = path.relative_to(basepath) for parent in relpath.parents: if parent.name == magic.ASTACUS_TMPDIR: @@ -92,9 +89,7 @@ def _list_files(self, basepath: Path) -> list[FoundFile]: result_files.add( FoundFile( relative_path=relpath, - group=SnapshotGroup( - root_glob=group.root_glob, embedded_file_size_max=group.embedded_file_size_max - ), + group=group.group.without_excluded_names(), ) ) return sorted(result_files, key=lambda found_file: found_file.relative_path) diff --git a/astacus/node/snapshot_groups.py b/astacus/node/snapshot_groups.py new file mode 100644 index 00000000..276df591 --- /dev/null +++ b/astacus/node/snapshot_groups.py @@ -0,0 +1,59 @@ +""" + +Copyright (c) 2023 Aiven Ltd +See LICENSE for details + +Classes for working with snapshot groups. + +""" +from astacus.common.snapshot import SnapshotGroup +from pathlib import Path +from typing import Iterable, Optional, Sequence +from typing_extensions import Self +from wcmatch.glob import GLOBSTAR, iglob, translate + +import dataclasses +import os +import re + +WCMATCH_FLAGS = GLOBSTAR + + +@dataclasses.dataclass +class CompiledGroup: + group: SnapshotGroup + regex: re.Pattern + + @classmethod + def compile(cls, group: SnapshotGroup) -> Self: + return cls(group, compile(group.root_glob)) + + def matches(self, relative_path: Path) -> bool: + return bool(self.regex.match(str(relative_path))) and not relative_path.name in self.group.excluded_names + + def glob(self, root_dir: Optional[Path] = None) -> Iterable[str]: + for path in iglob(self.group.root_glob, root_dir=root_dir, flags=WCMATCH_FLAGS): + if os.path.basename(path) not in self.group.excluded_names: + yield path + + +@dataclasses.dataclass +class CompiledGroups: + groups: Sequence[CompiledGroup] + + @classmethod + def compile(cls, groups: Sequence[SnapshotGroup]) -> Self: + return cls([CompiledGroup.compile(group) for group in groups]) + + def get_matching(self, relative_path: Path) -> list[SnapshotGroup]: + return [group.group for group in self.groups if group.matches(relative_path)] + + def any_match(self, relative_path: Path) -> bool: + return any(group.matches(relative_path) for group in self.groups) + + def root_globs(self) -> list[str]: + return [group.group.root_glob for group in self.groups] + + +def compile(glob: str) -> re.Pattern: + return re.compile(translate(glob, flags=WCMATCH_FLAGS)[0][0]) diff --git a/astacus/node/snapshotter.py b/astacus/node/snapshotter.py index 2f5a81c4..d34a9035 100644 --- a/astacus/node/snapshotter.py +++ b/astacus/node/snapshotter.py @@ -10,6 +10,7 @@ from astacus.common.progress import Progress from astacus.common.snapshot import SnapshotGroup from astacus.node.snapshot import Snapshot +from astacus.node.snapshot_groups import CompiledGroups from multiprocessing import dummy from pathlib import Path from threading import Lock @@ -28,7 +29,7 @@ def __init__(self, groups: Sequence[SnapshotGroup], src: Path, dst: Path, snapsh self.snapshot = snapshot self._src = src self._dst = dst - self._groups = groups + self._groups = CompiledGroups.compile(groups) self._parallel = parallel self._dst.mkdir(parents=True, exist_ok=True) @@ -45,9 +46,7 @@ def release(self, hexdigests: Iterable[str], *, progress: Progress) -> None: ... def get_snapshot_state(self) -> SnapshotState: - return SnapshotState( - root_globs=[group.root_glob for group in self._groups], files=list(self.snapshot.get_all_files()) - ) + return SnapshotState(root_globs=self._groups.root_globs(), files=list(self.snapshot.get_all_files())) def _file_in_src(self, relative_path: Path) -> SnapshotFile: src_path = self._src / relative_path @@ -71,10 +70,7 @@ def _cb(snapshotfile: SnapshotFile) -> SnapshotFile: yield from p.imap_unordered(_cb, files) def _embedded_file_size_max_for_file(self, file: SnapshotFile) -> int | None: - groups = [] - for group in self._groups: - if file.relative_path.match(group.root_glob): - groups.append(group) + groups = self._groups.get_matching(file.relative_path) assert groups head, *tail = groups for group in tail: diff --git a/astacus/node/sqlite_snapshot.py b/astacus/node/sqlite_snapshot.py index 6e87fdca..761ccaf3 100644 --- a/astacus/node/sqlite_snapshot.py +++ b/astacus/node/sqlite_snapshot.py @@ -11,12 +11,13 @@ from astacus.node.snapshot import Snapshot from astacus.node.snapshotter import Snapshotter from contextlib import closing -from fnmatch import fnmatch from pathlib import Path from typing import Iterable, Sequence -from typing_extensions import override +from typing_extensions import override, Self +import dataclasses import os +import re import sqlite3 @@ -124,15 +125,8 @@ def _list_files_and_create_directories(self) -> Iterable[Path]: (self._dst / rel_dir).mkdir(parents=True, exist_ok=True) for f in files: rel_path = rel_dir / f - full_path = dir_path / f - if full_path.is_symlink(): - continue - for group in self._groups: - # fnmatch works strangely with paths until 3.13 so convert to string - # https://github.com/python/cpython/issues/73435 - if fnmatch(str(rel_path), group.root_glob) and f not in group.excluded_names: - yield rel_path - break + if not (dir_path / f).is_symlink() and self._groups.any_match(rel_path): + yield rel_path def _compare_current_snapshot(self, files: Iterable[Path]) -> Iterable[tuple[Path, SnapshotFile | None]]: with closing(self._con.cursor()) as cur: diff --git a/setup.cfg b/setup.cfg index 858f034b..41f57a92 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,6 +18,7 @@ install_requires = tabulate==0.9.0 typing-extensions==4.7.1 uvicorn==0.15.0 + wcmatch==8.4.1 # Pinned transitive deps pydantic==1.10.2 diff --git a/tests/unit/node/test_snapshot_groups.py b/tests/unit/node/test_snapshot_groups.py new file mode 100644 index 00000000..8f574f61 --- /dev/null +++ b/tests/unit/node/test_snapshot_groups.py @@ -0,0 +1,60 @@ +from astacus.common.snapshot import SnapshotGroup +from astacus.node.snapshot_groups import compile, CompiledGroup, CompiledGroups +from pathlib import Path + +import os + +POSITIVE_TEST_CASES: list[tuple[Path, str]] = [ + (Path("foo"), "foo"), + (Path("foo"), "*"), + (Path("foo"), "**"), + (Path("foo/bar"), "**"), + (Path("foo/bar/baz"), "**/*"), + (Path("foo/bar"), "**/*"), + (Path("foo/bar"), "**/**"), +] + +NEGATIVE_TEST_CASES: list[tuple[Path, str]] = [ + (Path("foo/bar/baz"), "*/*"), +] + + +def test_compile() -> None: + for path, glob in POSITIVE_TEST_CASES: + assert compile(glob).match(str(path)) is not None + for path, glob in NEGATIVE_TEST_CASES: + assert compile(glob).match(str(path)) is None + + +def test_CompiledGroup_matches() -> None: + for path, glob in POSITIVE_TEST_CASES: + group = SnapshotGroup(root_glob=glob) + assert CompiledGroup.compile(group).matches(path) + group = SnapshotGroup(root_glob=glob, excluded_names=[os.path.basename(path)]) + assert not CompiledGroup.compile(group).matches(path) + for path, glob in NEGATIVE_TEST_CASES: + group = SnapshotGroup(root_glob=glob) + assert not CompiledGroup.compile(group).matches(path) + + +def test_CompiledGroups() -> None: + for path, glob in POSITIVE_TEST_CASES: + group1 = SnapshotGroup(root_glob=glob) + group2 = SnapshotGroup(root_glob=glob, excluded_names=[os.path.basename(path)]) + group3 = SnapshotGroup(root_glob="doesntmatch") + compiled = CompiledGroups.compile([group1, group2, group3]) + assert compiled.any_match(path) + assert compiled.get_matching(path) == [group1] + + +def test_CompiledGroup_glob(tmp_path: Path) -> None: + for p, _ in POSITIVE_TEST_CASES + NEGATIVE_TEST_CASES: + p = tmp_path / p + p.mkdir(parents=True, exist_ok=True) + p.touch() + for p, glob in POSITIVE_TEST_CASES: + group = SnapshotGroup(root_glob=glob) + assert str(p) in CompiledGroup.compile(group).glob(tmp_path) + for p, glob in NEGATIVE_TEST_CASES: + group = SnapshotGroup(root_glob=glob) + assert str(p) not in CompiledGroup.compile(group).glob(tmp_path)