Skip to content

Commit

Permalink
More tutorials, bugfixing (#30)
Browse files Browse the repository at this point in the history
* cleanup

* bump version

* cleanup

* cleanup

* cleanup

* improvements

* debug

* bugfixing

* debug

* debug
  • Loading branch information
bcebere authored Dec 14, 2022
1 parent 74ce6e7 commit 55bc11d
Show file tree
Hide file tree
Showing 25 changed files with 298 additions and 28 deletions.
11 changes: 11 additions & 0 deletions src/hyperimpute/plugins/core/base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class Plugin(Serializable, metaclass=ABCMeta):
def __init__(self) -> None:
super().__init__()

self.drop_consts = []

@staticmethod
@abstractmethod
def hyperparameter_space(*args: Any, **kwargs: Any) -> List[Params]:
Expand Down Expand Up @@ -119,6 +121,13 @@ def fit_predict(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> pd.DataFram

def fit(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> Any:
X = cast.to_dataframe(X)

for col in X.columns:
if len(X.loc[X[col].notna(), col].unique()) <= 1:
self.drop_consts.append(col)

X = X.drop(columns=self.drop_consts)
self.columns = X.columns
return self._fit(X, *args, **kwargs)

@abstractmethod
Expand All @@ -127,6 +136,7 @@ def _fit(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> "Plugin":

def transform(self, X: pd.DataFrame) -> pd.DataFrame:
X = cast.to_dataframe(X)
X = X.drop(columns=self.drop_consts)
return pd.DataFrame(self._transform(X))

@abstractmethod
Expand All @@ -135,6 +145,7 @@ def _transform(self, X: pd.DataFrame) -> pd.DataFrame:

def predict(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> pd.DataFrame:
X = cast.to_dataframe(X)
X = X.drop(columns=self.drop_consts)
return pd.DataFrame(self._predict(X, *args, *kwargs))

@abstractmethod
Expand Down
5 changes: 4 additions & 1 deletion src/hyperimpute/plugins/imputers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# hyperimpute absolute
import hyperimpute.plugins.core.base_plugin as plugin
from hyperimpute.utils.distributions import enable_reproducible_results


class ImputerPlugin(_BaseImputer, plugin.Plugin):
Expand All @@ -24,8 +25,10 @@ class ImputerPlugin(_BaseImputer, plugin.Plugin):
"""

def __init__(self, random_state: int = 0) -> None:
super().__init__()
_BaseImputer.__init__(self)
plugin.Plugin.__init__(self)

enable_reproducible_results(random_state)
self.random_state = random_state

@staticmethod
Expand Down
11 changes: 9 additions & 2 deletions src/hyperimpute/plugins/prediction/classifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import hyperimpute.plugins.core.base_plugin as plugin
import hyperimpute.plugins.prediction.base as prediction_base
import hyperimpute.plugins.utils.cast as cast
from hyperimpute.utils.distributions import enable_reproducible_results
from hyperimpute.utils.tester import Eval


Expand All @@ -26,17 +27,23 @@ class ClassifierPlugin(
If any method implementation is missing, the class constructor will fail.
"""

def __init__(self, **kwargs: Any) -> None:
def __init__(self, random_state: int = 0, **kwargs: Any) -> None:
self.args = kwargs
self.random_state = random_state

super().__init__()
enable_reproducible_results(self.random_state)

ClassifierMixin.__init__(self)
BaseEstimator.__init__(self)
prediction_base.PredictionPlugin.__init__(self)

@staticmethod
def subtype() -> str:
return "classifier"

def fit(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> plugin.Plugin:
X = cast.to_dataframe(X)
enable_reproducible_results(self.random_state)

if len(args) == 0:
raise RuntimeError("Please provide the training labels as well")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
random_strength: float = 1,
**kwargs: Any
) -> None:
super().__init__(**kwargs)
super().__init__(random_state=random_state, **kwargs)
if model is not None:
self.model = model
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
hyperparam_search_iterations: Optional[int] = None,
**kwargs: Any
) -> None:
super().__init__(**kwargs)
super().__init__(random_state=random_state, **kwargs)
if hyperparam_search_iterations:
n_estimators = int(hyperparam_search_iterations)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
model: Any = None,
**kwargs: Any
) -> None:
super().__init__(**kwargs)
super().__init__(random_state=random_state, **kwargs)
if model is not None:
self.model = model
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
random_state: int = 0,
**kwargs: Any
) -> None:
super().__init__(**kwargs)
super().__init__(random_state=random_state, **kwargs)
if model is not None:
self.model = model
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
hyperparam_search_iterations: Optional[int] = None,
**kwargs: Any
) -> None:
super().__init__(**kwargs)
super().__init__(random_state=random_state, **kwargs)
if model is not None:
self.model = model
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def __init__(
hyperparam_search_iterations: Optional[int] = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
super().__init__(random_state=random_state, **kwargs)

enable_reproducible_results(random_state)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
hyperparam_search_iterations: Optional[int] = None,
**kwargs: Any
) -> None:
super().__init__(**kwargs)
super().__init__(random_state=random_state, **kwargs)
if hyperparam_search_iterations:
n_estimators = int(hyperparam_search_iterations)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
random_state: int = 0,
**kwargs: Any
) -> None:
super().__init__(**kwargs)
super().__init__(random_state=random_state, **kwargs)

