Skip to content

Commit

Permalink
rohmu: add progress callback for transferred object keys
Browse files Browse the repository at this point in the history
Added callback progress_fn to copy_files_from
which tracks the progress of transferred objects with completed and total files.

This does not measure the number of bytes
in order to avoid a provider based implementation.
  • Loading branch information
tilman-aiven committed Aug 22, 2024
1 parent b6b332b commit 31f4254
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
18 changes: 15 additions & 3 deletions rohmu/object_storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
TypeVar,
Union,
)
from typing_extensions import Self
from typing_extensions import Self, TypeAlias

import logging
import os
Expand All @@ -55,6 +55,9 @@ class IterKeyItem(NamedTuple):
# Argument is the additional number of bytes transferred
IncrementalProgressCallbackType = Optional[Callable[[int], None]]

# First argument is the number of transferred objects, second is the total number of objects
ObjectTransferProgressCallbackType: TypeAlias = Optional[Callable[[int, int], None]]


@dataclass(frozen=True, unsafe_hash=True)
class ConcurrentUpload:
Expand Down Expand Up @@ -202,10 +205,19 @@ def copy_file(
cannot be copied with this method. If no metadata is given copies the existing metadata."""
raise NotImplementedError

def copy_files_from(self, *, source: BaseTransfer[SourceStorageModelT], keys: Collection[str]) -> None:
def copy_files_from(
self,
*,
source: BaseTransfer[SourceStorageModelT],
keys: Collection[str],
progress_fn: ObjectTransferProgressCallbackType = None,
) -> None:
if isinstance(source, self.__class__):
for key in keys:
total_files = len(keys)
for index, key in enumerate(keys):
self._copy_file_from_bucket(source_bucket=source, source_key=key, destination_key=key, timeout=15)
if progress_fn is not None:
progress_fn(index + 1, total_files)
else:
raise NotImplementedError

Expand Down
14 changes: 12 additions & 2 deletions test/object_storage/test_object_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from rohmu import errors
from rohmu.object_storage.local import LocalTransfer
from typing import Any
from unittest import mock
from unittest.mock import MagicMock

import pytest

Expand Down Expand Up @@ -69,18 +71,26 @@ def test_copy(transfer_type: str, request: Any) -> None:
assert transfer.get_contents_to_string("dummy_copy_metadata") == (DUMMY_CONTENT, {"new_k": "new_v"})


def test_copy_local_files_from(tmp_path: Path) -> None:
@pytest.mark.parametrize("with_progress_fn", [False, True])
def test_copy_local_files_from(tmp_path: Path, with_progress_fn: bool) -> None:
source = LocalTransfer(tmp_path / "source", prefix="s-prefix")
destination = LocalTransfer(tmp_path / "destination", prefix="d-prefix")
mock_progress_fn = MagicMock(return_value=None)

source.store_file_from_memory("some/a/key.ext", b"content_a", metadata={"info": "aaa"})
source.store_file_from_memory("some/b/key.ext", b"content_b", metadata={"info": "bbb"})
source.store_file_from_memory("some/c/key.ext", b"content_c", metadata={"info": "ccc"})
destination.copy_files_from(
source=source,
keys=["some/a/key.ext", "some/b/key.ext"],
keys=["some/a/key.ext", "some/b/key.ext", "some/c/key.ext"],
progress_fn=mock_progress_fn if with_progress_fn else None,
)

assert destination.get_contents_to_string("some/a/key.ext") == (b"content_a", {"info": "aaa", "Content-Length": "9"})
assert destination.get_contents_to_string("some/b/key.ext") == (b"content_b", {"info": "bbb", "Content-Length": "9"})
assert destination.get_contents_to_string("some/c/key.ext") == (b"content_c", {"info": "ccc", "Content-Length": "9"})
if with_progress_fn:
assert mock_progress_fn.call_args_list == [mock.call(1, 3), mock.call(2, 3), mock.call(3, 3)]


@pytest.mark.parametrize("transfer_type", ["local_transfer"])
Expand Down

0 comments on commit 31f4254

Please sign in to comment.