Skip to content

Commit

Permalink
reduce diff using inheritance
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre committed Sep 3, 2024
1 parent 4887823 commit b2284b8
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 25 deletions.
13 changes: 2 additions & 11 deletions skrub/_gap_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,12 @@

from . import _dataframe as sbd
from ._on_each_column import RejectColumn, SingleColumnTransformer
from ._repr import (
_HTMLDocumentationLinkMixin,
doc_link_module,
doc_link_template,
doc_link_url_param_generator,
)
from ._repr import _SkrubHTMLDocumentationLinkMixin
from ._utils import unique_strings


class GapEncoder(
_HTMLDocumentationLinkMixin, TransformerMixin, SingleColumnTransformer
_SkrubHTMLDocumentationLinkMixin, TransformerMixin, SingleColumnTransformer
):
"""Constructs latent topics with continuous encoding.
Expand Down Expand Up @@ -185,10 +180,6 @@ class GapEncoder(
The higher the value, the bigger the correspondence with the topic.
"""

_doc_link_module = doc_link_module
_doc_link_template = doc_link_template
_doc_link_url_param_generator = doc_link_url_param_generator

def __init__(
self,
n_components=10,
Expand Down
12 changes: 9 additions & 3 deletions skrub/_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

sklearn_version = parse_version(sklearn.__version__)

# TODO: remove when scikit-learn 1.6 is the minimum supported version
# TODO: subsequently, we should remove the inheritance from _HTMLDocumentationLinkMixin
# for each estimator then.
# TODO: remove when scikit-learn 1.6 is the minimum supported version and only import
# We have this fix due to the following bug:
# https://github.com/scikit-learn/scikit-learn/pull/29774
if sklearn_version > parse_version("1.6"):
from sklearn.utils._estimator_html_repr import _HTMLDocumentationLinkMixin
else:
Expand Down Expand Up @@ -101,3 +101,9 @@ def doc_link_url_param_generator(estimator):
"estimator_module": estimator_module,
"estimator_name": estimator_name,
}


class _SkrubHTMLDocumentationLinkMixin(_HTMLDocumentationLinkMixin):
_doc_link_template = doc_link_template
_doc_link_module = doc_link_module
_doc_link_url_param_generator = doc_link_url_param_generator
15 changes: 4 additions & 11 deletions skrub/_table_vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@
from ._datetime_encoder import DatetimeEncoder
from ._gap_encoder import GapEncoder
from ._on_each_column import SingleColumnTransformer
from ._repr import (
_HTMLDocumentationLinkMixin,
doc_link_module,
doc_link_template,
doc_link_url_param_generator,
)
from ._repr import _SkrubHTMLDocumentationLinkMixin
from ._select_cols import Drop
from ._to_datetime import ToDatetime
from ._to_float32 import ToFloat32
Expand Down Expand Up @@ -116,7 +111,9 @@ def _check_transformer(transformer):
return clone(transformer)


class TableVectorizer(_HTMLDocumentationLinkMixin, TransformerMixin, BaseEstimator):
class TableVectorizer(
_SkrubHTMLDocumentationLinkMixin, TransformerMixin, BaseEstimator
):
"""Transform a dataframe to a numerical (vectorized) representation.
Applies a different transformation to each of several kinds of columns:
Expand Down Expand Up @@ -411,10 +408,6 @@ class TableVectorizer(_HTMLDocumentationLinkMixin, TransformerMixin, BaseEstimat
ValueError: Column 'A' used twice in 'specific_transformers', at indices 0 and 1.
""" # noqa: E501

_doc_link_module = doc_link_module
_doc_link_template = doc_link_template
_doc_link_url_param_generator = doc_link_url_param_generator

def __init__(
self,
*,
Expand Down

0 comments on commit b2284b8

Please sign in to comment.