if hyperparam_search_iterations:
max_iter = int(hyperparam_search_iterations) * 100
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
hyperparam_search_iterations: Optional[int] = None,
**kwargs: Any
) -> None:
super().__init__(**kwargs)
super().__init__(random_state=random_state, **kwargs)
if hyperparam_search_iterations:
n_estimators = int(hyperparam_search_iterations)

Expand Down
4 changes: 3 additions & 1 deletion src/hyperimpute/plugins/prediction/regression/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def __init__(
self,
**kwargs: Any,
) -> None:
super().__init__()
RegressorMixin.__init__(self)
BaseEstimator.__init__(self)
prediction_base.PredictionPlugin.__init__(self)

self.args = kwargs

Expand Down
2 changes: 1 addition & 1 deletion src/hyperimpute/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = "0.1.10"
__version__ = "0.1.11"
MAJOR_VERSION = "0.1"
2 changes: 1 addition & 1 deletion tests/imputers/test_em.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_compare_methods_perf(
) -> None:
np.random.seed(0)

n = 10
n = 50
p = 4

mean = np.repeat(0, p)
Expand Down
2 changes: 1 addition & 1 deletion tests/imputers/test_gain.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_compare_methods_perf(
) -> None:
np.random.seed(0)

n = 10
n = 50
p = 4

mean = np.repeat(0, p)
Expand Down
9 changes: 4 additions & 5 deletions tests/imputers/test_hyperimpute.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ def test_hyperimpute_plugin_fit_transform(test_plugin: ImputerPlugin) -> None:
[[1, 1, 1, 1], [np.nan, np.nan, np.nan, np.nan], [3, 3, 9, 9], [2, 2, 2, 2]]
)
)

assert not np.all(np.isnan(res))
assert not np.any(np.isnan(res))


@pytest.mark.parametrize("test_plugin", [from_api(), from_module(), from_serde()])
Expand All @@ -92,7 +91,7 @@ def test_compare_methods_perf(
) -> None:
np.random.seed(0)

n = 20
n = 50
p = 4

mean = np.repeat(0, p)
Expand Down Expand Up @@ -127,7 +126,7 @@ def test_compare_optimizers(

np.random.seed(0)

n = 20
n = 50
p = 4

mean = np.repeat(0, p)
Expand Down Expand Up @@ -166,7 +165,7 @@ def test_imputation_order(

np.random.seed(0)

n = 20
n = 50
p = 4

mean = np.repeat(0, p)
Expand Down
2 changes: 1 addition & 1 deletion tests/imputers/test_ice.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_compare_methods_perf(
) -> None:
np.random.seed(0)

n = 10
n = 50
p = 4

mean = np.repeat(0, p)
Expand Down
2 changes: 1 addition & 1 deletion tests/imputers/test_mice.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_compare_methods_perf(
) -> None:
np.random.seed(0)

n = 100
n = 50
p = 4

mean = np.repeat(0, p)
Expand Down
2 changes: 1 addition & 1 deletion tests/imputers/test_missforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_compare_methods_perf(
) -> None:
np.random.seed(0)

n = 20
n = 50
p = 4

mean = np.repeat(0, p)
Expand Down
2 changes: 1 addition & 1 deletion tests/imputers/test_miwae.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_compare_methods_perf(
) -> None:
np.random.seed(0)

n = 100
n = 50
p = 4

mean = np.repeat(0, p)
Expand Down
2 changes: 1 addition & 1 deletion tests/imputers/test_sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_compare_methods_perf(
) -> None:
np.random.seed(0)

n = 10
n = 50
p = 4

mean = np.repeat(0, p)
Expand Down
2 changes: 1 addition & 1 deletion tests/imputers/test_sklearn_ice.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_compare_methods_perf(
) -> None:
np.random.seed(0)

n = 10
n = 50
p = 4

mean = np.repeat(0, p)
Expand Down
2 changes: 1 addition & 1 deletion tests/imputers/test_softimpute.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_compare_methods_perf(
) -> None:
np.random.seed(0)

n = 20
n = 50
p = 4

mean = np.repeat(0, p)
Expand Down
Loading

0 comments on commit 55bc11d

Please sign in to comment.