Skip to content

Commit

Permalink
Wrap only if we do not apply to everything
Browse files Browse the repository at this point in the history
  • Loading branch information
fraimondo committed Sep 28, 2024
1 parent eac0462 commit 72fc30f
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions julearn/pipeline/pipeline_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,14 +511,18 @@ def to_pipeline(
logger.debug(f"\t Params to tune: {step_params_to_tune}")

# Wrap in a JuTransformer if needed
if self.wrap and not isinstance(estimator, JuTransformer):
estimator = self._wrap_step(
name,
estimator,
step_dict.apply_to,
row_select_col_type=step_dict.row_select_col_type,
row_select_vals=step_dict.row_select_vals,
)
if self.wrap:
if step_dict.apply_to not in [
{"*"},
{".*"},
] and not isinstance(estimator, JuTransformer):
estimator = self._wrap_step(
name,
estimator,
step_dict.apply_to,
row_select_col_type=step_dict.row_select_col_type,
row_select_vals=step_dict.row_select_vals,
)

# Check if a step with the same name was already added
pipeline_steps.append((name, estimator))
Expand Down Expand Up @@ -794,7 +798,7 @@ def _check_X_types(

@staticmethod
def _is_transformer_step(
step: Union[str, EstimatorLike, TargetPipelineCreator]
step: Union[str, EstimatorLike, TargetPipelineCreator],
) -> bool:
"""Check if a step is a transformer."""
if step in list_transformers():
Expand All @@ -805,7 +809,7 @@ def _is_transformer_step(

@staticmethod
def _is_model_step(
step: Union[EstimatorLike, str, TargetPipelineCreator]
step: Union[EstimatorLike, str, TargetPipelineCreator],
) -> bool:
"""Check if a step is a model."""
if step in list_models():
Expand Down

0 comments on commit 72fc30f

Please sign in to comment.