From 47f649e0d5493a28d2fce5da6b3978a86594f2ba Mon Sep 17 00:00:00 2001 From: Bogdan Cebere Date: Thu, 19 Jan 2023 10:15:00 +0200 Subject: [PATCH] Optuna 3.1 support (#33) * optuna 3.1 debug * cleanup * cleanup * cleanup --- setup.cfg | 2 +- src/hyperimpute/utils/optimizer.py | 12 ++++++++---- src/hyperimpute/version.py | 6 ++++-- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/setup.cfg b/setup.cfg index e73657a..99a4180 100644 --- a/setup.cfg +++ b/setup.cfg @@ -46,7 +46,7 @@ install_requires = torch>=1.10.0 numpy>=1.20 catboost>=1.0.5 - optuna>=2.10 + optuna>=3.1 loguru==.0.6.0 xgboost>=1.6.1 miracle-imputation>=0.1.3 diff --git a/src/hyperimpute/utils/optimizer.py b/src/hyperimpute/utils/optimizer.py index a5b6d78..72f3603 100644 --- a/src/hyperimpute/utils/optimizer.py +++ b/src/hyperimpute/utils/optimizer.py @@ -4,6 +4,7 @@ # third party import optuna +from optuna.storages import JournalRedisStorage, JournalStorage import redis # hyperimpute absolute @@ -21,17 +22,16 @@ def __init__( ): self.url = f"redis://{host}:{port}/" - self._optuna_storage = optuna.storages.RedisStorage(url=self.url) + self._optuna_storage = JournalStorage(JournalRedisStorage(url=self.url)) self._client = redis.Redis.from_url(self.url) - def optuna(self) -> optuna.storages.RedisStorage: + def optuna(self) -> JournalStorage: return self._optuna_storage def client(self) -> redis.Redis: return self._client -backend = RedisBackend() threshold = 40 @@ -104,7 +104,11 @@ def create_study( patience: int = threshold, ) -> Tuple[optuna.Study, ParamRepeatPruner]: - storage_obj = backend.optuna() + try: + backend = RedisBackend() + storage_obj = backend.optuna() + except BaseException: + storage_obj = None try: study = optuna.create_study( diff --git a/src/hyperimpute/version.py b/src/hyperimpute/version.py index 6b07f98..3f40b97 100644 --- a/src/hyperimpute/version.py +++ b/src/hyperimpute/version.py @@ -1,2 +1,4 @@ -__version__ = "0.1.12" -MAJOR_VERSION = "0.1" +__version__ = "0.1.13" + +MAJOR_VERSION = ".".join(__version__.split(".")[:-1]) +MINOR_VERSION = __version__.split(".")[-1]