-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #20 from DataResponsibly/parallelize
Add parallelization to ShaRP
- Loading branch information
Showing
4 changed files
with
135 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
""" | ||
Taken from the ``ml-research`` package. | ||
Author: Joao Fonseca | ||
""" | ||
|
||
import os | ||
import contextlib | ||
from joblib import Parallel, delayed | ||
import types | ||
|
||
|
||
def _optional_import(module: str) -> types.ModuleType: | ||
""" | ||
Import an optional dependency. | ||
Parameters | ||
---------- | ||
module : str | ||
The identifier for the backend. Either an entrypoint item registered | ||
with importlib.metadata, "matplotlib", or a module name. | ||
Returns | ||
------- | ||
types.ModuleType | ||
The imported backend. | ||
""" | ||
# This function was adapted from the _load_backend function from the pandas.plotting | ||
# source code. | ||
import importlib | ||
|
||
# Attempt an import of an optional dependency here and raise an ImportError if | ||
# needed. | ||
try: | ||
module_ = importlib.import_module(module) | ||
except ImportError: | ||
mod = module.split(".")[0] | ||
raise ImportError(f"{mod} is required to use this functionality.") from None | ||
|
||
return module_ | ||
|
||
|
||
def _get_n_jobs(n_jobs): | ||
"""Assign number of jobs to be assigned in parallel.""" | ||
max_jobs = os.cpu_count() | ||
n_jobs = 1 if n_jobs is None else int(n_jobs) | ||
if n_jobs > max_jobs: | ||
raise RuntimeError("Cannot assign more jobs than the number of CPUs.") | ||
elif n_jobs == -1: | ||
return max_jobs | ||
else: | ||
return n_jobs | ||
|
||
|
||
@contextlib.contextmanager | ||
def _tqdm_joblib(tqdm_object): | ||
""" | ||
Context manager to patch joblib to report into tqdm progress bar given as argument. | ||
""" | ||
|
||
def tqdm_print_progress(self): | ||
if self.n_completed_tasks > tqdm_object.n: | ||
n_completed = self.n_completed_tasks - tqdm_object.n | ||
tqdm_object.update(n=n_completed) | ||
|
||
original_print_progress = Parallel.print_progress | ||
Parallel.print_progress = tqdm_print_progress | ||
|
||
try: | ||
yield tqdm_object | ||
finally: | ||
Parallel.print_progress = original_print_progress | ||
tqdm_object.close() | ||
|
||
|
||
def parallel_loop( | ||
function, iterable, n_jobs=None, progress_bar=False, description=None | ||
): | ||
""" | ||
Parallelize a loop and optionally add a progress bar. | ||
.. warning:: | ||
The progress bar tracks job starts, not completions. | ||
Parameters | ||
---------- | ||
function : function | ||
The function to which the elements in the iterable will passed to. Must have a | ||
single parameter. | ||
iterable : iterable | ||
Object to be looped over. | ||
n_jobs : int, default=None | ||
Number of jobs to run in parallel. None means 1 unless in a | ||
joblib.parallel_backend context. -1 means using all processors. | ||
Returns | ||
------- | ||
output : list | ||
The list with the results produced using ``function`` across ``iterable``. | ||
""" | ||
n_jobs = _get_n_jobs(n_jobs) | ||
|
||
if progress_bar: | ||
tqdm = _optional_import("tqdm.auto").tqdm | ||
|
||
with _tqdm_joblib(tqdm(desc=description, total=len(iterable))) as progress_bar: | ||
return Parallel(n_jobs=n_jobs)(delayed(function)(i) for i in iterable) | ||
|
||
else: | ||
return Parallel(n_jobs=n_jobs)(delayed(function)(i) for i in iterable) |