Skip to content

Commit

Permalink
update type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouzaida committed Sep 4, 2023
1 parent 0ab2386 commit eeb8930
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions mmengine/utils/progressbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import Iterable
from multiprocessing import Pool
from shutil import get_terminal_size
from typing import Callable, Sequence, Union
from typing import Callable, Sequence

from .timer import Timer

Expand Down Expand Up @@ -88,7 +88,7 @@ def update(self, num_tasks: int = 1):


def track_progress(func: Callable,
tasks: Union[tuple, Sequence],
tasks: Sequence,
bar_width: int = 50,
file=sys.stdout,
**kwargs):
Expand All @@ -98,8 +98,10 @@ def track_progress(func: Callable,
Args:
func (callable): The function to be applied to each task.
tasks (tuple or Size): A tuple contains two elements
(tasks, total num) or a sequence object.
tasks (Sequence): If tasks is a tuple, it must contain two elements,
the first being the tasks to be completed and the other being the
number of tasks. If it is not a tuple, it represents the tasks to
be completed.
bar_width (int): Width of progress bar.
Returns:
Expand Down Expand Up @@ -138,7 +140,7 @@ def init_pool(process_num, initializer=None, initargs=None):


def track_parallel_progress(func: Callable,
tasks: Union[tuple, Sequence],
tasks: Sequence,
nproc: int,
initializer: Callable = None,
initargs: tuple = None,
Expand All @@ -154,8 +156,10 @@ def track_parallel_progress(func: Callable,
Args:
func (callable): The function to be applied to each task.
tasks (tuple or Size): A tuple contains two elements
(tasks, total num) or a sequence object.
tasks (Sequence): If tasks is a tuple, it must contain two elements,
the first being the tasks to be completed and the other being the
number of tasks. If it is not a tuple, it represents the tasks to
be completed.
nproc (int): Process (worker) number.
initializer (None or callable): Refer to :class:`multiprocessing.Pool`
for details.
Expand Down Expand Up @@ -208,17 +212,17 @@ def track_parallel_progress(func: Callable,
return results


def track_iter_progress(tasks: Union[tuple, Sequence],
bar_width: int = 50,
file=sys.stdout):
def track_iter_progress(tasks: Sequence, bar_width: int = 50, file=sys.stdout):
"""Track the progress of tasks iteration or enumeration with a progress
bar.
Tasks are yielded with a simple for-loop.
Args:
tasks (tuple or Size): A tuple contains two elements
(tasks, total num) or a sequence object.
tasks (Sequence): If tasks is a tuple, it must contain two elements,
the first being the tasks to be completed and the other being the
number of tasks. If it is not a tuple, it represents the tasks to
be completed.
bar_width (int): Width of progress bar.
Yields:
Expand Down

0 comments on commit eeb8930

Please sign in to comment.