Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(maidr.show): support py-shiny renderer #67

Merged
merged 7 commits into from
Aug 24, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions example/py-shiny/example_pyshiny_reactive_scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import matplotlib.pyplot as plt
import seaborn as sns
from shiny import App, ui

from maidr.widget.shiny import render_maidr

# Load the dataset
iris = sns.load_dataset("iris")

# Define the UI components for the Shiny application
app_ui = ui.page_fluid(
ui.row(
ui.column(
3,
ui.input_select(
"x_var",
"Select X variable:",
choices=iris.select_dtypes(include=["float64"]).columns.tolist(),
selected="sepal_length",
),
ui.input_select(
"y_var",
"Select Y variable:",
choices=iris.select_dtypes(include=["float64"]).columns.tolist(),
selected="sepal_width",
),
),
ui.column(9, ui.output_ui("create_reactivebarplot")),
)
)


# Define the server
def server(input, output, session):
@render_maidr
def create_reactivebarplot():
fig, ax = plt.subplots(figsize=(10, 6))
s_plot = sns.scatterplot(
data=iris, x=input.x_var(), y=input.y_var(), hue="species", ax=ax
)
ax.set_title(f"Iris {input.y_var()} vs {input.x_var()}")
ax.set_xlabel(input.x_var().replace("_", " ").title())
ax.set_ylabel(input.y_var().replace("_", " ").title())
return s_plot


# Create the app
app = App(app_ui, server)

# Run the app
if __name__ == "__main__":
app.run()
3 changes: 2 additions & 1 deletion maidr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
lineplot,
scatterplot,
)
from .api import close, save_html, show, stacked
from .api import close, render, save_html, show, stacked

__all__ = [
"close",
"render",
"save_html",
"show",
"stacked",
Expand Down
9 changes: 9 additions & 0 deletions maidr/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Literal, Any

from htmltools import Tag
from matplotlib.axes import Axes
from matplotlib.container import BarContainer

Expand All @@ -10,6 +11,14 @@
from maidr.core.figure_manager import FigureManager


def render(
plot: Any, *, lib_prefix: str | None = "lib", include_version: bool = True
) -> Tag:
ax = FigureManager.get_axes(plot)
maidr = FigureManager.get_maidr(ax.get_figure())
return maidr.render()


def show(plot: Any, renderer: Literal["auto", "ipython", "browser"] = "auto") -> object:
ax = FigureManager.get_axes(plot)
maidr = FigureManager.get_maidr(ax.get_figure())
Expand Down
55 changes: 21 additions & 34 deletions maidr/core/maidr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
import json
import uuid

from htmltools import HTML, HTMLDocument, RenderedHTML, tags, Tag
from htmltools import HTML, HTMLDocument, Tag, tags
from lxml import etree

from matplotlib.figure import Figure

from maidr.core.context_manager import HighlightContextManager
from maidr.core.plot import MaidrPlot
from maidr.utils.environment import Environment


class Maidr:
Expand Down Expand Up @@ -52,21 +51,9 @@ def plots(self) -> list[MaidrPlot]:
"""Return the list of plots extracted from the ``fig``."""
return self._plots

def render(
self, *, lib_prefix: str | None = "lib", include_version: bool = True
) -> RenderedHTML:
"""
Render the document.

Parameters
----------
lib_prefix : str, default="lib"
A prefix to add to relative paths to dependency files.
include_version : bool, default=True
Whether to include the version number in the dependency's folder name.
"""
html = self._create_html_doc()
return html.render(lib_prefix=lib_prefix, include_version=include_version)
def render(self) -> Tag:
"""Return the maidr plot inside an iframe."""
return self._create_html_tag()

def save_html(
self, file: str, *, lib_dir: str | None = "lib", include_version: bool = True
Expand All @@ -79,7 +66,8 @@ def save_html(
file : str
The file to save to.
lib_dir : str, default="lib"
The directory to save the dependencies to (relative to the file's directory).
The directory to save the dependencies to
(relative to the file's directory).
include_version : bool, default=True
Whether to include the version number in the dependency folder name.
"""
Expand Down Expand Up @@ -163,7 +151,6 @@ def _unique_id() -> str:
@staticmethod
def _inject_plot(plot: HTML, maidr: str) -> Tag:
"""Embed the plot and associated MAIDR scripts into the HTML structure."""

base_html = tags.html(
tags.head(
tags.meta(charset="UTF-8"),
Expand All @@ -183,20 +170,20 @@ def _inject_plot(plot: HTML, maidr: str) -> Tag:
tags.script(maidr),
)

if Environment.is_interactive_shell():
# If running in an interactive environment (e.g., Jupyter Notebook),
# display the HTML content using an iframe to ensure proper rendering
# and interactivity. The iframe's height is dynamically adjusted
base_html = tags.iframe(
srcdoc=str(base_html.get_html_string()),
width="100%",
height="100%",
scrolling="auto",
style="background-color: #fff",
frameBorder=0,
onload="""
this.style.height = this.contentWindow.document.body.scrollHeight + 100 + 'px';
""",
)
# If running in an interactive environment (e.g., Jupyter Notebook),
# display the HTML content using an iframe to ensure proper rendering
# and interactivity. The iframe's height is dynamically adjusted
base_html = tags.iframe(
srcdoc=str(base_html.get_html_string()),
width="100%",
height="100%",
scrolling="auto",
style="background-color: #fff",
frameBorder=0,
onload="""
this.style.height = this.contentWindow.document.body.scrollHeight +
100 + 'px';
""",
)

return base_html
4 changes: 1 addition & 3 deletions maidr/patch/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@


@wrapt.patch_function_wrapper(Axes, "hist")
def mpl_hist(
wrapped, _, args, kwargs
) -> tuple[
def mpl_hist(wrapped, _, args, kwargs) -> tuple[
np.ndarray | list[np.ndarray],
np.ndarray,
BarContainer | Polygon | list[BarContainer | Polygon],
Expand Down
29 changes: 29 additions & 0 deletions maidr/widget/shiny.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from __future__ import annotations

from shiny.render import ui
from shiny.types import Jsonifiable

import maidr


class render_maidr(ui):
Copy link
Member

@jooyoungseo jooyoungseo Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use PascalCase for class name in our package. Why do you use snake_case here? Is this the Shiny convention?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked Shiny and there are no such conventions. I have renamed the class name to adhere to PascalCase.

"""
A custom UI rendering class for Maidr objects in Shiny applications.

This class extends the Shiny UI rendering functionality to handle Maidr objects.

Methods
-------
render()
Asynchronously renders the Maidr object.
"""

async def render(self) -> Jsonifiable:
"""Return maidr rendered object for a given plot."""
initial_value = await self.fn()
if initial_value is None:
return None

maidr_rendered = maidr.render(initial_value)
transformed = await self.transform(maidr_rendered)
return transformed
Loading