-
Notifications
You must be signed in to change notification settings - Fork 0
/
util.py
43 lines (33 loc) · 1.24 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from typing import Any
import joblib
from tqdm.auto import tqdm
class EasyDict(dict):
"""
Convenience class that behaves like a dict but allows access with the attribute syntax.
"""
def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError:
raise AttributeError(name)
def __setattr__(self, name: str, value: Any) -> None:
self[name] = value
def __delattr__(self, name: str) -> None:
del self[name]
class ProgressParallel(joblib.Parallel):
"""joblib Parallel with additional args to pass to tqdm progress bar"""
def __init__(self, use_tqdm=True, total=None, leave=True, *args, **kwargs):
self._use_tqdm = use_tqdm
self._total = total
self._leave = leave
super().__init__(*args, **kwargs)
def __call__(self, *args, **kwargs):
with tqdm(
disable=not self._use_tqdm, total=self._total, leave=self._leave
) as self._pbar:
return super(ProgressParallel, self).__call__(*args, **kwargs)
def print_progress(self):
if self._total is None:
self._pbar.total = self.n_dispatched_tasks
self._pbar.n = self.n_completed_tasks
self._pbar.refresh()