Skip to content

Commit

Permalink
Slight rework of parallel_progress_bar function
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed May 17, 2024
1 parent 5ad2737 commit a9b4a69
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 108 deletions.
87 changes: 63 additions & 24 deletions milatools/utils/parallel_progress.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from __future__ import annotations

import asyncio
import functools
from logging import getLogger as get_logger
from typing import Coroutine, Iterable, Protocol, TypedDict, TypeVar
from typing import (
Coroutine,
Protocol,
TypedDict,
TypeVar,
)

from rich.progress import (
BarColumn,
Expand All @@ -29,46 +35,75 @@ class ProgressDict(TypedDict):
info: NotRequired[str]


class AsyncTaskFn(Protocol[OutT_co]):
"""Protocol for a function that can be run in parallel and reports its progress.
class ReportProgressFn(Protocol):
"""A function to be called inside a task to show information in the progress bar."""

The function should periodically set a dict containing info about it's progress in
the `progress_dict` at key `task_id`. For example:
def __call__(self, progress: int, total: int, info: str | None = None) -> None:
... # pragma: no cover

```python
async def _example_task_fn(progress_dict: DictProxy[TaskID, ProgressDict], task_id: TaskID):
import random
import time
progress_dict[task_id] = {"progress": 0, "total": len_of_task, "info": "Starting."}

len_of_task = random.randint(3, 20) # take some random length of time
for n in range(len_of_task):
await asyncio.sleep(1) # sleep for a bit to simulate work
progress_dict[task_id] = {"progress": n + 1, "total": len_of_task}
def report_progress(
progress: int,
total: int,
info: str | None = None,
*,
task_id: TaskID,
progress_dict: dict[TaskID, ProgressDict],
):
if info is not None:
progress_dict[task_id] = {"progress": progress, "total": total, "info": info}
else:
progress_dict[task_id] = {"progress": progress, "total": total}

progress_dict[task_id] = {"progress": len_of_task, "total": len_of_task, "info": "Done."}
return f"Some result for task {task_id}."

async for result in parallel_progress_bar([_example_task_fn, _example_task_fn]):
print(result)
class AsyncTaskFn(Protocol[OutT_co]):
"""Protocol for a function that can be run in parallel and reports its progress.
The function can (should) periodically report info about it's progress by calling
the `report_progress` function. For example:
"""

def __call__(
self, task_progress_dict: dict[TaskID, ProgressDict], task_id: TaskID
self, report_progress: ReportProgressFn
) -> Coroutine[None, None, OutT_co]:
...
... # pragma: no cover


async def async_progress_bar(
async def run_async_tasks_with_progress_bar(
async_task_fns: list[AsyncTaskFn[OutT_co]],
task_descriptions: list[str] | None = None,
overall_progress_task_description: str = "[green]All jobs progress:",
) -> Iterable[OutT_co]:
"""Parallel progress bar where each task is a coroutine.
) -> list[OutT_co]:
"""Run a sequence of async tasks in "parallel" and display a progress bar.
Adapted from the example at:
https://www.deanmontgomery.com/2022/03/24/rich-progress-and-multiprocessing/
NOTE: This differs from the usual progress bar: the results are returned as a list
(all at the same time) instead of one at a time.
>>> async def example_task_fn(report_progress: ReportProgressFn, len_of_task: int):
... import random
... report_progress(progress=0, total=len_of_task, info="Starting.")
... for n in range(len_of_task):
... await asyncio.sleep(1) # sleep for a bit to simulate work
... report_progress(progress=n + 1, total=len_of_task, info="working...")
... report_progress(progress=len_of_task, total=len_of_task, info="Done.")
... return f"Done after {len_of_task} seconds."
>>> import functools
>>> tasks = [functools.partial(example_task_fn, len_of_task=i) for i in range(1, 4)]
>>> import time
>>> start_time = time.time()
>>> results = asyncio.run(run_async_tasks_with_progress_bar(tasks))
✓ All jobs progress: 6/6 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 0:00:03
✓ Task 0 - Done. 1/1 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 0:00:00
✓ Task 1 - Done. 2/2 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 0:00:01
✓ Task 2 - Done. 3/3 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 0:00:02
>>> results
['Done after 1 seconds.', 'Done after 2 seconds.', 'Done after 3 seconds.']
>>> f"Finished all tasks in {round(time.time() - start_time)} seconds."
'Finished all tasks in 3 seconds.'
"""
if task_descriptions is None:
task_descriptions = [f"Task {i}" for i in range(len(async_task_fns))]
Expand Down Expand Up @@ -103,7 +138,11 @@ async def async_progress_bar(
visible=True,
start=False,
)
coroutine = async_task_fn(_progress_dict, task_id)
report_progress_fn = functools.partial(
report_progress, task_id=task_id, progress_dict=_progress_dict
)
coroutine = async_task_fn(report_progress=report_progress_fn)

