diff --git a/docs/changes/newsfragments/270.enh b/docs/changes/newsfragments/270.enh new file mode 100644 index 000000000..63cd68890 --- /dev/null +++ b/docs/changes/newsfragments/270.enh @@ -0,0 +1 @@ +Remove final model fit requirement for inspector to be returned by :func:`.run_cross_validation` by `Fede Raimondo`_. \ No newline at end of file diff --git a/docs/getting_started.rst b/docs/getting_started.rst index aa7aa851c..47bd86601 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -96,5 +96,5 @@ The following optional dependencies are available: module is not compatible with newer Python versions and it is unmaintained. * ``skopt``: Using the ``"bayes"`` searcher (:class:`~skopt.BayesSearchCV`) requires the `scikit-optimize`_ package. -* ``optuna``: Using the ``"optuna"`` searcher (:class:`~optuna_integration.sklearn.OptunaSearchCV`) requires the `Optuna`_ and `optuna_integration`_ packages. +* ``optuna``: Using the ``"optuna"`` searcher (:class:`~optuna_integration.OptunaSearchCV`) requires the `Optuna`_ and `optuna_integration`_ packages. * ``all``: Install all optional functional dependencies (except ``deslib``). diff --git a/docs/whats_new.rst b/docs/whats_new.rst index d1137d51d..dd36ea0c1 100644 --- a/docs/whats_new.rst +++ b/docs/whats_new.rst @@ -56,7 +56,7 @@ Enhancements Features ^^^^^^^^ -- Add :class:`~optuna_integration.sklearn.OptunaSearchCV` to the list of +- Add :class:`~optuna_integration.OptunaSearchCV` to the list of available searchers as ``optuna`` by `Fede Raimondo`_ (:gh:`262`) diff --git a/examples/99_docs/run_hyperparameters_docs.py b/examples/99_docs/run_hyperparameters_docs.py index ce476262a..04bdf6a13 100644 --- a/examples/99_docs/run_hyperparameters_docs.py +++ b/examples/99_docs/run_hyperparameters_docs.py @@ -255,7 +255,7 @@ # Other searchers that ``julearn`` provides are the # :class:`~sklearn.model_selection.RandomizedSearchCV`, # :class:`~skopt.BayesSearchCV` and -# :class:`~optuna_integration.sklearn.OptunaSearchCV`. +# :class:`~optuna_integration.OptunaSearchCV`. # # The randomized searcher # (:class:`~sklearn.model_selection.RandomizedSearchCV`) is similar to the @@ -275,7 +275,7 @@ # :class:`~skopt.BayesSearchCV` documentation, including how to specify # the prior distributions of the hyperparameters. # -# The Optuna searcher (:class:`~optuna_integration.sklearn.OptunaSearchCV`) +# The Optuna searcher (:class:`~optuna_integration.OptunaSearchCV`) # uses the Optuna library to find the best hyperparameter set. Optuna is a # hyperparameter optimization framework that has several algorithms to find # the best hyperparameter set. For more information, see the diff --git a/julearn/api.py b/julearn/api.py index 130b57e0a..c86086d36 100644 --- a/julearn/api.py +++ b/julearn/api.py @@ -142,7 +142,7 @@ def run_cross_validation( # noqa: C901 :class:`~sklearn.model_selection.RandomizedSearchCV` * ``"bayes"`` : :class:`~skopt.BayesSearchCV` * ``"optuna"`` : - :class:`~optuna_integration.sklearn.OptunaSearchCV` + :class:`~optuna_integration.OptunaSearchCV` * user-registered searcher name : see :func:`~julearn.model_selection.register_searcher` * ``scikit-learn``-compatible searcher @@ -194,11 +194,11 @@ def run_cross_validation( # noqa: C901 ) if return_inspector: if return_estimator is None: - logger.info("Inspector requested: setting return_estimator='all'") return_estimator = "all" - if return_estimator != "all": + if return_estimator not in ["all", "cv"]: raise_error( - "return_inspector=True requires return_estimator to be `all`." + "return_inspector=True requires return_estimator to be `all` " + "or `cv`" ) X_types = {} if X_types is None else X_types @@ -441,6 +441,9 @@ def run_cross_validation( # noqa: C901 groups=df_groups, cv=cv_outer, ) - out = scores_df, pipeline, inspector + if isinstance(out, tuple): + out = (*out, inspector) + else: + out = out, inspector return out diff --git a/julearn/inspect/tests/test_inspector.py b/julearn/inspect/tests/test_inspector.py index 8643cee1d..6f069324b 100644 --- a/julearn/inspect/tests/test_inspector.py +++ b/julearn/inspect/tests/test_inspector.py @@ -54,7 +54,9 @@ def test_normal_usage(df_iris: "pd.DataFrame") -> None: """ X = list(df_iris.iloc[:, :-1].columns) - scores, pipe, inspect = run_cross_validation( + + # All estimators + out = run_cross_validation( X=X, y="species", data=df_iris, @@ -63,6 +65,7 @@ def test_normal_usage(df_iris: "pd.DataFrame") -> None: return_inspector=True, problem_type="classification", ) + scores, pipe, inspect = out assert pipe == inspect.model._model # type: ignore for (_, score), inspect_fold in zip( scores.iterrows(), # type: ignore @@ -70,6 +73,24 @@ def test_normal_usage(df_iris: "pd.DataFrame") -> None: ): assert score["estimator"] == inspect_fold.model._model + del pipe + # only CV estimators + out = run_cross_validation( + X=X, + y="species", + data=df_iris, + model="svm", + return_estimator="cv", + return_inspector=True, + problem_type="classification", + ) + scores, inspect = out + for (_, score), inspect_fold in zip( + scores.iterrows(), # type: ignore + inspect.folds, # type: ignore + ): + assert score["estimator"] == inspect_fold.model._model + def test_normal_usage_with_search(df_iris: "pd.DataFrame") -> None: """Test inspector with search. diff --git a/julearn/pipeline/pipeline_creator.py b/julearn/pipeline/pipeline_creator.py index 60d0be052..9cc9e2fd4 100644 --- a/julearn/pipeline/pipeline_creator.py +++ b/julearn/pipeline/pipeline_creator.py @@ -944,7 +944,7 @@ def _prepare_hyperparameter_tuning( :class:`~sklearn.model_selection.RandomizedSearchCV` * ``"bayes"`` : :class:`~skopt.BayesSearchCV` * ``"optuna"`` : - :class:`~optuna_integration.sklearn.OptunaSearchCV` + :class:`~optuna_integration.OptunaSearchCV` * user-registered searcher name : see :func:`~julearn.model_selection.register_searcher` * ``scikit-learn``-compatible searcher diff --git a/pyproject.toml b/pyproject.toml index 9615834bf..bcb3b73bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ docs = [ "towncrier<24", "scikit-optimize>=0.10.0,<0.11", "optuna>=3.6.0,<3.7", - "optuna_integration>=3.6.0,<3.7", + "optuna_integration>=3.6.0,<4.1", ] deslib = ["deslib>=0.3.5,<0.4"] viz = [ @@ -72,7 +72,7 @@ viz = [ skopt = ["scikit-optimize>=0.10.0,<0.11"] optuna = [ "optuna>=3.6.0,<3.7", - "optuna_integration>=3.6.0,<3.7", + "optuna_integration>=3.6.0,<4.1", ] # Add all optional functional dependencies (skip deslib until its fixed) # This does not include dev/docs building dependencies diff --git a/tox.ini b/tox.ini index 09bb9c0aa..89030b7b4 100644 --- a/tox.ini +++ b/tox.ini @@ -16,7 +16,7 @@ deps = seaborn scikit-optimize>=0.10.0,<0.11 optuna>=3.6.0,<3.7 - optuna_integration>=3.6.0,<3.7 + optuna_integration>=3.6.0,<4.1 commands = pytest {toxinidir}/julearn @@ -45,7 +45,7 @@ deps = param scikit-optimize>=0.10.0,<0.11 optuna>=3.6.0,<3.7 - optuna_integration>=3.6.0,<3.7 + optuna_integration>=3.6.0,<4.1 commands = pytest -vv {toxinidir}/julearn @@ -69,7 +69,7 @@ deps = param scikit-optimize>=0.10.0,<0.11 optuna>=3.6.0,<3.7 - optuna_integration>=3.6.0,<3.7 + optuna_integration>=3.6.0,<4.1 commands = pytest --cov={envsitepackagesdir}/julearn --cov=./julearn --cov-report=xml --cov-report=term -vv