Skip to content

Commit

Permalink
Remove parallel_progress_bar
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed May 16, 2024
1 parent 45dfbb6 commit 4d9a5a9
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 244 deletions.
168 changes: 8 additions & 160 deletions milatools/utils/parallel_progress.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from __future__ import annotations

import asyncio
import multiprocessing
import time
from concurrent.futures import Future, ThreadPoolExecutor
from logging import getLogger as get_logger
from multiprocessing.managers import DictProxy
from typing import Coroutine, Iterable, Protocol, TypedDict, TypeVar

from rich.progress import (
Expand Down Expand Up @@ -33,57 +29,27 @@ class ProgressDict(TypedDict):
info: NotRequired[str]


class TaskFn(Protocol[OutT_co]):
"""Protocol for a function that can be run in parallel and reports its progress.
The function should periodically set a dict containing info about it's progress in
the `progress_dict` at key `task_id`. For example:
```python
def _example_task_fn(progress_dict: DictProxy[TaskID, ProgressDict], task_id: TaskID):
import random
import time
progress_dict[task_id] = {"progress": 0, "total": len_of_task, "info": "Starting."}
len_of_task = random.randint(3, 20) # take some random length of time
for n in range(len_of_task):
time.sleep(1) # sleep for a bit to simulate work
progress_dict[task_id] = {"progress": n + 1, "total": len_of_task}
progress_dict[task_id] = {"progress": len_of_task, "total": len_of_task, "info": "Done."}
return f"Some result for task {task_id}."
for result in parallel_progress_bar([_example_task_fn, _example_task_fn]):
print(result)
"""

def __call__(
self, task_progress_dict: DictProxy[TaskID, ProgressDict], task_id: TaskID
) -> OutT_co:
...


class AsyncTaskFn(Protocol[OutT_co]):
"""Protocol for a function that can be run in parallel and reports its progress.
The function should periodically set a dict containing info about it's progress in
the `progress_dict` at key `task_id`. For example:
```python
def _example_task_fn(progress_dict: DictProxy[TaskID, ProgressDict], task_id: TaskID):
async def _example_task_fn(progress_dict: DictProxy[TaskID, ProgressDict], task_id: TaskID):
import random
import time
progress_dict[task_id] = {"progress": 0, "total": len_of_task, "info": "Starting."}
len_of_task = random.randint(3, 20) # take some random length of time
for n in range(len_of_task):
time.sleep(1) # sleep for a bit to simulate work
await asyncio.sleep(1) # sleep for a bit to simulate work
progress_dict[task_id] = {"progress": n + 1, "total": len_of_task}
progress_dict[task_id] = {"progress": len_of_task, "total": len_of_task, "info": "Done."}
return f"Some result for task {task_id}."
for result in parallel_progress_bar([_example_task_fn, _example_task_fn]):
async for result in parallel_progress_bar([_example_task_fn, _example_task_fn]):
print(result)
"""

Expand All @@ -93,135 +59,17 @@ def __call__(
...


def parallel_progress_bar(
task_fns: list[TaskFn[OutT_co]],
async def async_progress_bar(
async_task_fns: list[AsyncTaskFn[OutT_co]],
task_descriptions: list[str] | None = None,
overall_progress_task_description: str = "[green]All jobs progress:",
n_workers: int = 8,
) -> Iterable[OutT_co]:
"""Adapted from the example at:
"""Parallel progress bar where each task is a coroutine.
Adapted from the example at:
https://www.deanmontgomery.com/2022/03/24/rich-progress-and-multiprocessing/
"""
if task_descriptions is None:
task_descriptions = [f"Task {i}" for i in range(len(task_fns))]

assert task_fns
assert len(task_fns) == len(task_descriptions)

futures: dict[TaskID, Future[OutT_co]] = {}
num_yielded_results: int = 0

# NOTE: Could also use a ProcessPoolExecutor here:
# executor = ProcessPoolExecutor(max_workers=n_workers)
executor = ThreadPoolExecutor(
max_workers=n_workers, thread_name_prefix="mila_sync_worker"
)
manager = multiprocessing.Manager()
progress = Progress(
SpinnerColumn(finished_text="[green]✓"),
TextColumn("[progress.description]{task.description}"),
MofNCompleteColumn(),
BarColumn(bar_width=None),
TaskProgressColumn(),
TimeRemainingColumn(),
TimeElapsedColumn(),
console=console,
transient=False,
refresh_per_second=10,
expand=False,
)
with executor, manager, progress:
# We share some state between our main process and our worker
# functions
_progress_dict: DictProxy[TaskID, ProgressDict] = manager.dict()

overall_progress_task = progress.add_task(
overall_progress_task_description,
visible=True,
start=True,
)

# iterate over the jobs we need to run
for task_name, task_fn in zip(task_descriptions, task_fns):
# NOTE: Could set visible=false so we don't have a lot of bars all at once.
task_id = progress.add_task(
description=task_name, visible=True, start=False
)
futures[task_id] = executor.submit(task_fn, _progress_dict, task_id)

_started_task_ids: list[TaskID] = []

# monitor the progress:
while num_yielded_results < len(futures):
total_progress = 0
total_task_lengths = 0

for (task_id, future), task_description in zip(
futures.items(), task_descriptions
):
if task_id not in _progress_dict:
# No progress reported yet by the task function.
continue

update_data = _progress_dict[task_id]
task_progress = update_data["progress"]
total = update_data["total"]

# Start the task in the progress bar when the first update is received.
# This allows us to have a nice per-task elapsed time instead of the
# same elapsed time in all tasks.
if task_id not in _started_task_ids and task_progress > 0:
# Note: calling `start_task` multiple times doesn't cause issues,
# but we're still doing this just to be explicit.
progress.start_task(task_id)
_started_task_ids.append(task_id)

# Update the progress bar for this task:
progress.update(
task_id=task_id,
completed=task_progress,
total=total,
description=task_description
+ (f" - {info}" if (info := update_data.get("info")) else ""),
visible=True,
)
total_progress += task_progress
total_task_lengths += total

if total_progress or total_task_lengths:
progress.update(
task_id=overall_progress_task,
completed=total_progress,
total=total_task_lengths,
visible=True,
)

next_task_id_to_yield, next_future_to_resolve = list(futures.items())[
num_yielded_results
]
if next_future_to_resolve.done():
logger.debug(f"Task {next_task_id_to_yield} is done, yielding result.")
yield next_future_to_resolve.result()
num_yielded_results += 1

try:
time.sleep(0.01)
except KeyboardInterrupt:
logger.info(
"Received keyboard interrupt, cancelling tasks that haven't started yet."
)
for future in futures.values():
future.cancel()
break


async def async_progress_bar(
async_task_fns: list[AsyncTaskFn[OutT_co]],
task_descriptions: list[str] | None = None,
overall_progress_task_description: str = "[green]All jobs progress:",
) -> Iterable[OutT_co]:
"""Like parallel_progress_bar, but where each task is a coroutine."""
if task_descriptions is None:
task_descriptions = [f"Task {i}" for i in range(len(async_task_fns))]

Expand Down
80 changes: 1 addition & 79 deletions tests/utils/test_parallel_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,9 @@
from milatools.cli.utils import removesuffix
from milatools.utils.parallel_progress import (
AsyncTaskFn,
DictProxy,
ProgressDict,
TaskFn,
TaskID,
async_progress_bar,
parallel_progress_bar,
)

from ..cli.common import xfails_on_windows
Expand All @@ -28,33 +25,8 @@
OutT = TypeVar("OutT")


def _task_fn(
task_progress_dict: DictProxy[TaskID, ProgressDict],
task_id: TaskID,
task_length: int,
result: OutT,
) -> OutT:
task_progress_dict[task_id] = {
"progress": 0,
"total": task_length,
"info": "Starting task.",
}

for n in range(task_length):
time.sleep(1.0) # sleep for a bit to simulate work
logger.debug(f"Task {task_id} is {n+1}/{task_length} done.")
task_progress_dict[task_id] = {"progress": n + 1, "total": task_length}

task_progress_dict[task_id] = {
"progress": task_length,
"total": task_length,
"info": "Done.",
}
return result


async def _async_task_fn(
task_progress_dict: DictProxy[TaskID, ProgressDict],
task_progress_dict: dict[TaskID, ProgressDict],
task_id: TaskID,
task_length: int,
result: OutT,
Expand All @@ -78,56 +50,6 @@ async def _async_task_fn(
return result


@xfails_on_windows(
raises=AssertionError,
reason="Output is weird on windows? something to do with linebreaks perhaps.",
strict=True,
)
def test_parallel_progress_bar(file_regression: FileRegressionFixture):
num_tasks = 4
task_length = 5
task_lengths = [task_length for _ in range(num_tasks)]
task_results = [i for i in range(num_tasks)]

task_fns: list[TaskFn[int]] = [
# pylance doesn't sees this as `Partial[int]` because it doesn't "save" the rest
# of the signature. Ignoring the type error here.
functools.partial(_task_fn, task_length=task_length, result=result) # type: ignore
for task_length, result in zip(task_lengths, task_results)
]

start_time = time.time()

console.begin_capture()

time_to_results: list[float] = []
results: list[int] = []
for result in parallel_progress_bar(task_fns, n_workers=num_tasks):
results.append(result)
time_to_result = time.time() - start_time
time_to_results.append(time_to_result)

assert results == task_results

all_output = console.end_capture()

# Remove the elapsed column since its values can vary a little bit between runs.
all_output_without_elapsed = "\n".join(
removesuffix(line, last_part).rstrip()
if (parts := line.split()) and (last_part := parts[-1]).count(":") == 2
else line
for line in all_output.splitlines()
)

file_regression.check(all_output_without_elapsed, encoding="utf-8")

total_time_seconds = time.time() - start_time

# All tasks sleep for `task_length` seconds, so the total time should still be
# roughly `task_length` seconds.
assert total_time_seconds < 2 * task_length


@xfails_on_windows(
raises=AssertionError,
reason="Output is weird on windows? something to do with linebreaks perhaps.",
Expand Down

This file was deleted.

0 comments on commit 4d9a5a9

Please sign in to comment.