Skip to content

Commit

Permalink
Merge pull request #20 from DataResponsibly/parallelize
Browse files Browse the repository at this point in the history
Add parallelization to ShaRP
  • Loading branch information
joaopfonseca authored Feb 20, 2024
2 parents 3fdcb28 + d9e931f commit dcd3063
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 10 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ project. ``ShaRP`` requires:
- numpy (>= 1.20.0)
- pandas (>= 1.3.5)
- scikit-learn (>= 1.2.0)
- ml-research (>= 0.4.2)

Some functions require Matplotlib (>= 2.2.3) for plotting.

Expand Down
3 changes: 2 additions & 1 deletion sharp/_min_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
NUMPY_MIN_VERSION = "1.20.0"
PANDAS_MIN_VERSION = "1.3.5"
SKLEARN_MIN_VERSION = "1.2.0"
TQDM_MIN_VERSION = "4.46.0"
MATPLOTLIB_MIN_VERSION = "2.2.3"

# The values are (version_spec, comma separated tags)
dependent_packages = {
"numpy": (NUMPY_MIN_VERSION, "install"),
"pandas": (PANDAS_MIN_VERSION, "install"),
"scikit-learn": (SKLEARN_MIN_VERSION, "install"),
"tqdm": (TQDM_MIN_VERSION, "install"),
"matplotlib": (MATPLOTLIB_MIN_VERSION, "optional, docs"),
# "ml-research": ("0.4.2", "optional"),
"pytest-cov": ("3.0.0", "tests"),
"flake8": ("3.8.2", "tests"),
"black": ("22.3", "tests"),
Expand Down
29 changes: 20 additions & 9 deletions sharp/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
from sklearn.base import BaseEstimator
from sklearn.utils import check_random_state
from .utils._parallelize import parallel_loop
from .utils import check_feature_names, check_inputs, check_measure, check_qoi
from .visualization._visualization import ShaRPViz

Expand Down Expand Up @@ -71,6 +72,8 @@ def __init__(
sample_size=None,
replace=False,
random_state=None,
n_jobs=1,
verbose=0,
**kwargs
):
self.qoi = qoi
Expand All @@ -79,6 +82,8 @@ def __init__(
self.sample_size = sample_size
self.replace = replace
self.random_state = random_state
self.n_jobs = n_jobs
self.verbose = verbose
self.plot = ShaRPViz(self)
self._X = kwargs["X"] if "X" in kwargs.keys() else None
self._y = kwargs["y"] if "y" in kwargs.keys() else None
Expand Down Expand Up @@ -125,9 +130,9 @@ def individual(self, sample, X=None, y=None, **kwargs):
else:
sample_size = X_.shape[0]

influences = []
for col_idx in range(len(self.feature_names_)):
cell_influence = self.measure_(
verbosity = kwargs["verbose"] if "verbose" in kwargs.keys() else self.verbose
influences = parallel_loop(
lambda col_idx: self.measure_(
row=sample,
col_idx=col_idx,
set_cols_idx=set_cols_idx,
Expand All @@ -136,8 +141,12 @@ def individual(self, sample, X=None, y=None, **kwargs):
sample_size=sample_size,
replace=self.replace,
rng=self._rng,
)
influences.append(cell_influence)
),
range(len(self.feature_names_)),
n_jobs=self.n_jobs,
progress_bar=verbosity,
)

return influences

def feature(self, feature, X=None, y=None, **kwargs):
Expand Down Expand Up @@ -185,10 +194,12 @@ def all(self, X=None, y=None, **kwargs):
"""
X_, y_ = check_inputs(X, y)

influences = []
for sample_idx in range(X_.shape[0]):
individual_influence = self.individual(sample_idx, X_, **kwargs)
influences.append(individual_influence)
influences = parallel_loop(
lambda sample_idx: self.individual(sample_idx, X_, verbose=False, **kwargs),
range(X_.shape[0]),
n_jobs=self.n_jobs,
progress_bar=self.verbose,
)

return np.array(influences)

Expand Down
112 changes: 112 additions & 0 deletions sharp/utils/_parallelize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""
Taken from the ``ml-research`` package.
Author: Joao Fonseca
"""

import os
import contextlib
from joblib import Parallel, delayed
import types


def _optional_import(module: str) -> types.ModuleType:
"""
Import an optional dependency.
Parameters
----------
module : str
The identifier for the backend. Either an entrypoint item registered
with importlib.metadata, "matplotlib", or a module name.
Returns
-------
types.ModuleType
The imported backend.
"""
# This function was adapted from the _load_backend function from the pandas.plotting
# source code.
import importlib

# Attempt an import of an optional dependency here and raise an ImportError if
# needed.
try:
module_ = importlib.import_module(module)
except ImportError:
mod = module.split(".")[0]
raise ImportError(f"{mod} is required to use this functionality.") from None

return module_


def _get_n_jobs(n_jobs):
"""Assign number of jobs to be assigned in parallel."""
max_jobs = os.cpu_count()
n_jobs = 1 if n_jobs is None else int(n_jobs)
if n_jobs > max_jobs:
raise RuntimeError("Cannot assign more jobs than the number of CPUs.")
elif n_jobs == -1:
return max_jobs
else:
return n_jobs


@contextlib.contextmanager
def _tqdm_joblib(tqdm_object):
"""
Context manager to patch joblib to report into tqdm progress bar given as argument.
"""

def tqdm_print_progress(self):
if self.n_completed_tasks > tqdm_object.n:
n_completed = self.n_completed_tasks - tqdm_object.n
tqdm_object.update(n=n_completed)

original_print_progress = Parallel.print_progress
Parallel.print_progress = tqdm_print_progress

try:
yield tqdm_object
finally:
Parallel.print_progress = original_print_progress
tqdm_object.close()


def parallel_loop(
function, iterable, n_jobs=None, progress_bar=False, description=None
):
"""
Parallelize a loop and optionally add a progress bar.
.. warning::
The progress bar tracks job starts, not completions.
Parameters
----------
function : function
The function to which the elements in the iterable will passed to. Must have a
single parameter.
iterable : iterable
Object to be looped over.
n_jobs : int, default=None
Number of jobs to run in parallel. None means 1 unless in a
joblib.parallel_backend context. -1 means using all processors.
Returns
-------
output : list
The list with the results produced using ``function`` across ``iterable``.
"""
n_jobs = _get_n_jobs(n_jobs)

if progress_bar:
tqdm = _optional_import("tqdm.auto").tqdm

with _tqdm_joblib(tqdm(desc=description, total=len(iterable))) as progress_bar:
return Parallel(n_jobs=n_jobs)(delayed(function)(i) for i in iterable)

else:
return Parallel(n_jobs=n_jobs)(delayed(function)(i) for i in iterable)

0 comments on commit dcd3063

Please sign in to comment.