Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA add link for HTML representation #1051

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 2 additions & 11 deletions skrub/_gap_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,12 @@

from . import _dataframe as sbd
from ._on_each_column import RejectColumn, SingleColumnTransformer
from ._repr import (
_HTMLDocumentationLinkMixin,
doc_link_module,
doc_link_template,
doc_link_url_param_generator,
)
from ._repr import _SkrubHTMLDocumentationLinkMixin
from ._utils import unique_strings


class GapEncoder(
_HTMLDocumentationLinkMixin, TransformerMixin, SingleColumnTransformer
_SkrubHTMLDocumentationLinkMixin, TransformerMixin, SingleColumnTransformer
):
"""Constructs latent topics with continuous encoding.

Expand Down Expand Up @@ -185,10 +180,6 @@ class GapEncoder(
The higher the value, the bigger the correspondence with the topic.
"""

_doc_link_module = doc_link_module
_doc_link_template = doc_link_template
_doc_link_url_param_generator = doc_link_url_param_generator

def __init__(
self,
n_components=10,
Expand Down
12 changes: 9 additions & 3 deletions skrub/_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

sklearn_version = parse_version(sklearn.__version__)

# TODO: remove when scikit-learn 1.6 is the minimum supported version
# TODO: subsequently, we should remove the inheritance from _HTMLDocumentationLinkMixin
# for each estimator then.
# TODO: remove when scikit-learn 1.6 is the minimum supported version and only import
# We have this fix due to the following bug:
# https://github.com/scikit-learn/scikit-learn/pull/29774
if sklearn_version > parse_version("1.6"):
from sklearn.utils._estimator_html_repr import _HTMLDocumentationLinkMixin

Check warning on line 12 in skrub/_repr.py

View check run for this annotation

Codecov / codecov/patch

skrub/_repr.py#L12

Added line #L12 was not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC we will now have 2 of those in every skrub estimator's parents, one directly and one through the BaseEstimator?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. The one on the left in the inheritance will be the one used.

else:

class _HTMLDocumentationLinkMixin:
Expand All @@ -20,12 +20,12 @@

@property
def _doc_link_template(self):
sklearn_version = parse_version(sklearn.__version__)

Check warning on line 23 in skrub/_repr.py

View check run for this annotation

Codecov / codecov/patch

skrub/_repr.py#L23

Added line #L23 was not covered by tests
if sklearn_version.dev is None:
version_url = f"{sklearn_version.major}.{sklearn_version.minor}"

Check warning on line 25 in skrub/_repr.py

View check run for this annotation

Codecov / codecov/patch

skrub/_repr.py#L25

Added line #L25 was not covered by tests
else:
version_url = "dev"
return getattr(

Check warning on line 28 in skrub/_repr.py

View check run for this annotation

Codecov / codecov/patch

skrub/_repr.py#L27-L28

Added lines #L27 - L28 were not covered by tests
self,
"__doc_link_template",
(
Expand All @@ -36,7 +36,7 @@

@_doc_link_template.setter
def _doc_link_template(self, value):
setattr(self, "__doc_link_template", value)

Check warning on line 39 in skrub/_repr.py

View check run for this annotation

Codecov / codecov/patch

skrub/_repr.py#L39

Added line #L39 was not covered by tests

def _get_doc_link(self):
"""Generates a link to the API documentation for a given estimator.
Expand All @@ -52,10 +52,10 @@
`""`) is returned.
"""
if self.__class__.__module__.split(".")[0] != self._doc_link_module:
return ""

Check warning on line 55 in skrub/_repr.py

View check run for this annotation

Codecov / codecov/patch

skrub/_repr.py#L55

Added line #L55 was not covered by tests

if self._doc_link_url_param_generator is None:
estimator_name = self.__class__.__name__

Check warning on line 58 in skrub/_repr.py

View check run for this annotation

Codecov / codecov/patch

skrub/_repr.py#L58

Added line #L58 was not covered by tests
# Construct the estimator's module name, up to the first private
# submodule. This works because in scikit-learn all public estimators
# are exposed at that level, even if they actually live in a private
Expand All @@ -66,10 +66,10 @@
self.__class__.__module__.split("."),
)
)
return self._doc_link_template.format(

Check warning on line 69 in skrub/_repr.py

View check run for this annotation

Codecov / codecov/patch

skrub/_repr.py#L69

Added line #L69 was not covered by tests
estimator_module=estimator_module, estimator_name=estimator_name
)
return self._doc_link_template.format(

Check warning on line 72 in skrub/_repr.py

View check run for this annotation

Codecov / codecov/patch

skrub/_repr.py#L72

Added line #L72 was not covered by tests
**self._doc_link_url_param_generator()
)

Expand All @@ -82,22 +82,28 @@


def doc_link_url_param_generator(estimator):
from skrub import __version__

Check warning on line 85 in skrub/_repr.py

View check run for this annotation

Codecov / codecov/patch

skrub/_repr.py#L85

Added line #L85 was not covered by tests

skrub_version = parse_version(__version__)

Check warning on line 87 in skrub/_repr.py

View check run for this annotation

Codecov / codecov/patch

skrub/_repr.py#L87

Added line #L87 was not covered by tests
if skrub_version.dev is None:
version_url = f"{skrub_version.major}.{skrub_version.minor}"

Check warning on line 89 in skrub/_repr.py

View check run for this annotation

Codecov / codecov/patch

skrub/_repr.py#L89

Added line #L89 was not covered by tests
else:
version_url = "dev"
estimator_name = estimator.__class__.__name__

Check warning on line 92 in skrub/_repr.py

View check run for this annotation

Codecov / codecov/patch

skrub/_repr.py#L91-L92

Added lines #L91 - L92 were not covered by tests
estimator_module = ".".join(
itertools.takewhile(
lambda part: not part.startswith("_"),
estimator.__class__.__module__.split("."),
)
)
return {

Check warning on line 99 in skrub/_repr.py

View check run for this annotation

Codecov / codecov/patch

skrub/_repr.py#L99

Added line #L99 was not covered by tests
"version": version_url,
"estimator_module": estimator_module,
"estimator_name": estimator_name,
}


class _SkrubHTMLDocumentationLinkMixin(_HTMLDocumentationLinkMixin):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even when the scikit-learn fix is released, we will most likely want this mixin to avoid repeating these 3 attributes everywhere. so in the end all we will have to remove is the HTMLDocumentationLinkMixin which is confined to this module.

so after all I am +1 for this PR

_doc_link_template = doc_link_template
_doc_link_module = doc_link_module
_doc_link_url_param_generator = doc_link_url_param_generator
15 changes: 4 additions & 11 deletions skrub/_table_vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@
from ._datetime_encoder import DatetimeEncoder
from ._gap_encoder import GapEncoder
from ._on_each_column import SingleColumnTransformer
from ._repr import (
_HTMLDocumentationLinkMixin,
doc_link_module,
doc_link_template,
doc_link_url_param_generator,
)
from ._repr import _SkrubHTMLDocumentationLinkMixin
from ._select_cols import Drop
from ._to_datetime import ToDatetime
from ._to_float32 import ToFloat32
Expand Down Expand Up @@ -116,7 +111,9 @@ def _check_transformer(transformer):
return clone(transformer)


class TableVectorizer(_HTMLDocumentationLinkMixin, TransformerMixin, BaseEstimator):
class TableVectorizer(
_SkrubHTMLDocumentationLinkMixin, TransformerMixin, BaseEstimator
):
"""Transform a dataframe to a numerical (vectorized) representation.

Applies a different transformation to each of several kinds of columns:
Expand Down Expand Up @@ -411,10 +408,6 @@ class TableVectorizer(_HTMLDocumentationLinkMixin, TransformerMixin, BaseEstimat
ValueError: Column 'A' used twice in 'specific_transformers', at indices 0 and 1.
""" # noqa: E501

_doc_link_module = doc_link_module
_doc_link_template = doc_link_template
_doc_link_url_param_generator = doc_link_url_param_generator

def __init__(
self,
*,
Expand Down