Skip to content

Commit

Permalink
More intelligent step wrapping
Browse files Browse the repository at this point in the history
  • Loading branch information
fraimondo committed Oct 8, 2024
1 parent 72fc30f commit e19794d
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 18 deletions.
58 changes: 44 additions & 14 deletions julearn/pipeline/pipeline_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,37 @@
from .target_pipeline_creator import TargetPipelineCreator


def _should_wrap_this_step(
X_types: Dict[str, List[str]], # noqa: N803
apply_to: ColumnTypesLike,
) -> bool:
"""Check if we should wrap the step.
Parameters
----------
X_types : Dict[str, List[str]]
The types of the columns in the data.
apply_to : ColumnTypesLike
The types to apply this step to.
Returns
-------
bool
Whether we should wrap the step.
"""

# If we have a wildcard, we will not wrap the step
if any(x in ["*", ".*"] for x in apply_to):
return False

# If any of the X_types is not in the apply_to, we will wrap the step
if any(x not in apply_to for x in X_types.keys()):
return True

return False


def _params_to_pipeline(
param: Any,
X_types: Dict[str, List], # noqa: N803
Expand Down Expand Up @@ -511,18 +542,16 @@ def to_pipeline(
logger.debug(f"\t Params to tune: {step_params_to_tune}")

# Wrap in a JuTransformer if needed
if self.wrap:
if step_dict.apply_to not in [
{"*"},
{".*"},
] and not isinstance(estimator, JuTransformer):
estimator = self._wrap_step(
name,
estimator,
step_dict.apply_to,
row_select_col_type=step_dict.row_select_col_type,
row_select_vals=step_dict.row_select_vals,
)
if _should_wrap_this_step(
X_types, step_dict.apply_to
) and not isinstance(estimator, JuTransformer):
estimator = self._wrap_step(
name,
estimator,
step_dict.apply_to,
row_select_col_type=step_dict.row_select_col_type,
row_select_vals=step_dict.row_select_vals,
)

# Check if a step with the same name was already added
pipeline_steps.append((name, estimator))
Expand All @@ -543,7 +572,9 @@ def to_pipeline(
for k, v in model_params.items()
}
model_estimator.set_params(**model_params)
if self.wrap and not isinstance(model_estimator, JuModelLike):
if _should_wrap_this_step(
X_types, model_step.apply_to
) and not isinstance(model_estimator, JuModelLike):
logger.debug(f"Wrapping {model_name}")
model_estimator = WrapModel(model_estimator, model_step.apply_to)

Expand Down Expand Up @@ -793,7 +824,6 @@ def _check_X_types(
"this type."
)

self.wrap = needed_types != {"continuous"}
return X_types

@staticmethod
Expand Down
53 changes: 50 additions & 3 deletions julearn/pipeline/tests/test_pipeline_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
from sklearn.pipeline import Pipeline


def test_construction_working(
def test_construction_working_wrapping(
model: str, preprocess: Union[str, List[str]], problem_type: str
) -> None:
"""Test that the pipeline constructions works as expected.
"""Test that the pipeline constructions works as expected (wrapping).
Parameters
----------
Expand All @@ -46,7 +46,7 @@ def test_construction_working(
for step in preprocess:
creator.add(step, apply_to="categorical")
creator.add(model)
X_types = {"categorical": ["A"]}
X_types = {"categorical": ["A"], "continuous": ["B"]}
pipeline = creator.to_pipeline(X_types=X_types)

# check preprocessing steps
Expand All @@ -72,6 +72,53 @@ def test_construction_working(
assert len(preprocess) + 2 == len(pipeline.steps)


def test_construction_working_nowrapping(
model: str, preprocess: Union[str, List[str]], problem_type: str
) -> None:
"""Test that the pipeline constructions works as expected (no wrapping).
Parameters
----------
model : str
The model to test.
preprocess : str or list of str
The preprocessing steps to test.
problem_type : str
The problem type to test.
"""
creator = PipelineCreator(problem_type=problem_type)
preprocess = preprocess if isinstance(preprocess, list) else [preprocess]
for step in preprocess:
creator.add(step, apply_to="*")
creator.add(model, apply_to=["categorical", "continuous"])
X_types = {"categorical": ["A"], "continuous": ["B"]}
pipeline = creator.to_pipeline(X_types=X_types)

# check preprocessing steps
# ignoring first step for types and last for model
for element in zip(preprocess, pipeline.steps[1:-1]):
_preprocess, (name, transformer) = element
assert name.startswith(f"{_preprocess}")
assert not isinstance(transformer, JuColumnTransformer)
assert isinstance(
transformer, get_transformer(_preprocess).__class__
)

# check model step
model_name, model = pipeline.steps[-1]
assert not isinstance(model, WrapModel)
assert isinstance(
model,
get_model(
model_name,
problem_type=problem_type,
).__class__,
)
assert len(preprocess) + 2 == len(pipeline.steps)



def test_fit_and_transform_no_error(
X_iris: pd.DataFrame, # noqa: N803
y_iris: pd.Series,
Expand Down
2 changes: 1 addition & 1 deletion julearn/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,7 +1227,7 @@ def test_api_stacking_models() -> None:
# The final model should be a stacking model im which the first estimator
# is a grid search
assert isinstance(
final.steps[1][1].model.estimators[0][1], # type: ignore
final.steps[1][1].estimators[0][1], # type: ignore
GridSearchCV,
)

Expand Down

0 comments on commit e19794d

Please sign in to comment.