diff --git a/mmengine/utils/progressbar.py b/mmengine/utils/progressbar.py index f45f92d4b9..36172f04dd 100644 --- a/mmengine/utils/progressbar.py +++ b/mmengine/utils/progressbar.py @@ -3,7 +3,7 @@ from collections.abc import Iterable from multiprocessing import Pool from shutil import get_terminal_size -from typing import Callable, Sequence, Union +from typing import Callable, Sequence from .timer import Timer @@ -88,7 +88,7 @@ def update(self, num_tasks: int = 1): def track_progress(func: Callable, - tasks: Union[tuple, Sequence], + tasks: Sequence, bar_width: int = 50, file=sys.stdout, **kwargs): @@ -98,8 +98,10 @@ def track_progress(func: Callable, Args: func (callable): The function to be applied to each task. - tasks (tuple or Size): A tuple contains two elements - (tasks, total num) or a sequence object. + tasks (Sequence): If tasks is a tuple, it must contain two elements, + the first being the tasks to be completed and the other being the + number of tasks. If it is not a tuple, it represents the tasks to + be completed. bar_width (int): Width of progress bar. Returns: @@ -138,7 +140,7 @@ def init_pool(process_num, initializer=None, initargs=None): def track_parallel_progress(func: Callable, - tasks: Union[tuple, Sequence], + tasks: Sequence, nproc: int, initializer: Callable = None, initargs: tuple = None, @@ -154,8 +156,10 @@ def track_parallel_progress(func: Callable, Args: func (callable): The function to be applied to each task. - tasks (tuple or Size): A tuple contains two elements - (tasks, total num) or a sequence object. + tasks (Sequence): If tasks is a tuple, it must contain two elements, + the first being the tasks to be completed and the other being the + number of tasks. If it is not a tuple, it represents the tasks to + be completed. nproc (int): Process (worker) number. initializer (None or callable): Refer to :class:`multiprocessing.Pool` for details. @@ -208,17 +212,17 @@ def track_parallel_progress(func: Callable, return results -def track_iter_progress(tasks: Union[tuple, Sequence], - bar_width: int = 50, - file=sys.stdout): +def track_iter_progress(tasks: Sequence, bar_width: int = 50, file=sys.stdout): """Track the progress of tasks iteration or enumeration with a progress bar. Tasks are yielded with a simple for-loop. Args: - tasks (tuple or Size): A tuple contains two elements - (tasks, total num) or a sequence object. + tasks (Sequence): If tasks is a tuple, it must contain two elements, + the first being the tasks to be completed and the other being the + number of tasks. If it is not a tuple, it represents the tasks to + be completed. bar_width (int): Width of progress bar. Yields: