From b2284b848a82a459dd6da3599e30cd9ab625de04 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 3 Sep 2024 08:52:38 +0200 Subject: [PATCH] reduce diff using inheritance --- skrub/_gap_encoder.py | 13 ++----------- skrub/_repr.py | 12 +++++++++--- skrub/_table_vectorizer.py | 15 ++++----------- 3 files changed, 15 insertions(+), 25 deletions(-) diff --git a/skrub/_gap_encoder.py b/skrub/_gap_encoder.py index d31db886b..a36cd3f43 100644 --- a/skrub/_gap_encoder.py +++ b/skrub/_gap_encoder.py @@ -19,17 +19,12 @@ from . import _dataframe as sbd from ._on_each_column import RejectColumn, SingleColumnTransformer -from ._repr import ( - _HTMLDocumentationLinkMixin, - doc_link_module, - doc_link_template, - doc_link_url_param_generator, -) +from ._repr import _SkrubHTMLDocumentationLinkMixin from ._utils import unique_strings class GapEncoder( - _HTMLDocumentationLinkMixin, TransformerMixin, SingleColumnTransformer + _SkrubHTMLDocumentationLinkMixin, TransformerMixin, SingleColumnTransformer ): """Constructs latent topics with continuous encoding. @@ -185,10 +180,6 @@ class GapEncoder( The higher the value, the bigger the correspondence with the topic. """ - _doc_link_module = doc_link_module - _doc_link_template = doc_link_template - _doc_link_url_param_generator = doc_link_url_param_generator - def __init__( self, n_components=10, diff --git a/skrub/_repr.py b/skrub/_repr.py index 02ac76ca0..e320ffa07 100644 --- a/skrub/_repr.py +++ b/skrub/_repr.py @@ -5,9 +5,9 @@ sklearn_version = parse_version(sklearn.__version__) -# TODO: remove when scikit-learn 1.6 is the minimum supported version -# TODO: subsequently, we should remove the inheritance from _HTMLDocumentationLinkMixin -# for each estimator then. +# TODO: remove when scikit-learn 1.6 is the minimum supported version and only import +# We have this fix due to the following bug: +# https://github.com/scikit-learn/scikit-learn/pull/29774 if sklearn_version > parse_version("1.6"): from sklearn.utils._estimator_html_repr import _HTMLDocumentationLinkMixin else: @@ -101,3 +101,9 @@ def doc_link_url_param_generator(estimator): "estimator_module": estimator_module, "estimator_name": estimator_name, } + + +class _SkrubHTMLDocumentationLinkMixin(_HTMLDocumentationLinkMixin): + _doc_link_template = doc_link_template + _doc_link_module = doc_link_module + _doc_link_url_param_generator = doc_link_url_param_generator diff --git a/skrub/_table_vectorizer.py b/skrub/_table_vectorizer.py index 98df3b414..47d2e8820 100644 --- a/skrub/_table_vectorizer.py +++ b/skrub/_table_vectorizer.py @@ -17,12 +17,7 @@ from ._datetime_encoder import DatetimeEncoder from ._gap_encoder import GapEncoder from ._on_each_column import SingleColumnTransformer -from ._repr import ( - _HTMLDocumentationLinkMixin, - doc_link_module, - doc_link_template, - doc_link_url_param_generator, -) +from ._repr import _SkrubHTMLDocumentationLinkMixin from ._select_cols import Drop from ._to_datetime import ToDatetime from ._to_float32 import ToFloat32 @@ -116,7 +111,9 @@ def _check_transformer(transformer): return clone(transformer) -class TableVectorizer(_HTMLDocumentationLinkMixin, TransformerMixin, BaseEstimator): +class TableVectorizer( + _SkrubHTMLDocumentationLinkMixin, TransformerMixin, BaseEstimator +): """Transform a dataframe to a numerical (vectorized) representation. Applies a different transformation to each of several kinds of columns: @@ -411,10 +408,6 @@ class TableVectorizer(_HTMLDocumentationLinkMixin, TransformerMixin, BaseEstimat ValueError: Column 'A' used twice in 'specific_transformers', at indices 0 and 1. """ # noqa: E501 - _doc_link_module = doc_link_module - _doc_link_template = doc_link_template - _doc_link_url_param_generator = doc_link_url_param_generator - def __init__( self, *,