Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added parallel checksumming option #125

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hera_librarian/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def upload(
endpoint="upload/stage",
request=UploadInitiationRequest(
upload_size=get_size_from_path(local_path),
upload_checksum=get_checksum_from_path(local_path),
upload_checksum=get_checksum_from_path(local_path, threads=1),
upload_name=dest_path.name,
destination_location=dest_path,
uploader=self.user,
Expand Down
55 changes: 36 additions & 19 deletions hera_librarian/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import os
import os.path
import re
from functools import partial
from multiprocessing import Pool
from pathlib import Path

import xxhash
Expand Down Expand Up @@ -33,7 +35,7 @@ def dirhash(
ignore_hidden=False,
followlinks=False,
excluded_extensions=None,
include_paths=False,
threads=2,
):
hash_func = HASH_FUNCS.get(hashfunc)
if not hash_func:
Expand All @@ -56,27 +58,40 @@ def dirhash(
dirs.sort()
files.sort()

for fname in files:
if ignore_hidden and fname.startswith("."):
continue
core = partial(
individual,
root=root,
ignore_hidden=ignore_hidden,
excluded_extensions=excluded_extensions,
excluded_files=excluded_files,
hashfunc=hashfunc,
)

if threads > 1:
with Pool(threads) as p:
hashvalues.extend(p.map(core, files))
else:
for fname in files:
hashvalues.append(core(fname))

return _reduce_hash(filter(lambda x: x is not None, hashvalues), hash_func)

if fname.split(".")[-1:][0] in excluded_extensions:
continue

if fname in excluded_files:
continue
def individual(
fname, root, ignore_hidden, excluded_extensions, excluded_files, hashfunc
):
if ignore_hidden and fname.startswith("."):
return None

if fname.split(".")[-1:][0] in excluded_extensions:
return None

hashvalues.append(_filehash(os.path.join(root, fname), hash_func))
if fname in excluded_files:
return None

if include_paths:
hasher = hash_func()
# get the resulting relative path into array of elements
path_list = os.path.relpath(os.path.join(root, fname)).split(os.sep)
# compute the hash on joined list, removes all os specific separators
hasher.update("".join(path_list).encode("utf-8"))
hashvalues.append(hasher.hexdigest())
hash_func = HASH_FUNCS.get(hashfunc)

return _reduce_hash(hashvalues, hash_func)
return _filehash(os.path.join(root, fname), hash_func)


def _filehash(filepath, hashfunc):
Expand Down Expand Up @@ -130,7 +145,9 @@ def get_md5_from_path(path):
return _filehash(path, HASH_FUNCS["md5"])


def get_checksum_from_path(path: str | Path, hash_function: str = "xxh3") -> str:
def get_checksum_from_path(
path: str | Path, hash_function: str = "xxh3", threads: int = 1
) -> str:
"""
Compute the checksum of a file from a path. This allows you to select
the underlying checksum function, which is by default the very fast
Expand All @@ -141,7 +158,7 @@ def get_checksum_from_path(path: str | Path, hash_function: str = "xxh3") -> str
path = Path(path).resolve()

if path.is_dir():
return hash_function + ":::" + dirhash(path, hash_function)
return hash_function + ":::" + dirhash(path, hash_function, threads=threads)
else:
# Just a single file. That's fine!
return hash_function + ":::" + _filehash(path, HASH_FUNCS[hash_function])
Expand Down
3 changes: 3 additions & 0 deletions librarian_server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ class ServerSettings(BaseSettings):
globus_client_secret: Optional[str] = None
globus_client_secret_file: Optional[Path] = None

# Checksumming options
checksum_threads: int = 4

model_config = SettingsConfigDict(env_prefix="librarian_server_")

def model_post_init(__context, *args, **kwargs):
Expand Down
20 changes: 17 additions & 3 deletions librarian_server/stores/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ def delete(self, path: Path):
return

def commit(self, staging_path: Path, store_path: Path):
# Must import after initialization, unfortunately.
from librarian_server.settings import server_settings

need_ownership_changes = self.own_after_commit or self.readonly_after_commit

resolved_path_staging = self._resolved_path_staging(staging_path)
Expand All @@ -210,15 +213,19 @@ def commit(self, staging_path: Path, store_path: Path):
retries = 0
copy_success = False

original_checksum = get_checksum_from_path(resolved_path_staging)
original_checksum = get_checksum_from_path(
resolved_path_staging, threads=server_settings.checksum_threads
)

while not copy_success and retries < self.max_copy_retries:
if resolved_path_staging.is_dir():
shutil.copytree(resolved_path_staging, resolved_path_store)
else:
shutil.copy2(resolved_path_staging, resolved_path_store)

new_checksum = get_checksum_from_path(resolved_path_store)
new_checksum = get_checksum_from_path(
resolved_path_store, threads=server_settings.checksum_threads
)

copy_success = compare_checksums(original_checksum, new_checksum)
retries += 1
Expand Down Expand Up @@ -289,6 +296,9 @@ def store(self, path: Path) -> Path:
return resolved_path

def path_info(self, path: Path, hash_function: str = "xxh3") -> PathInfo:
# Must import after initialization, unfortunately.
from librarian_server.settings import server_settings

# Promote path to object if required
path = Path(path)

Expand All @@ -299,7 +309,11 @@ def path_info(self, path: Path, hash_function: str = "xxh3") -> PathInfo:
# Use the old functions for consistency.
path=path,
filetype=get_type_from_path(str(path)),
checksum=get_checksum_from_path(path, hash_function=hash_function),
checksum=get_checksum_from_path(
path,
hash_function=hash_function,
threads=server_settings.checksum_threads,
),
size=get_size_from_path(path),
)

Expand Down
Loading