From 72fc30fed6b101d0bc2ed7930342a8e937654068 Mon Sep 17 00:00:00 2001 From: Fede Raimondo Date: Sat, 28 Sep 2024 11:38:07 +0200 Subject: [PATCH] Wrap only if we do not apply to everything --- julearn/pipeline/pipeline_creator.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/julearn/pipeline/pipeline_creator.py b/julearn/pipeline/pipeline_creator.py index 9cc9e2fd4..c2c572590 100644 --- a/julearn/pipeline/pipeline_creator.py +++ b/julearn/pipeline/pipeline_creator.py @@ -511,14 +511,18 @@ def to_pipeline( logger.debug(f"\t Params to tune: {step_params_to_tune}") # Wrap in a JuTransformer if needed - if self.wrap and not isinstance(estimator, JuTransformer): - estimator = self._wrap_step( - name, - estimator, - step_dict.apply_to, - row_select_col_type=step_dict.row_select_col_type, - row_select_vals=step_dict.row_select_vals, - ) + if self.wrap: + if step_dict.apply_to not in [ + {"*"}, + {".*"}, + ] and not isinstance(estimator, JuTransformer): + estimator = self._wrap_step( + name, + estimator, + step_dict.apply_to, + row_select_col_type=step_dict.row_select_col_type, + row_select_vals=step_dict.row_select_vals, + ) # Check if a step with the same name was already added pipeline_steps.append((name, estimator)) @@ -794,7 +798,7 @@ def _check_X_types( @staticmethod def _is_transformer_step( - step: Union[str, EstimatorLike, TargetPipelineCreator] + step: Union[str, EstimatorLike, TargetPipelineCreator], ) -> bool: """Check if a step is a transformer.""" if step in list_transformers(): @@ -805,7 +809,7 @@ def _is_transformer_step( @staticmethod def _is_model_step( - step: Union[EstimatorLike, str, TargetPipelineCreator] + step: Union[EstimatorLike, str, TargetPipelineCreator], ) -> bool: """Check if a step is a model.""" if step in list_models():