From b0b80297db1276fee4c56c101dfcd6017d1cba1d Mon Sep 17 00:00:00 2001 From: Josh Borrow Date: Wed, 18 Dec 2024 09:02:14 -0500 Subject: [PATCH] Added parallel checksumming option (#125) --- hera_librarian/client.py | 2 +- hera_librarian/utils.py | 55 +++++++++++++++++++++----------- librarian_server/settings.py | 3 ++ librarian_server/stores/local.py | 20 ++++++++++-- 4 files changed, 57 insertions(+), 23 deletions(-) diff --git a/hera_librarian/client.py b/hera_librarian/client.py index 35ba864..cbc2158 100644 --- a/hera_librarian/client.py +++ b/hera_librarian/client.py @@ -399,7 +399,7 @@ def upload( endpoint="upload/stage", request=UploadInitiationRequest( upload_size=get_size_from_path(local_path), - upload_checksum=get_checksum_from_path(local_path), + upload_checksum=get_checksum_from_path(local_path, threads=1), upload_name=dest_path.name, destination_location=dest_path, uploader=self.user, diff --git a/hera_librarian/utils.py b/hera_librarian/utils.py index bb06755..16a3ef0 100644 --- a/hera_librarian/utils.py +++ b/hera_librarian/utils.py @@ -6,6 +6,8 @@ import os import os.path import re +from functools import partial +from multiprocessing import Pool from pathlib import Path import xxhash @@ -33,7 +35,7 @@ def dirhash( ignore_hidden=False, followlinks=False, excluded_extensions=None, - include_paths=False, + threads=2, ): hash_func = HASH_FUNCS.get(hashfunc) if not hash_func: @@ -56,27 +58,40 @@ def dirhash( dirs.sort() files.sort() - for fname in files: - if ignore_hidden and fname.startswith("."): - continue + core = partial( + individual, + root=root, + ignore_hidden=ignore_hidden, + excluded_extensions=excluded_extensions, + excluded_files=excluded_files, + hashfunc=hashfunc, + ) + + if threads > 1: + with Pool(threads) as p: + hashvalues.extend(p.map(core, files)) + else: + for fname in files: + hashvalues.append(core(fname)) + + return _reduce_hash(filter(lambda x: x is not None, hashvalues), hash_func) - if fname.split(".")[-1:][0] in excluded_extensions: - continue - if fname in excluded_files: - continue +def individual( + fname, root, ignore_hidden, excluded_extensions, excluded_files, hashfunc +): + if ignore_hidden and fname.startswith("."): + return None + + if fname.split(".")[-1:][0] in excluded_extensions: + return None - hashvalues.append(_filehash(os.path.join(root, fname), hash_func)) + if fname in excluded_files: + return None - if include_paths: - hasher = hash_func() - # get the resulting relative path into array of elements - path_list = os.path.relpath(os.path.join(root, fname)).split(os.sep) - # compute the hash on joined list, removes all os specific separators - hasher.update("".join(path_list).encode("utf-8")) - hashvalues.append(hasher.hexdigest()) + hash_func = HASH_FUNCS.get(hashfunc) - return _reduce_hash(hashvalues, hash_func) + return _filehash(os.path.join(root, fname), hash_func) def _filehash(filepath, hashfunc): @@ -130,7 +145,9 @@ def get_md5_from_path(path): return _filehash(path, HASH_FUNCS["md5"]) -def get_checksum_from_path(path: str | Path, hash_function: str = "xxh3") -> str: +def get_checksum_from_path( + path: str | Path, hash_function: str = "xxh3", threads: int = 1 +) -> str: """ Compute the checksum of a file from a path. This allows you to select the underlying checksum function, which is by default the very fast @@ -141,7 +158,7 @@ def get_checksum_from_path(path: str | Path, hash_function: str = "xxh3") -> str path = Path(path).resolve() if path.is_dir(): - return hash_function + ":::" + dirhash(path, hash_function) + return hash_function + ":::" + dirhash(path, hash_function, threads=threads) else: # Just a single file. That's fine! return hash_function + ":::" + _filehash(path, HASH_FUNCS[hash_function]) diff --git a/librarian_server/settings.py b/librarian_server/settings.py index f70df34..2b82274 100644 --- a/librarian_server/settings.py +++ b/librarian_server/settings.py @@ -177,6 +177,9 @@ class ServerSettings(BaseSettings): globus_client_secret: Optional[str] = None globus_client_secret_file: Optional[Path] = None + # Checksumming options + checksum_threads: int = 4 + model_config = SettingsConfigDict(env_prefix="librarian_server_") def model_post_init(__context, *args, **kwargs): diff --git a/librarian_server/stores/local.py b/librarian_server/stores/local.py index fe801a2..f66bb54 100644 --- a/librarian_server/stores/local.py +++ b/librarian_server/stores/local.py @@ -190,6 +190,9 @@ def delete(self, path: Path): return def commit(self, staging_path: Path, store_path: Path): + # Must import after initialization, unfortunately. + from librarian_server.settings import server_settings + need_ownership_changes = self.own_after_commit or self.readonly_after_commit resolved_path_staging = self._resolved_path_staging(staging_path) @@ -210,7 +213,9 @@ def commit(self, staging_path: Path, store_path: Path): retries = 0 copy_success = False - original_checksum = get_checksum_from_path(resolved_path_staging) + original_checksum = get_checksum_from_path( + resolved_path_staging, threads=server_settings.checksum_threads + ) while not copy_success and retries < self.max_copy_retries: if resolved_path_staging.is_dir(): @@ -218,7 +223,9 @@ def commit(self, staging_path: Path, store_path: Path): else: shutil.copy2(resolved_path_staging, resolved_path_store) - new_checksum = get_checksum_from_path(resolved_path_store) + new_checksum = get_checksum_from_path( + resolved_path_store, threads=server_settings.checksum_threads + ) copy_success = compare_checksums(original_checksum, new_checksum) retries += 1 @@ -289,6 +296,9 @@ def store(self, path: Path) -> Path: return resolved_path def path_info(self, path: Path, hash_function: str = "xxh3") -> PathInfo: + # Must import after initialization, unfortunately. + from librarian_server.settings import server_settings + # Promote path to object if required path = Path(path) @@ -299,7 +309,11 @@ def path_info(self, path: Path, hash_function: str = "xxh3") -> PathInfo: # Use the old functions for consistency. path=path, filetype=get_type_from_path(str(path)), - checksum=get_checksum_from_path(path, hash_function=hash_function), + checksum=get_checksum_from_path( + path, + hash_function=hash_function, + threads=server_settings.checksum_threads, + ), size=get_size_from_path(path), )