Skip to content

Commit

Permalink
Merge pull request #270 from juaml/fix/inspector_requirements
Browse files Browse the repository at this point in the history
[ENH] Remove final model fit requirement for inspector
  • Loading branch information
fraimondo authored Sep 4, 2024
2 parents be6a63b + 5e1e407 commit b79dc25
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 16 deletions.
1 change: 1 addition & 0 deletions docs/changes/newsfragments/270.enh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Remove final model fit requirement for inspector to be returned by :func:`.run_cross_validation` by `Fede Raimondo`_.
2 changes: 1 addition & 1 deletion docs/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,5 +96,5 @@ The following optional dependencies are available:
module is not compatible with newer Python versions and it is unmaintained.
* ``skopt``: Using the ``"bayes"`` searcher (:class:`~skopt.BayesSearchCV`)
requires the `scikit-optimize`_ package.
* ``optuna``: Using the ``"optuna"`` searcher (:class:`~optuna_integration.sklearn.OptunaSearchCV`) requires the `Optuna`_ and `optuna_integration`_ packages.
* ``optuna``: Using the ``"optuna"`` searcher (:class:`~optuna_integration.OptunaSearchCV`) requires the `Optuna`_ and `optuna_integration`_ packages.
* ``all``: Install all optional functional dependencies (except ``deslib``).
2 changes: 1 addition & 1 deletion docs/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ Enhancements
Features
^^^^^^^^

- Add :class:`~optuna_integration.sklearn.OptunaSearchCV` to the list of
- Add :class:`~optuna_integration.OptunaSearchCV` to the list of
available searchers as ``optuna`` by `Fede Raimondo`_ (:gh:`262`)