tasks[task_id] = asyncio.create_task(coroutine, name=task_description)

update_pbar_task = asyncio.create_task(
Expand Down
76 changes: 37 additions & 39 deletions milatools/utils/vscode_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
from milatools.utils.local_v2 import LocalV2
from milatools.utils.parallel_progress import (
AsyncTaskFn,
ProgressDict,
TaskID,
async_progress_bar,
ReportProgressFn,
run_async_tasks_with_progress_bar,
)
from milatools.utils.remote_v2 import RemoteV2

Expand Down Expand Up @@ -79,11 +78,11 @@ def vscode_installed() -> bool:
async def sync_vscode_extensions(
source: str | LocalV2 | RemoteV2,
destinations: Sequence[str | LocalV2 | RemoteV2],
) -> dict[str, ProgressDict]:
) -> dict[str, list[str]]:
"""Syncs vscode extensions between `source` all all the clusters in `dest`.
This spawns a thread for each cluster in `dest` and displays a parallel progress bar
for the syncing of vscode extensions to each cluster.
This creates an async task for each cluster in `dest` and displays a progress bar.
Returns the extensions that were installed on each cluster.
"""
if isinstance(source, str):
if source == "localhost":
Expand All @@ -100,7 +99,7 @@ async def sync_vscode_extensions(
source_extensions = await _get_vscode_extensions(source)

task_hostnames: list[str] = []
tasks: list[AsyncTaskFn[ProgressDict]] = []
tasks: list[AsyncTaskFn[list[str]]] = []
task_descriptions: list[str] = []

# Connect to the remotes in parallel.
Expand All @@ -120,17 +119,12 @@ async def sync_vscode_extensions(
)
task_descriptions.append(f"{source.hostname} -> {dest_hostname}")

return {
hostname: result
for hostname, result in zip(
task_hostnames,
await async_progress_bar(
async_task_fns=tasks,
task_descriptions=task_descriptions,
overall_progress_task_description="[green]Syncing vscode extensions:",
),
)
}
results = await run_async_tasks_with_progress_bar(
async_task_fns=tasks,
task_descriptions=task_descriptions,
overall_progress_task_description="[green]Syncing vscode extensions:",
)
return {hostname: result for hostname, result in zip(task_hostnames, results)}


def _remove_source_from_destinations(
Expand Down Expand Up @@ -178,14 +172,13 @@ async def _get_vscode_extensions(


async def _install_vscode_extensions_task_function(
task_progress_dict: dict[TaskID, ProgressDict],
task_id: TaskID,
report_progress: ReportProgressFn,
dest_hostname: str | Literal["localhost"],
source_extensions: dict[str, str],
remote: RemoteV2 | LocalV2,
source_name: str,
verbose: bool = False,
) -> ProgressDict:
) -> list[str]:
"""Installs vscode extensions on the remote cluster.
1. Finds the `code-server` executable on the remote;
Expand All @@ -194,20 +187,17 @@ async def _install_vscode_extensions_task_function(
extensions on the source;
4. Install the extensions that are missing or out of date on the remote, updating
the progress dict as it goes.
Returns the list of installed extensions, in the form 'extension_name@version'.
"""
installed: list[str] = []

def _update_progress(
progress: int, status: str, total: int = len(source_extensions)
):
"""Shows progress to the parent process by setting an item in the task progress
dict."""
progress_dict: ProgressDict = {
"progress": progress,
"total": total,
"info": textwrap.shorten(status, 50, placeholder="..."),
}
task_progress_dict[task_id] = progress_dict
return progress_dict
info = textwrap.shorten(status, 50, placeholder="...")
report_progress(progress=progress, total=total, info=info)

if not remote:
if dest_hostname == "localhost":
Expand Down Expand Up @@ -235,13 +225,15 @@ def _update_progress(
f"The vscode-server executable was not found on {remote.hostname}."
f"Skipping syncing extensions to {remote.hostname}."
)
return _update_progress(
_update_progress(
# IDEA: Use a progress of `-1` to signify an error, and use a "X"
# instead of a checkmark?
progress=0,
total=0,
status="code-server executable not found!",
)
return installed

_update_progress(0, status="fetching installed extensions...")
extensions_on_dest = await _get_vscode_extensions_dict(
remote, code_server_executable
Expand Down Expand Up @@ -269,27 +261,33 @@ def _update_progress(
total=len(to_install),
status=f"Installing {extension_name}",
)
extension = f"{extension_name}@{extension_version}"
result = await _install_vscode_extension(
remote,
code_server_executable,
extension=f"{extension_name}@{extension_version}",
code_server_executable=code_server_executable,
extension=extension,
verbose=verbose,
)
except KeyboardInterrupt:
return _update_progress(
if result.returncode != 0:
logger.debug(
f"Unable to install extension {extension} on {dest_hostname}: {result.stderr}"
)
else:
installed.append(extension)
except (KeyboardInterrupt, asyncio.CancelledError):
_update_progress(
progress=index,
total=len(to_install),
status="Interrupted.",
)
return installed

if result.returncode != 0:
logger.debug(f"{dest_hostname}: {result.stderr}")

return _update_progress(
_update_progress(
progress=len(to_install),
total=len(to_install),
status="Done.",
)
return installed


async def _install_vscode_extension(
Expand Down
37 changes: 33 additions & 4 deletions tests/integration/test_sync_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import importlib
import inspect
import subprocess
from logging import getLogger as get_logger
from typing import Callable
from unittest.mock import AsyncMock, Mock
from unittest.mock import ANY, AsyncMock, Mock

import pytest
from typing_extensions import ParamSpec

from milatools.utils import vscode_utils
from milatools.utils.local_v2 import LocalV2
from milatools.utils.remote_v2 import RemoteV2
from milatools.utils.vscode_utils import (
Expand All @@ -26,6 +28,7 @@
logger = get_logger(__name__)


@pytest.mark.slow
@pytest.mark.parametrize(
"source",
[
Expand All @@ -45,6 +48,7 @@ async def test_sync_vscode_extensions(
source: str,
dest: str,
cluster: str,
login_node_v2: RemoteV2,
monkeypatch: pytest.MonkeyPatch,
):
if source == "cluster":
Expand All @@ -67,19 +71,44 @@ def mock_and_patch(wraps: Callable, *mock_args, **mock_kwargs):
mock_task_function = mock_and_patch(
wraps=_install_vscode_extensions_task_function,
)
extension, version = "ms-python.python", "v2024.0.1"

# Make it so we only need to install this particular extension.
mock_extensions_to_install = mock_and_patch(
wraps=_extensions_to_install,
return_value={"ms-python.python": "v2024.0.1"},
return_value={extension: version},
)
mock_find_code_server_executable = mock_and_patch(
wraps=_find_code_server_executable,
)
from milatools.utils.vscode_utils import _install_vscode_extension

mock_install_extension = AsyncMock(
spec=_install_vscode_extension,
return_value=subprocess.CompletedProcess(
args=["..."],
returncode=0,
stdout=f"Successfully installed {extension}@{version}",
),
)
monkeypatch.setattr(
vscode_utils, _install_vscode_extension.__name__, mock_install_extension
)

await sync_vscode_extensions(
source=LocalV2() if source == "localhost" else RemoteV2(source),
# Avoid actually installing this (possibly oudated?) extension.
extensions_per_cluster = await sync_vscode_extensions(
source=LocalV2() if source == "localhost" else login_node_v2,
destinations=[dest],
)
assert extensions_per_cluster == {dest: [f"{extension}@{version}"]}

mock_install_extension.assert_called_once_with(
LocalV2() if dest == "localhost" else login_node_v2,
code_server_executable=ANY,
extension=f"{extension}@{version}",
verbose=ANY,
)

mock_task_function.assert_called_once()
mock_extensions_to_install.assert_called_once()
if source == "localhost":
Expand Down
Loading

0 comments on commit a9b4a69

Please sign in to comment.