From e19794dab94aa051523243da2d873c25f8688bf9 Mon Sep 17 00:00:00 2001 From: Fede Raimondo Date: Tue, 8 Oct 2024 15:07:30 +0200 Subject: [PATCH] More intelligent step wrapping --- julearn/pipeline/pipeline_creator.py | 58 ++++++++++++++----- .../pipeline/tests/test_pipeline_creator.py | 53 ++++++++++++++++- julearn/tests/test_api.py | 2 +- 3 files changed, 95 insertions(+), 18 deletions(-) diff --git a/julearn/pipeline/pipeline_creator.py b/julearn/pipeline/pipeline_creator.py index c2c572590..566b5c3ac 100644 --- a/julearn/pipeline/pipeline_creator.py +++ b/julearn/pipeline/pipeline_creator.py @@ -42,6 +42,37 @@ from .target_pipeline_creator import TargetPipelineCreator +def _should_wrap_this_step( + X_types: Dict[str, List[str]], # noqa: N803 + apply_to: ColumnTypesLike, +) -> bool: + """Check if we should wrap the step. + + Parameters + ---------- + X_types : Dict[str, List[str]] + The types of the columns in the data. + apply_to : ColumnTypesLike + The types to apply this step to. + + Returns + ------- + bool + Whether we should wrap the step. + + """ + + # If we have a wildcard, we will not wrap the step + if any(x in ["*", ".*"] for x in apply_to): + return False + + # If any of the X_types is not in the apply_to, we will wrap the step + if any(x not in apply_to for x in X_types.keys()): + return True + + return False + + def _params_to_pipeline( param: Any, X_types: Dict[str, List], # noqa: N803 @@ -511,18 +542,16 @@ def to_pipeline( logger.debug(f"\t Params to tune: {step_params_to_tune}") # Wrap in a JuTransformer if needed - if self.wrap: - if step_dict.apply_to not in [ - {"*"}, - {".*"}, - ] and not isinstance(estimator, JuTransformer): - estimator = self._wrap_step( - name, - estimator, - step_dict.apply_to, - row_select_col_type=step_dict.row_select_col_type, - row_select_vals=step_dict.row_select_vals, - ) + if _should_wrap_this_step( + X_types, step_dict.apply_to + ) and not isinstance(estimator, JuTransformer): + estimator = self._wrap_step( + name, + estimator, + step_dict.apply_to, + row_select_col_type=step_dict.row_select_col_type, + row_select_vals=step_dict.row_select_vals, + ) # Check if a step with the same name was already added pipeline_steps.append((name, estimator)) @@ -543,7 +572,9 @@ def to_pipeline( for k, v in model_params.items() } model_estimator.set_params(**model_params) - if self.wrap and not isinstance(model_estimator, JuModelLike): + if _should_wrap_this_step( + X_types, model_step.apply_to + ) and not isinstance(model_estimator, JuModelLike): logger.debug(f"Wrapping {model_name}") model_estimator = WrapModel(model_estimator, model_step.apply_to) @@ -793,7 +824,6 @@ def _check_X_types( "this type." ) - self.wrap = needed_types != {"continuous"} return X_types @staticmethod diff --git a/julearn/pipeline/tests/test_pipeline_creator.py b/julearn/pipeline/tests/test_pipeline_creator.py index b2b8e7253..99dab55f9 100644 --- a/julearn/pipeline/tests/test_pipeline_creator.py +++ b/julearn/pipeline/tests/test_pipeline_creator.py @@ -26,10 +26,10 @@ from sklearn.pipeline import Pipeline -def test_construction_working( +def test_construction_working_wrapping( model: str, preprocess: Union[str, List[str]], problem_type: str ) -> None: - """Test that the pipeline constructions works as expected. + """Test that the pipeline constructions works as expected (wrapping). Parameters ---------- @@ -46,7 +46,7 @@ def test_construction_working( for step in preprocess: creator.add(step, apply_to="categorical") creator.add(model) - X_types = {"categorical": ["A"]} + X_types = {"categorical": ["A"], "continuous": ["B"]} pipeline = creator.to_pipeline(X_types=X_types) # check preprocessing steps @@ -72,6 +72,53 @@ def test_construction_working( assert len(preprocess) + 2 == len(pipeline.steps) +def test_construction_working_nowrapping( + model: str, preprocess: Union[str, List[str]], problem_type: str +) -> None: + """Test that the pipeline constructions works as expected (no wrapping). + + Parameters + ---------- + model : str + The model to test. + preprocess : str or list of str + The preprocessing steps to test. + problem_type : str + The problem type to test. + + """ + creator = PipelineCreator(problem_type=problem_type) + preprocess = preprocess if isinstance(preprocess, list) else [preprocess] + for step in preprocess: + creator.add(step, apply_to="*") + creator.add(model, apply_to=["categorical", "continuous"]) + X_types = {"categorical": ["A"], "continuous": ["B"]} + pipeline = creator.to_pipeline(X_types=X_types) + + # check preprocessing steps + # ignoring first step for types and last for model + for element in zip(preprocess, pipeline.steps[1:-1]): + _preprocess, (name, transformer) = element + assert name.startswith(f"{_preprocess}") + assert not isinstance(transformer, JuColumnTransformer) + assert isinstance( + transformer, get_transformer(_preprocess).__class__ + ) + + # check model step + model_name, model = pipeline.steps[-1] + assert not isinstance(model, WrapModel) + assert isinstance( + model, + get_model( + model_name, + problem_type=problem_type, + ).__class__, + ) + assert len(preprocess) + 2 == len(pipeline.steps) + + + def test_fit_and_transform_no_error( X_iris: pd.DataFrame, # noqa: N803 y_iris: pd.Series, diff --git a/julearn/tests/test_api.py b/julearn/tests/test_api.py index 1719c6bbd..af7a3641a 100644 --- a/julearn/tests/test_api.py +++ b/julearn/tests/test_api.py @@ -1227,7 +1227,7 @@ def test_api_stacking_models() -> None: # The final model should be a stacking model im which the first estimator # is a grid search assert isinstance( - final.steps[1][1].model.estimators[0][1], # type: ignore + final.steps[1][1].estimators[0][1], # type: ignore GridSearchCV, )