Skip to content

Commit

Permalink
[ENH] pipeline (#51)
Browse files Browse the repository at this point in the history
Implements #29 - a simple pipeline compatible with `sklearn` transformers.
  • Loading branch information
fkiraly authored Aug 27, 2023
1 parent 3947650 commit d0abb44
Show file tree
Hide file tree
Showing 4 changed files with 585 additions and 0 deletions.
6 changes: 6 additions & 0 deletions skpro/registry/_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@
"str",
"name of component list attribute for meta-objects",
),
(
"fitted_named_object_parameters",
"estimator",
"str",
"name of fitted component list attribute for meta-objects",
),
]

OBJECT_TAG_TABLE = pd.DataFrame(OBJECT_TAG_REGISTER)
Expand Down
25 changes: 25 additions & 0 deletions skpro/regression/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,31 @@ def __init__(self, index=None, columns=None):
super(BaseProbaRegressor, self).__init__()
_check_estimator_deps(self)

def __rmul__(self, other):
"""Magic * method, return (left) concatenated Pipeline.
Implemented for `other` being a transformer, otherwise returns `NotImplemented`.
Parameters
----------
other: `sklearn` transformer, must follow `sklearn` API
otherwise, `NotImplemented` is returned
Returns
-------
Pipeline object,
concatenation of `other` (first) with `self` (last).
not nested, contains only non-Pipeline `skpro` steps
"""
from skpro.regression.compose._pipeline import Pipeline

# we wrap self in a pipeline, and concatenate with the other
# the TransformedTargetForecaster does the rest, e.g., dispatch on other
if hasattr(other, "transform"):
return other * Pipeline([self])
else:
return NotImplemented

def fit(self, X, y):
"""Fit regressor to training data.
Expand Down
6 changes: 6 additions & 0 deletions skpro/regression/compose/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -*- coding: utf-8 -*-
"""Composition and pipelines for probabilistic supervised regression."""

from skpro.regression.compose._pipeline import Pipeline

__all__ = ["Pipeline"]
Loading

0 comments on commit d0abb44

Please sign in to comment.