diff --git a/audeer/core/tqdm.py b/audeer/core/tqdm.py index dd3eff0..c32dca2 100644 --- a/audeer/core/tqdm.py +++ b/audeer/core/tqdm.py @@ -1,4 +1,6 @@ -from typing import Sequence +import threading +import time +import typing from tqdm import tqdm @@ -42,11 +44,12 @@ def format_display_message(text: str, pbar: bool = False) -> str: def progress_bar( - iterable: Sequence = None, + iterable: typing.Sequence = None, *, total: int = None, desc: str = None, disable: bool = False, + maximum_refresh_time: float = None, ) -> tqdm: r"""Progress bar with optional text on the right. @@ -80,6 +83,11 @@ def progress_bar( total: total number of iterations desc: text shown on the right of the progress bar disable: don't show the display bar + maximum_refresh_time: refresh the progress bar + at least every ``maximum_refresh_time`` seconds, + using another thread. + If ``None``, + no refreshing is enforced Returns: progress bar object @@ -87,8 +95,9 @@ def progress_bar( """ if desc is None: desc = "" - return tqdm( + return tqdm_wrapper( iterable=iterable, + maximum_refresh_time=maximum_refresh_time, ncols=config.TQDM_COLUMNS, bar_format=config.TQDM_FORMAT, total=total, @@ -96,3 +105,45 @@ def progress_bar( desc=format_display_message(desc, pbar=True), leave=config.TQDM_LEAVE, ) + + +def tqdm_wrapper( + iterable: typing.Sequence, + maximum_refresh_time: float, + *args, + **kwargs, +) -> tqdm: + r"""Tqdm progress bar wrapper to enforce update once a second. + + When using tqdm with large time durations + between single steps of the iteration, + it will not automatically update the elapsed time, + but needs to be forced, + see https://github.com/tqdm/tqdm/issues/861#issuecomment-2197893883. + + Args: + iterable: sequence to iterate through + maximum_refresh_time: refresh the progress bar + at least every ``maximum_refresh_time`` seconds, + using another thread. + If ``None``, + no refreshing is enforced + args: arguments passed on to ``tqdm`` + kwargs: keyword arguments passed on to ``tqdm`` + + Returns: + progress bar object + + """ + pbar = tqdm(iterable, *args, **kwargs) + + def refresh(): + while not pbar.disable: + time.sleep(maximum_refresh_time) + pbar.refresh() + + if maximum_refresh_time is not None: + thread = threading.Thread(target=refresh, daemon=True) + thread.start() + + return pbar diff --git a/tests/test_tqdm.py b/tests/test_tqdm.py index 9d992a0..cbcd68e 100644 --- a/tests/test_tqdm.py +++ b/tests/test_tqdm.py @@ -37,3 +37,15 @@ def test_progress_bar(): pbar = audeer.progress_bar([0.1]) for step in pbar: time.sleep(step) + + +def test_progress_bar_update(): + r"""Ensure progress bar is refreshed. + + If the progress bar has to wait for a long time + until it would get updated, + we enforce an update by a given time. + + """ + for _ in audeer.progress_bar(range(2), maximum_refresh_time=0.01): + time.sleep(0.05)