Skip to content

Commit

Permalink
Added parallel checksumming option (#125)
Browse files Browse the repository at this point in the history
  • Loading branch information
JBorrow authored Dec 18, 2024
1 parent 6b5af73 commit b0b8029
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 23 deletions.
2 changes: 1 addition & 1 deletion hera_librarian/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def upload(
endpoint="upload/stage",
request=UploadInitiationRequest(
upload_size=get_size_from_path(local_path),
upload_checksum=get_checksum_from_path(local_path),
upload_checksum=get_checksum_from_path(local_path, threads=1),
upload_name=dest_path.name,
destination_location=dest_path,
uploader=self.user,
Expand Down
55 changes: 36 additions & 19 deletions hera_librarian/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import os
import os.path
import re
from functools import partial
from multiprocessing import Pool
from pathlib import Path

import xxhash
Expand Down Expand Up @@ -33,7 +35,7 @@ def dirhash(
ignore_hidden=False,
followlinks=False,
excluded_extensions=None,
include_paths=False,
threads=2,
):
hash_func = HASH_FUNCS.get(hashfunc)
if not hash_func:
Expand All @@ -56,27 +58,40 @@ def dirhash(
dirs.sort()
files.sort()

for fname in files:
if ignore_hidden and fname.startswith("."):
continue
core = partial(
individual,
root=root,
ignore_hidden=ignore_hidden,
excluded_extensions=excluded_extensions,
excluded_files=excluded_files,
hashfunc=hashfunc,
)

if threads > 1:
with Pool(threads) as p:
hashvalues.extend(p.map(core, files))
else:
for fname in files:
hashvalues.append(core(fname))

return _reduce_hash(filter(lambda x: x is not None, hashvalues), hash_func)

if fname.split(".")[-1:][0] in excluded_extensions:
continue

if fname in excluded_files:
continue
def individual(
fname, root, ignore_hidden, excluded_extensions, excluded_files, hashfunc
):
if ignore_hidden and fname.startswith("."):
return None

if fname.split(".")[-1:][0] in excluded_extensions:
return None

hashvalues.append(_filehash(os.path.join(root, fname), hash_func))
if fname in excluded_files:
return None

if include_paths:
hasher = hash_func()
# get the resulting relative path into array of elements
path_list = os.path.relpath(os.path.join(root, fname)).split(os.sep)
# compute the hash on joined list, removes all os specific separators
hasher.update("".join(path_list).encode("utf-8"))
hashvalues.append(hasher.hexdigest())
hash_func = HASH_FUNCS.get(hashfunc)

return _reduce_hash(hashvalues, hash_func)
return _filehash(os.path.join(root, fname), hash_func)


def _filehash(filepath, hashfunc):
Expand Down Expand Up @@ -130,7 +145,9 @@ def get_md5_from_path(path):
return _filehash(path, HASH_FUNCS["md5"])


def get_checksum_from_path(path: str | Path, hash_function: str = "xxh3") -> str:
def get_checksum_from_path(
path: str | Path, hash_function: str = "xxh3", threads: int = 1
) -> str:
"""
Compute the checksum of a file from a path. This allows you to select
the underlying checksum function, which is by default the very fast
Expand All @@ -141,7 +158,7 @@ def get_checksum_from_path(path: str | Path, hash_function: str = "xxh3") -> str
path = Path(path).resolve()

if path.is_dir():
return hash_function + ":::" + dirhash(path, hash_function)
return hash_function + ":::" + dirhash(path, hash_function, threads=threads)
else:
# Just a single file. That's fine!
return hash_function + ":::" + _filehash(path, HASH_FUNCS[hash_function])
Expand Down
3 changes: 3 additions & 0 deletions librarian_server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ class ServerSettings(BaseSettings):
globus_client_secret: Optional[str] = None
globus_client_secret_file: Optional[Path] = None

# Checksumming options
checksum_threads: int = 4

model_config = SettingsConfigDict(env_prefix="librarian_server_")

def model_post_init(__context, *args, **kwargs):
Expand Down
20 changes: 17 additions & 3 deletions librarian_server/stores/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ def delete(self, path: Path):
return

def commit(self, staging_path: Path, store_path: Path):
# Must import after initialization, unfortunately.
from librarian_server.settings import server_settings

need_ownership_changes = self.own_after_commit or self.readonly_after_commit

resolved_path_staging = self._resolved_path_staging(staging_path)
Expand All @@ -210,15 +213,19 @@ def commit(self, staging_path: Path, store_path: Path):
retries = 0
copy_success = False

original_checksum = get_checksum_from_path(resolved_path_staging)
original_checksum = get_checksum_from_path(
resolved_path_staging, threads=server_settings.checksum_threads
)

while not copy_success and retries < self.max_copy_retries:
if resolved_path_staging.is_dir():
shutil.copytree(resolved_path_staging, resolved_path_store)
else:
shutil.copy2(resolved_path_staging, resolved_path_store)

new_checksum = get_checksum_from_path(resolved_path_store)
new_checksum = get_checksum_from_path(
resolved_path_store, threads=server_settings.checksum_threads
)

copy_success = compare_checksums(original_checksum, new_checksum)
retries += 1
Expand Down Expand Up @@ -289,6 +296,9 @@ def store(self, path: Path) -> Path:
return resolved_path

def path_info(self, path: Path, hash_function: str = "xxh3") -> PathInfo:
# Must import after initialization, unfortunately.
from librarian_server.settings import server_settings

# Promote path to object if required
path = Path(path)

Expand All @@ -299,7 +309,11 @@ def path_info(self, path: Path, hash_function: str = "xxh3") -> PathInfo:
# Use the old functions for consistency.
path=path,
filetype=get_type_from_path(str(path)),
checksum=get_checksum_from_path(path, hash_function=hash_function),
checksum=get_checksum_from_path(
path,
hash_function=hash_function,
threads=server_settings.checksum_threads,
),
size=get_size_from_path(path),
)

Expand Down

0 comments on commit b0b8029

Please sign in to comment.