diff --git a/sharp/base.py b/sharp/base.py index c1dd13c..3b577a1 100644 --- a/sharp/base.py +++ b/sharp/base.py @@ -25,7 +25,7 @@ class ShaRP(BaseEstimator): ---------- estimator : ML classifier - qoi : Quantity of interest + qoi : Quantity of interest, default: "rank" measure : measure used to estimate feature contributions (unary, set, banzhaf, etc.) @@ -85,7 +85,13 @@ def fit(self, X, y=None): self._rng = check_random_state(self.random_state) - if isinstance(self.qoi, str): + if self.qoi is None: + self.qoi_ = check_qoi( + "rank", + target_function=self.target_function, + X=X_, + ) + elif isinstance(self.qoi, str): self.qoi_ = check_qoi( self.qoi, target_function=self.target_function, diff --git a/sharp/tests/test_base.py b/sharp/tests/test_base.py new file mode 100644 index 0000000..004b9e9 --- /dev/null +++ b/sharp/tests/test_base.py @@ -0,0 +1,16 @@ +""" +Tests code in `base.py`. +""" + +import numpy as np +from sharp import ShaRP + + +def test_default_qoi(): + """ + Reproduces issue #44: Defining ShaRP without an explicit QoI raises an AttributeError + """ + _X = np.random.random((100, 3)) + sharp = ShaRP(target_function=lambda x: x.sum(axis=1)) + sharp.fit(_X) + sharp.all(_X[:5])