Expand Down
4 changes: 2 additions & 2 deletions examples/99_docs/run_hyperparameters_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@
# Other searchers that ``julearn`` provides are the
# :class:`~sklearn.model_selection.RandomizedSearchCV`,
# :class:`~skopt.BayesSearchCV` and
# :class:`~optuna_integration.sklearn.OptunaSearchCV`.
# :class:`~optuna_integration.OptunaSearchCV`.
#
# The randomized searcher
# (:class:`~sklearn.model_selection.RandomizedSearchCV`) is similar to the
Expand All @@ -275,7 +275,7 @@
# :class:`~skopt.BayesSearchCV` documentation, including how to specify
# the prior distributions of the hyperparameters.
#
# The Optuna searcher (:class:`~optuna_integration.sklearn.OptunaSearchCV`)
# The Optuna searcher (:class:`~optuna_integration.OptunaSearchCV`)
# uses the Optuna library to find the best hyperparameter set. Optuna is a
# hyperparameter optimization framework that has several algorithms to find
# the best hyperparameter set. For more information, see the
Expand Down
13 changes: 8 additions & 5 deletions julearn/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def run_cross_validation( # noqa: C901
:class:`~sklearn.model_selection.RandomizedSearchCV`
* ``"bayes"`` : :class:`~skopt.BayesSearchCV`
* ``"optuna"`` :
:class:`~optuna_integration.sklearn.OptunaSearchCV`
:class:`~optuna_integration.OptunaSearchCV`
* user-registered searcher name : see
:func:`~julearn.model_selection.register_searcher`
* ``scikit-learn``-compatible searcher
Expand Down Expand Up @@ -194,11 +194,11 @@ def run_cross_validation( # noqa: C901
)
if return_inspector:
if return_estimator is None:
logger.info("Inspector requested: setting return_estimator='all'")
return_estimator = "all"
if return_estimator != "all":
if return_estimator not in ["all", "cv"]:
raise_error(
"return_inspector=True requires return_estimator to be `all`."
"return_inspector=True requires return_estimator to be `all` "
"or `cv`"
)

X_types = {} if X_types is None else X_types
Expand Down Expand Up @@ -441,6 +441,9 @@ def run_cross_validation( # noqa: C901
groups=df_groups,
cv=cv_outer,
)
out = scores_df, pipeline, inspector
if isinstance(out, tuple):
out = (*out, inspector)
else:
out = out, inspector

return out
23 changes: 22 additions & 1 deletion julearn/inspect/tests/test_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def test_normal_usage(df_iris: "pd.DataFrame") -> None:
"""
X = list(df_iris.iloc[:, :-1].columns)
scores, pipe, inspect = run_cross_validation(

# All estimators
out = run_cross_validation(
X=X,
y="species",
data=df_iris,
Expand All @@ -63,13 +65,32 @@ def test_normal_usage(df_iris: "pd.DataFrame") -> None:
return_inspector=True,
problem_type="classification",
)
scores, pipe, inspect = out
assert pipe == inspect.model._model # type: ignore
for (_, score), inspect_fold in zip(
scores.iterrows(), # type: ignore
inspect.folds, # type: ignore
):
assert score["estimator"] == inspect_fold.model._model

del pipe
# only CV estimators
out = run_cross_validation(
X=X,
y="species",
data=df_iris,
model="svm",
return_estimator="cv",
return_inspector=True,
problem_type="classification",
)
scores, inspect = out
for (_, score), inspect_fold in zip(
scores.iterrows(), # type: ignore
inspect.folds, # type: ignore
):
assert score["estimator"] == inspect_fold.model._model


def test_normal_usage_with_search(df_iris: "pd.DataFrame") -> None:
"""Test inspector with search.
Expand Down
2 changes: 1 addition & 1 deletion julearn/pipeline/pipeline_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,7 @@ def _prepare_hyperparameter_tuning(
:class:`~sklearn.model_selection.RandomizedSearchCV`
* ``"bayes"`` : :class:`~skopt.BayesSearchCV`
* ``"optuna"`` :
:class:`~optuna_integration.sklearn.OptunaSearchCV`
:class:`~optuna_integration.OptunaSearchCV`
* user-registered searcher name : see
:func:`~julearn.model_selection.register_searcher`
* ``scikit-learn``-compatible searcher
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ docs = [
"towncrier<24",
"scikit-optimize>=0.10.0,<0.11",
"optuna>=3.6.0,<3.7",
"optuna_integration>=3.6.0,<3.7",
"optuna_integration>=3.6.0,<4.1",
]
deslib = ["deslib>=0.3.5,<0.4"]
viz = [
Expand All @@ -72,7 +72,7 @@ viz = [
skopt = ["scikit-optimize>=0.10.0,<0.11"]
optuna = [
"optuna>=3.6.0,<3.7",
"optuna_integration>=3.6.0,<3.7",
"optuna_integration>=3.6.0,<4.1",
]
# Add all optional functional dependencies (skip deslib until its fixed)
# This does not include dev/docs building dependencies
Expand Down
6 changes: 3 additions & 3 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ deps =
seaborn
scikit-optimize>=0.10.0,<0.11
optuna>=3.6.0,<3.7
optuna_integration>=3.6.0,<3.7
optuna_integration>=3.6.0,<4.1
commands =
pytest {toxinidir}/julearn

Expand Down Expand Up @@ -45,7 +45,7 @@ deps =
param
scikit-optimize>=0.10.0,<0.11
optuna>=3.6.0,<3.7
optuna_integration>=3.6.0,<3.7
optuna_integration>=3.6.0,<4.1
commands =
pytest -vv {toxinidir}/julearn

Expand All @@ -69,7 +69,7 @@ deps =
param
scikit-optimize>=0.10.0,<0.11
optuna>=3.6.0,<3.7
optuna_integration>=3.6.0,<3.7
optuna_integration>=3.6.0,<4.1
commands =
pytest --cov={envsitepackagesdir}/julearn --cov=./julearn --cov-report=xml --cov-report=term -vv

Expand Down

0 comments on commit b79dc25

Please sign in to comment.