diff --git a/src/hyperimpute/plugins/core/base_plugin.py b/src/hyperimpute/plugins/core/base_plugin.py index 6d74f94..bb6320d 100644 --- a/src/hyperimpute/plugins/core/base_plugin.py +++ b/src/hyperimpute/plugins/core/base_plugin.py @@ -38,8 +38,6 @@ class Plugin(Serializable, metaclass=ABCMeta): def __init__(self) -> None: super().__init__() - self.drop_consts = [] - @staticmethod @abstractmethod def hyperparameter_space(*args: Any, **kwargs: Any) -> List[Params]: @@ -122,11 +120,6 @@ def fit_predict(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> pd.DataFram def fit(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> Any: X = cast.to_dataframe(X) - for col in X.columns: - if len(X.loc[X[col].notna(), col].unique()) <= 1: - self.drop_consts.append(col) - - X = X.drop(columns=self.drop_consts) self.columns = X.columns return self._fit(X, *args, **kwargs) @@ -136,7 +129,6 @@ def _fit(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> "Plugin": def transform(self, X: pd.DataFrame) -> pd.DataFrame: X = cast.to_dataframe(X) - X = X.drop(columns=self.drop_consts) return pd.DataFrame(self._transform(X)) @abstractmethod @@ -145,7 +137,6 @@ def _transform(self, X: pd.DataFrame) -> pd.DataFrame: def predict(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> pd.DataFrame: X = cast.to_dataframe(X) - X = X.drop(columns=self.drop_consts) return pd.DataFrame(self._predict(X, *args, *kwargs)) @abstractmethod diff --git a/src/hyperimpute/plugins/prediction/classifiers/plugin_lgbm.py b/src/hyperimpute/plugins/prediction/classifiers/plugin_lgbm.py deleted file mode 100644 index 3909acc..0000000 --- a/src/hyperimpute/plugins/prediction/classifiers/plugin_lgbm.py +++ /dev/null @@ -1,127 +0,0 @@ -# stdlib -from typing import Any, List - -# third party -import lightgbm as lgbm -import pandas as pd - -# hyperimpute absolute -import hyperimpute.plugins.core.params as params -import hyperimpute.plugins.prediction.classifiers.base as base -import hyperimpute.utils.serialization as serialization - - -class LightGBMPlugin(base.ClassifierPlugin): - """Classification plugin based on LightGBM. - - Method: - Gradient boosting is a machine learning technique for regression and classification problems, which produces a prediction model in the form of an ensemble of weak prediction models, typically decision trees. When a decision tree is the weak learner, the resulting algorithm is called gradient boosted trees, which usually outperforms random forest. - - Args: - n_estimators: int - The number of boosting stages to perform. Gradient boosting is fairly robust to over-fitting so a large number usually results in better performance. - learning_rate: float - Learning rate shrinks the contribution of each tree by learning_rate. There is a trade-off between learning_rate and n_estimators. - max_depth: int - The maximum depth of the individual regression estimators. - boosting_type: str - ‘gbdt’, traditional Gradient Boosting Decision Tree. ‘dart’, Dropouts meet Multiple Additive Regression Trees. ‘goss’, Gradient-based One-Side Sampling. ‘rf’, Random Forest. - objective:str - Specify the learning task and the corresponding learning objective or a custom objective function to be used. - reg_lambda:float - L2 regularization term on weights. - reg_alpha:float - L1 regularization term on weights. - colsample_bytree:float - Subsample ratio of columns when constructing each tree. - subsample:float - Subsample ratio of the training instance. - num_leaves:int - Maximum tree leaves for base learners. - min_child_samples:int - Minimum sum of instance weight (hessian) needed in a child (leaf). - - Example: - >>> from hyperimpute.plugins.prediction import Predictions - >>> plugin = Predictions(category="classifiers").get("lgbm") - >>> from sklearn.datasets import load_iris - >>> X, y = load_iris(return_X_y=True) - >>> plugin.fit_predict(X, y) # returns the probabilities for each class - """ - - def __init__( - self, - n_estimators: int = 100, - boosting_type: str = "gbdt", - learning_rate: float = 1e-2, - max_depth: int = 6, - reg_lambda: float = 1e-3, - reg_alpha: float = 1e-3, - colsample_bytree: float = 0.1, - subsample: float = 0.1, - num_leaves: int = 31, - min_child_samples: int = 1, - model: Any = None, - random_state: int = 0, - **kwargs: Any - ) -> None: - super().__init__(random_state=random_state, **kwargs) - if model is not None: - self.model = model - return - - self.model = lgbm.LGBMClassifier( - n_estimators=n_estimators, - boosting_type=boosting_type, - learning_rate=learning_rate, - max_depth=max_depth, - reg_lambda=reg_lambda, - reg_alpha=reg_alpha, - colsample_bytree=colsample_bytree, - subsample=subsample, - num_leaves=num_leaves, - min_child_samples=min_child_samples, - random_state=random_state, - ) - - @staticmethod - def name() -> str: - return "lgbm" - - @staticmethod - def hyperparameter_space(*args: Any, **kwargs: Any) -> List[params.Params]: - return [ - params.Integer("n_estimators", 5, 100), - params.Float("reg_lambda", 1e-3, 1e3), - params.Float("reg_alpha", 1e-3, 1e3), - params.Float("colsample_bytree", 0.1, 1.0), - params.Float("subsample", 0.1, 1.0), - params.Integer("num_leaves", 31, 256), - params.Integer("min_child_samples", 1, 500), - params.Categorical("learning_rate", [1e-4, 1e-3, 1e-2, 2e-4]), - params.Integer("max_depth", 1, 6), - ] - - def _fit(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> "LightGBMPlugin": - self.model.fit(X, *args, **kwargs) - return self - - def _predict(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> pd.DataFrame: - return self.model.predict(X, *args, **kwargs) - - def _predict_proba( - self, X: pd.DataFrame, *args: Any, **kwargs: Any - ) -> pd.DataFrame: - return self.model.predict_proba(X, *args, **kwargs) - - def save(self) -> bytes: - return serialization.save(self.model) - - @classmethod - def load(cls, buff: bytes) -> "LightGBMPlugin": - model = serialization.load(buff) - - return cls(model=model) - - -plugin = LightGBMPlugin diff --git a/src/hyperimpute/plugins/prediction/classifiers/plugin_random_forest.py b/src/hyperimpute/plugins/prediction/classifiers/plugin_random_forest.py index 12d5e59..392d89a 100644 --- a/src/hyperimpute/plugins/prediction/classifiers/plugin_random_forest.py +++ b/src/hyperimpute/plugins/prediction/classifiers/plugin_random_forest.py @@ -41,7 +41,7 @@ class RandomForestPlugin(base.ClassifierPlugin): """ criterions = ["gini", "entropy"] - features = ["auto", "sqrt", "log2"] + features = ["sqrt", "log2", None] def __init__( self, @@ -97,11 +97,13 @@ def _fit(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> "RandomForestPlugi return self def _predict(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> pd.DataFrame: + X = np.asarray(X) return self.model.predict(X, *args, **kwargs) def _predict_proba( self, X: pd.DataFrame, *args: Any, **kwargs: Any ) -> pd.DataFrame: + X = np.asarray(X) return self.model.predict_proba(X, *args, **kwargs) diff --git a/src/hyperimpute/plugins/prediction/regression/plugin_lgbm_regressor.py b/src/hyperimpute/plugins/prediction/regression/plugin_lgbm_regressor.py deleted file mode 100644 index 4bb1fa5..0000000 --- a/src/hyperimpute/plugins/prediction/regression/plugin_lgbm_regressor.py +++ /dev/null @@ -1,112 +0,0 @@ -# stdlib -from typing import Any, List - -# third party -import lightgbm as lgbm -import pandas as pd - -# hyperimpute absolute -import hyperimpute.plugins.core.params as params -import hyperimpute.plugins.prediction.regression.base as base - - -class LGBMRegressorPlugin(base.RegressionPlugin): - """Regression plugin based on LGBMRegressor. - - Method: - Gradient boosting is a machine learning technique for regression and classification problems, which produces a prediction model in the form of an ensemble of weak prediction models, typically decision trees. When a decision tree is the weak learner, the resulting algorithm is called gradient boosted trees, which usually outperforms random forest. - - Args: - n_estimators: int - The number of boosting stages to perform. Gradient boosting is fairly robust to over-fitting so a large number usually results in better performance. - learning_rate: float - Learning rate shrinks the contribution of each tree by learning_rate. There is a trade-off between learning_rate and n_estimators. - max_depth: int - The maximum depth of the individual regression estimators. - boosting_type: str - ‘gbdt’, traditional Gradient Boosting Decision Tree. ‘dart’, Dropouts meet Multiple Additive Regression Trees. ‘goss’, Gradient-based One-Side Sampling. ‘rf’, Random Forest. - objective:str - Specify the learning task and the corresponding learning objective or a custom objective function to be used. - reg_lambda:float - L2 regularization term on weights. - reg_alpha:float - L1 regularization term on weights. - colsample_bytree:float - Subsample ratio of columns when constructing each tree. - subsample:float - Subsample ratio of the training instance. - num_leaves:int - Maximum tree leaves for base learners. - min_child_samples:int - Minimum sum of instance weight (hessian) needed in a child (leaf). - - Example: - >>> from hyperimpute.plugins.prediction import Predictions - >>> plugin = Predictions(category="classifiers").get("lgbm") - >>> from sklearn.datasets import load_iris - >>> X, y = load_iris(return_X_y=True) - >>> plugin.fit_predict(X, y) # returns the probabilities for each class - """ - - def __init__( - self, - n_estimators: int = 100, - boosting_type: str = "gbdt", - learning_rate: float = 1e-2, - max_depth: int = 6, - reg_lambda: float = 1e-3, - reg_alpha: float = 1e-3, - colsample_bytree: float = 0.1, - subsample: float = 0.1, - num_leaves: int = 31, - min_child_samples: int = 1, - model: Any = None, - random_state: int = 0, - **kwargs: Any - ) -> None: - super().__init__(**kwargs) - if model is not None: - self.model = model - return - - self.model = lgbm.LGBMRegressor( - n_estimators=n_estimators, - boosting_type=boosting_type, - learning_rate=learning_rate, - max_depth=max_depth, - reg_lambda=reg_lambda, - reg_alpha=reg_alpha, - colsample_bytree=colsample_bytree, - subsample=subsample, - num_leaves=num_leaves, - min_child_samples=min_child_samples, - random_state=random_state, - ) - - @staticmethod - def name() -> str: - return "lgbm_regressor" - - @staticmethod - def hyperparameter_space(*args: Any, **kwargs: Any) -> List[params.Params]: - return [ - params.Integer("n_estimators", 5, 100), - params.Float("reg_lambda", 1e-3, 1e3), - params.Float("reg_alpha", 1e-3, 1e3), - params.Float("colsample_bytree", 0.1, 1.0), - params.Float("subsample", 0.1, 1.0), - params.Integer("num_leaves", 31, 256), - params.Integer("min_child_samples", 1, 500), - params.Categorical("learning_rate", [1e-4, 1e-3, 1e-2, 2e-4]), - params.Integer("max_depth", 1, 6), - ] - - def _fit(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> "LGBMRegressorPlugin": - self.model.fit(X, *args, **kwargs) - return self - - def _predict(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> pd.DataFrame: - return self.model.predict(X, *args, **kwargs) - - -plugin = LGBMRegressorPlugin diff --git a/src/hyperimpute/plugins/prediction/regression/plugin_random_forest_regressor.py b/src/hyperimpute/plugins/prediction/regression/plugin_random_forest_regressor.py index dfb2d60..148fec6 100644 --- a/src/hyperimpute/plugins/prediction/regression/plugin_random_forest_regressor.py +++ b/src/hyperimpute/plugins/prediction/regression/plugin_random_forest_regressor.py @@ -41,7 +41,7 @@ class RandomForestRegressionPlugin(base.RegressionPlugin): """ criterions = ["squared_error", "absolute_error", "poisson"] - features = ["auto", "sqrt", "log2"] + features = ["sqrt", "log2", None] def __init__( self, @@ -103,6 +103,8 @@ def _fit( return self def _predict(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> pd.DataFrame: + X = np.asarray(X) + return self.model.predict(X, *args, **kwargs) diff --git a/src/hyperimpute/utils/serialization.py b/src/hyperimpute/utils/serialization.py index 5ab22a5..902c0d3 100644 --- a/src/hyperimpute/utils/serialization.py +++ b/src/hyperimpute/utils/serialization.py @@ -138,19 +138,42 @@ def version() -> str: return MAJOR_VERSION +def _add_version(obj: Any) -> Any: + obj._serde_version = MAJOR_VERSION + return obj + + +def _check_version(obj: Any) -> Any: + local_version = obj._serde_version + + if not hasattr(obj, "_serde_version"): + raise RuntimeError("Missing serialization version") + + if local_version != MAJOR_VERSION: + raise ValueError( + f"Serialized object mismatch. Current major version is {MAJOR_VERSION}, but the serialized object has version {local_version}." + ) + + def save(model: Any) -> bytes: + _add_version(model) return cloudpickle.dumps(model) def load(buff: bytes) -> Any: - return cloudpickle.loads(buff) + obj = cloudpickle.loads(buff) + _check_version(obj) + return obj def save_to_file(path: Union[str, Path], model: Any) -> Any: + _add_version(model) with open(path, "wb") as f: return cloudpickle.dump(model, f) def load_from_file(path: Union[str, Path]) -> Any: with open(path, "rb") as f: - return cloudpickle.load(f) + obj = cloudpickle.load(f) + _check_version(obj) + return obj diff --git a/src/hyperimpute/version.py b/src/hyperimpute/version.py index e156496..db9479b 100644 --- a/src/hyperimpute/version.py +++ b/src/hyperimpute/version.py @@ -1,4 +1,4 @@ -__version__ = "0.1.14" +__version__ = "0.1.15" MAJOR_VERSION = ".".join(__version__.split(".")[:-1]) MINOR_VERSION = __version__.split(".")[-1] diff --git a/tests/imputers/test_imputation_serde.py b/tests/imputers/test_imputation_serde.py index 3bbe523..4d52d23 100644 --- a/tests/imputers/test_imputation_serde.py +++ b/tests/imputers/test_imputation_serde.py @@ -2,6 +2,7 @@ from typing import Tuple # third party +import cloudpickle import numpy as np import pandas as pd import pytest @@ -29,16 +30,25 @@ def dataset(mechanism: str, p_miss: float) -> Tuple[np.ndarray, np.ndarray, np.n return pd.DataFrame(x), pd.DataFrame(x_miss) -@pytest.mark.slow @pytest.mark.parametrize("plugin", Imputers().list()) def test_pickle(plugin: str) -> None: x, x_miss = dataset("MAR", 0.3) estimator = Imputers().get(plugin) + buff = save(estimator) + estimator_new = load(buff) + estimator.fit_transform(x_miss) buff = save(estimator) estimator_new = load(buff) estimator_new.transform(x_miss) + + # load wrong version + estimator._serde_version = "sdfsfs" + buff = cloudpickle.dumps(estimator) + + with pytest.raises(ValueError): + load(buff) diff --git a/tests/prediction/classifiers/test_clf_serde.py b/tests/prediction/classifiers/test_clf_serde.py index 3278443..c026422 100644 --- a/tests/prediction/classifiers/test_clf_serde.py +++ b/tests/prediction/classifiers/test_clf_serde.py @@ -1,41 +1,24 @@ -# stdlib -from typing import Tuple - # third party -import numpy as np -import pandas as pd import pytest -from sklearn.model_selection import train_test_split +from sklearn.datasets import load_iris # hyperimpute absolute from hyperimpute.plugins import Predictions from hyperimpute.utils.serialization import load, save -def dataset() -> Tuple[pd.DataFrame, pd.Series]: - rng = np.random.RandomState(1) - - N = 1000 - X = rng.randint(N, size=(N, 3)) - y = rng.randint(2, size=(N)) - - return pd.DataFrame(X), pd.Series(y) - - -@pytest.mark.parametrize("plugin", ["xgboost", "catboost"]) +@pytest.mark.parametrize("plugin", Predictions(category="classifier").list()) def test_pickle(plugin: str) -> None: - X, y = dataset() + X, y = load_iris(return_X_y=True, as_frame=True) - X_train, X_test, y_train, y_test = train_test_split( - X, y, test_size=0.2, random_state=42 - ) + estimator = Predictions(category="classifier").get(plugin) - estimator = Predictions().get(plugin) + buff = save(estimator) + estimator_new = load(buff) - estimator.fit(X_train, y_train) - estimator.predict(X_test) + estimator.fit(X, y) buff = save(estimator) estimator_new = load(buff) - estimator_new.predict(X_test) + estimator_new.predict(X) diff --git a/tests/prediction/classifiers/test_lgbm.py b/tests/prediction/classifiers/test_lgbm.py deleted file mode 100644 index eb5ca9c..0000000 --- a/tests/prediction/classifiers/test_lgbm.py +++ /dev/null @@ -1,100 +0,0 @@ -# stdlib -import sys -from typing import Any - -# third party -import numpy as np -import optuna -import pytest -from sklearn.datasets import load_iris -from sklearn.model_selection import train_test_split - -# hyperimpute absolute -from hyperimpute.plugins.prediction import PredictionPlugin, Predictions -from hyperimpute.plugins.prediction.classifiers.plugin_lgbm import plugin -from hyperimpute.utils.serialization import load, save -from hyperimpute.utils.tester import evaluate_estimator - - -def from_api() -> PredictionPlugin: - return Predictions().get("lgbm", iterations=100) - - -def from_module() -> PredictionPlugin: - return plugin(iterations=100) - - -def from_pickle() -> PredictionPlugin: - buff = save(plugin()) - return load(buff) - - -@pytest.mark.parametrize("test_plugin", [from_api(), from_module(), from_pickle()]) -def test_lgbm_plugin_sanity(test_plugin: PredictionPlugin) -> None: - assert test_plugin is not None - - -@pytest.mark.parametrize("test_plugin", [from_api(), from_module(), from_pickle()]) -def test_lgbm_plugin_name(test_plugin: PredictionPlugin) -> None: - assert test_plugin.name() == "lgbm" - - -@pytest.mark.parametrize("test_plugin", [from_api(), from_module(), from_pickle()]) -def test_lgbm_plugin_type(test_plugin: PredictionPlugin) -> None: - assert test_plugin.type() == "prediction" - assert test_plugin.subtype() == "classifier" - - -@pytest.mark.parametrize("test_plugin", [from_api(), from_module(), from_pickle()]) -def test_lgbm_plugin_hyperparams(test_plugin: PredictionPlugin) -> None: - assert len(test_plugin.hyperparameter_space()) == 9 - - -@pytest.mark.parametrize("test_plugin", [from_api(), from_module(), from_pickle()]) -@pytest.mark.skipif(sys.platform == "darwin", reason="LGBM crash on OSX") -def test_lgbm_plugin_fit_predict(test_plugin: PredictionPlugin) -> None: - X, y = load_iris(return_X_y=True, as_frame=True) - X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) - - y_pred = test_plugin.fit(X_train, y_train).predict(X_test) - - assert np.abs((y_pred.values - y_test.values).mean()) < 1 - - -@pytest.mark.parametrize("test_plugin", [from_api(), from_module(), from_pickle()]) -@pytest.mark.skipif(sys.platform == "darwin", reason="LGBM crash on OSX") -def test_lgbm_plugin_score(test_plugin: PredictionPlugin) -> None: - X, y = load_iris(return_X_y=True, as_frame=True) - X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) - - test_plugin.fit(X_train, y_train) - - assert test_plugin.score(X_test, y_test) > 0.5 - - -@pytest.mark.skipif(sys.platform == "darwin", reason="LGBM crash on OSX") -def test_param_search() -> None: - if len(plugin.hyperparameter_space()) == 0: - return - - X, y = load_iris(return_X_y=True, as_frame=True) - - def evaluate_args(**kwargs: Any) -> float: - kwargs["iterations"] = 100 - model = plugin(**kwargs) - metrics = evaluate_estimator(model, X, y) - - return metrics["clf"]["aucroc"][0] - - def objective(trial: optuna.Trial) -> float: - args = plugin.sample_hyperparameters(trial) - return evaluate_args(**args) - - study = optuna.create_study( - load_if_exists=True, - directions=["maximize"], - study_name=f"test_param_search_{plugin.name()}", - ) - study.optimize(objective, n_trials=10, timeout=60) - - assert len(study.trials) == 10 diff --git a/tests/prediction/regression/test_lgbm_regressor.py b/tests/prediction/regression/test_lgbm_regressor.py deleted file mode 100644 index bffb492..0000000 --- a/tests/prediction/regression/test_lgbm_regressor.py +++ /dev/null @@ -1,87 +0,0 @@ -# stdlib -import sys -from typing import Any - -# third party -import optuna -import pytest -from sklearn.datasets import load_diabetes - -# hyperimpute absolute -from hyperimpute.plugins.prediction import PredictionPlugin, Predictions -from hyperimpute.plugins.prediction.regression.plugin_lgbm_regressor import plugin -from hyperimpute.utils.serialization import load, save -from hyperimpute.utils.tester import evaluate_regression - - -def from_api() -> PredictionPlugin: - return Predictions(category="regression").get("lgbm_regressor") - - -def from_module() -> PredictionPlugin: - return plugin() - - -def from_pickle() -> PredictionPlugin: - buff = save(plugin()) - return load(buff) - - -@pytest.mark.parametrize("test_plugin", [from_api(), from_module(), from_pickle()]) -def test_lgbm_regressor_plugin_sanity(test_plugin: PredictionPlugin) -> None: - assert test_plugin is not None - - -@pytest.mark.parametrize("test_plugin", [from_api(), from_module(), from_pickle()]) -def test_lgbm_regressor_plugin_name(test_plugin: PredictionPlugin) -> None: - assert test_plugin.name() == "lgbm_regressor" - - -@pytest.mark.parametrize("test_plugin", [from_api(), from_module(), from_pickle()]) -def test_lgbm_regressor_plugin_type(test_plugin: PredictionPlugin) -> None: - assert test_plugin.type() == "prediction" - assert test_plugin.subtype() == "regression" - - -@pytest.mark.parametrize("test_plugin", [from_api(), from_module(), from_pickle()]) -def test_lgbm_regressor_plugin_hyperparams(test_plugin: PredictionPlugin) -> None: - assert len(test_plugin.hyperparameter_space()) == 9 - - -@pytest.mark.parametrize("test_plugin", [from_api(), from_module(), from_pickle()]) -@pytest.mark.skipif(sys.platform == "darwin", reason="LGBM crash on OSX") -def test_lgbm_regressor_plugin_fit_predict(test_plugin: PredictionPlugin) -> None: - X, y = load_diabetes(return_X_y=True) - - score = evaluate_regression(test_plugin, X, y) - - assert score["clf"]["rmse"][0] < 5000 - - -@pytest.mark.skipif(sys.platform == "darwin", reason="LGBM crash on OSX") -def test_param_search() -> None: - if len(plugin.hyperparameter_space()) == 0: - return - - X, y = load_diabetes(return_X_y=True) - - def evaluate_args(**kwargs: Any) -> float: - kwargs["n_estimators"] = 10 - - model = plugin(**kwargs) - metrics = evaluate_regression(model, X, y) - - return metrics["clf"]["rmse"][0] - - def objective(trial: optuna.Trial) -> float: - args = plugin.sample_hyperparameters(trial) - return evaluate_args(**args) - - study = optuna.create_study( - load_if_exists=True, - directions=["maximize"], - study_name=f"test_param_search_{plugin.name()}", - ) - study.optimize(objective, n_trials=10, timeout=60) - - assert len(study.trials) == 10 diff --git a/tests/prediction/regression/test_reg_serde.py b/tests/prediction/regression/test_reg_serde.py new file mode 100644 index 0000000..b7c0940 --- /dev/null +++ b/tests/prediction/regression/test_reg_serde.py @@ -0,0 +1,24 @@ +# third party +import pytest +from sklearn.datasets import load_iris + +# hyperimpute absolute +from hyperimpute.plugins import Predictions +from hyperimpute.utils.serialization import load, save + + +@pytest.mark.parametrize("plugin", Predictions(category="regression").list()) +def test_pickle(plugin: str) -> None: + X, y = load_iris(return_X_y=True, as_frame=True) + + estimator = Predictions(category="regression").get(plugin) + + buff = save(estimator) + estimator_new = load(buff) + + estimator.fit(X, y) + + buff = save(estimator) + estimator_new = load(buff) + + estimator_new.predict(X)