Skip to content

Commit

Permalink
Serialization bugfixing (#35)
Browse files Browse the repository at this point in the history
* serde error

* more tests

* improvements

* drop lgbm

* cleanup
  • Loading branch information
bcebere authored Jan 31, 2023
1 parent ead5c4b commit 825635d
Show file tree
Hide file tree
Showing 12 changed files with 75 additions and 466 deletions.
9 changes: 0 additions & 9 deletions src/hyperimpute/plugins/core/base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ class Plugin(Serializable, metaclass=ABCMeta):
def __init__(self) -> None:
super().__init__()

self.drop_consts = []

@staticmethod
@abstractmethod
def hyperparameter_space(*args: Any, **kwargs: Any) -> List[Params]:
Expand Down Expand Up @@ -122,11 +120,6 @@ def fit_predict(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> pd.DataFram
def fit(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> Any:
X = cast.to_dataframe(X)

for col in X.columns:
if len(X.loc[X[col].notna(), col].unique()) <= 1:
self.drop_consts.append(col)

X = X.drop(columns=self.drop_consts)
self.columns = X.columns
return self._fit(X, *args, **kwargs)

Expand All @@ -136,7 +129,6 @@ def _fit(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> "Plugin":

def transform(self, X: pd.DataFrame) -> pd.DataFrame:
X = cast.to_dataframe(X)
X = X.drop(columns=self.drop_consts)
return pd.DataFrame(self._transform(X))

@abstractmethod
Expand All @@ -145,7 +137,6 @@ def _transform(self, X: pd.DataFrame) -> pd.DataFrame:

def predict(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> pd.DataFrame:
X = cast.to_dataframe(X)
X = X.drop(columns=self.drop_consts)
return pd.DataFrame(self._predict(X, *args, *kwargs))

@abstractmethod
Expand Down
127 changes: 0 additions & 127 deletions src/hyperimpute/plugins/prediction/classifiers/plugin_lgbm.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class RandomForestPlugin(base.ClassifierPlugin):
"""

criterions = ["gini", "entropy"]
features = ["auto", "sqrt", "log2"]
features = ["sqrt", "log2", None]

def __init__(
self,
Expand Down Expand Up @@ -97,11 +97,13 @@ def _fit(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> "RandomForestPlugi
return self

def _predict(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> pd.DataFrame:
X = np.asarray(X)
return self.model.predict(X, *args, **kwargs)

def _predict_proba(
self, X: pd.DataFrame, *args: Any, **kwargs: Any
) -> pd.DataFrame:
X = np.asarray(X)
return self.model.predict_proba(X, *args, **kwargs)


Expand Down
112 changes: 0 additions & 112 deletions src/hyperimpute/plugins/prediction/regression/plugin_lgbm_regressor.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class RandomForestRegressionPlugin(base.RegressionPlugin):
"""

criterions = ["squared_error", "absolute_error", "poisson"]
features = ["auto", "sqrt", "log2"]
features = ["sqrt", "log2", None]

def __init__(
self,
Expand Down Expand Up @@ -103,6 +103,8 @@ def _fit(
return self

def _predict(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> pd.DataFrame:
X = np.asarray(X)

return self.model.predict(X, *args, **kwargs)


Expand Down
27 changes: 25 additions & 2 deletions src/hyperimpute/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,19 +138,42 @@ def version() -> str:
return MAJOR_VERSION


def _add_version(obj: Any) -> Any:
obj._serde_version = MAJOR_VERSION
return obj


def _check_version(obj: Any) -> Any:
local_version = obj._serde_version

if not hasattr(obj, "_serde_version"):
raise RuntimeError("Missing serialization version")

if local_version != MAJOR_VERSION:
raise ValueError(
f"Serialized object mismatch. Current major version is {MAJOR_VERSION}, but the serialized object has version {local_version}."
)


def save(model: Any) -> bytes:
_add_version(model)
return cloudpickle.dumps(model)


def load(buff: bytes) -> Any:
return cloudpickle.loads(buff)
obj = cloudpickle.loads(buff)
_check_version(obj)
return obj


def save_to_file(path: Union[str, Path], model: Any) -> Any:
_add_version(model)
with open(path, "wb") as f:
return cloudpickle.dump(model, f)


def load_from_file(path: Union[str, Path]) -> Any:
with open(path, "rb") as f:
return cloudpickle.load(f)
obj = cloudpickle.load(f)
_check_version(obj)
return obj
2 changes: 1 addition & 1 deletion src/hyperimpute/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.1.14"
__version__ = "0.1.15"

MAJOR_VERSION = ".".join(__version__.split(".")[:-1])
MINOR_VERSION = __version__.split(".")[-1]
Loading

0 comments on commit 825635d

Please sign in to comment.