Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move update of graph theme to Graph.__call__ #174

Merged
merged 16 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<!--
A new scriv changelog fragment.

Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨

- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Removed

- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Added

- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Changed

- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Deprecated

- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Fixed

- A bullet item for the Fixed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Security

- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
3 changes: 0 additions & 3 deletions vizro-core/src/vizro/actions/_actions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,6 @@ def _get_modified_page_figures(
ctds_filter: List[CallbackTriggerDict],
ctds_filter_interaction: List[Dict[str, CallbackTriggerDict]],
ctds_parameters: List[CallbackTriggerDict],
ctd_theme: CallbackTriggerDict,
targets: Optional[List[ModelID]] = None,
) -> Dict[ModelID, Any]:
if not targets:
Expand All @@ -267,7 +266,5 @@ def _get_modified_page_figures(
outputs[target] = model_manager[target]( # type: ignore[operator]
data_frame=filtered_data[target], **parameterized_config[target]
)
if hasattr(outputs[target], "update_layout"):
outputs[target].update_layout(template="vizro_dark" if ctd_theme["value"] else "vizro_light")

return outputs
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _get_action_callback_inputs(action_id: ModelID) -> Dict[str, Any]:
if "filter_interaction" in include_inputs
else []
),
"theme_selector": (State("theme_selector", "on") if "theme_selector" in include_inputs else []),
"theme_selector": State("theme_selector", "on") if "theme_selector" in include_inputs else [],
antonymilne marked this conversation as resolved.
Show resolved Hide resolved
}
return action_input_mapping

Expand Down
1 change: 0 additions & 1 deletion vizro-core/src/vizro/actions/_filter_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,4 @@ def _filter(
ctds_filter=ctx.args_grouping["filters"],
ctds_filter_interaction=ctx.args_grouping["filter_interaction"],
ctds_parameters=ctx.args_grouping["parameters"],
ctd_theme=ctx.args_grouping["theme_selector"],
)
1 change: 0 additions & 1 deletion vizro-core/src/vizro/actions/_on_page_load_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,4 @@ def _on_page_load(page_id: ModelID, **inputs: Dict[str, Any]) -> Dict[ModelID, A
ctds_filter=ctx.args_grouping["filters"],
ctds_filter_interaction=ctx.args_grouping["filter_interaction"],
ctds_parameters=ctx.args_grouping["parameters"],
ctd_theme=ctx.args_grouping["theme_selector"],
)
1 change: 0 additions & 1 deletion vizro-core/src/vizro/actions/_parameter_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,4 @@ def _parameter(targets: List[str], **inputs: Dict[str, Any]) -> Dict[ModelID, An
ctds_filter=ctx.args_grouping["filters"],
ctds_filter_interaction=ctx.args_grouping["filter_interaction"],
ctds_parameters=ctx.args_grouping["parameters"],
ctd_theme=ctx.args_grouping["theme_selector"],
)
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,4 @@ def filter_interaction(
ctds_filter=ctx.args_grouping["filters"],
ctds_filter_interaction=ctx.args_grouping["filter_interaction"],
ctds_parameters=ctx.args_grouping["parameters"],
ctd_theme=ctx.args_grouping["theme_selector"],
)
21 changes: 20 additions & 1 deletion vizro-core/src/vizro/models/_components/graph.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import logging
from typing import List, Literal

from dash import dcc
from dash import ctx, dcc
from dash.exceptions import MissingCallbackContextException
from plotly import graph_objects as go
from pydantic import Field, PrivateAttr, validator

import vizro.plotly.express as px
from vizro import _themes as themes
from vizro.managers import data_manager
from vizro.models import Action, VizroBaseModel
from vizro.models._action._actions_chain import _action_validator_factory
Expand Down Expand Up @@ -44,6 +46,16 @@ def __call__(self, **kwargs):
# Remove top margin if title is provided
if fig.layout.title.text is None:
fig.update_layout(margin_t=24)

# Possibly we should enforce that __call__ can only be used within the context of a callback, but it's easy
# to just swallow up the error here as it doesn't cause any problems.
try:
# At the moment theme_selector is always present so this if statement is redundant, but possibly in
# future we'll have callbacks that do Graph.__call__() without theme_selector set.
if "theme_selector" in ctx.args_grouping:
fig = self._update_theme(fig, ctx.args_grouping["theme_selector"]["value"])
except MissingCallbackContextException:
logger.info("fig.update_layout called outside of callback context.")
antonymilne marked this conversation as resolved.
Show resolved Hide resolved
return fig

# Convenience wrapper/syntactic sugar.
Expand Down Expand Up @@ -81,3 +93,10 @@ def build(self):
color="grey",
parent_className="loading-container",
)

@staticmethod
def _update_theme(fig: go.Figure, theme_selector: bool):
# Basically the same as doing fig.update_layout(template="vizro_light/dark") but works for both the call in
# self.__call__ and in the update_graph_theme callback.
fig["layout"]["template"] = themes.dark if theme_selector else themes.light
return fig
antonymilne marked this conversation as resolved.
Show resolved Hide resolved
35 changes: 18 additions & 17 deletions vizro-core/src/vizro/models/_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
from dash import Input, Output, Patch, callback, dcc, html
from pydantic import Field, root_validator, validator

import vizro._themes as themes
from vizro._constants import ON_PAGE_LOAD_ACTION_PREFIX
from vizro.actions import _on_page_load
from vizro.managers._model_manager import DuplicateIDError
from vizro.models import Action, Graph, Layout, VizroBaseModel
from vizro.models import Action, Layout, VizroBaseModel
from vizro.models._action._actions_chain import ActionsChain, Trigger
from vizro.models._models_utils import _log_call, get_unique_grid_component_ids

Expand Down Expand Up @@ -139,24 +138,26 @@ def _update_graph_theme(self):
# The obvious way to do this would be to alter pio.templates.default, but this changes global state and so is
# not good.
# Putting graphs as inputs here would be a nice way to trigger the theme change automatically so that we don't
# need the update_layout call in Graph.__call__, but this results in an extra callback and the graph

# need the call to _update_theme inside Graph.__call__ also, but this results in an extra callback and the graph
# flickering.
# TODO: consider making this clientside callback and then possibly we can remove the update_layout in
# The code is written to be generic and extensible so that it runs _update_theme on any component with such a
# method defined. But at the moment this just means Graphs.
# TODO: consider making this clientside callback and then possibly we can remove the call to _update_theme in
# Graph.__call__ without any flickering.
# TODO: consider putting the Graph-specific logic here in the Graph model itself (whether clientside or
# serverside) to keep the code here abstract.
outputs = [
Output(component.id, "figure", allow_duplicate=True)
for component in self.components
if isinstance(component, Graph)
]
if outputs:
# TODO: if we do this then we should *consider* defining the callback in Graph itself rather than at Page
# level. This would mean multiple callbacks on one page but if it's clientside that probably doesn't matter.

@callback(outputs, Input("theme_selector", "on"), prevent_initial_call="initial_duplicate")
def update_graph_theme(theme_selector_on: bool):
patched_figure = Patch()
patched_figure["layout"]["template"] = themes.dark if theme_selector_on else themes.light
return [patched_figure] * len(outputs)
themed_components = [component for component in self.components if hasattr(component, "_update_theme")]
if themed_components:

@callback(
[Output(component.id, "figure", allow_duplicate=True) for component in themed_components],
Input("theme_selector", "on"),
prevent_initial_call="initial_duplicate",
)
def update_graph_theme(theme_selector: bool):
return [component._update_theme(Patch(), theme_selector) for component in themed_components]

def _create_component_container(self, components_content):
component_container = html.Div(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def callback_context_on_page_load(request):
"theme_selector": CallbackTriggerDict(
id="theme_selector",
property="on",
value=True if template == "vizro_dark" else False,
value=template == "vizro_dark",
str_id="theme_selector",
triggered=False,
),
Expand Down
70 changes: 35 additions & 35 deletions vizro-core/tests/unit/vizro/models/_components/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
import plotly.graph_objects as go
import pytest
from dash import dcc
from dash._callback_context import context_value
from dash._utils import AttributeDict
from pydantic import ValidationError

import vizro.models as vm
import vizro.plotly.express as px
from vizro.actions._actions_utils import CallbackTriggerDict
from vizro.managers import data_manager
from vizro.models._action._action import Action

Expand All @@ -26,18 +29,6 @@ def standard_px_chart_with_str_dataframe():
)


@pytest.fixture
def expected_empty_chart():
figure = go.Figure()
figure.add_trace(go.Scatter(x=[None], y=[None], showlegend=False, hoverinfo="none"))
figure.update_layout(
xaxis={"visible": False},
yaxis={"visible": False},
annotations=[{"text": "NO DATA", "showarrow": False, "font": {"size": 16}}],
)
return figure


@pytest.fixture
def expected_graph():
return dcc.Loading(
Expand Down Expand Up @@ -74,11 +65,7 @@ def test_create_graph_mandatory_only(self, standard_px_chart):

@pytest.mark.parametrize("id", ["id_1", "id_2"])
def test_create_graph_mandatory_and_optional(self, standard_px_chart, id):
graph = vm.Graph(
figure=standard_px_chart,
id=id,
actions=[],
)
graph = vm.Graph(figure=standard_px_chart, id=id, actions=[])

assert graph.id == id
assert graph.type == "graph"
Expand All @@ -90,9 +77,7 @@ def test_mandatory_figure_missing(self):

def test_failed_graph_with_wrong_figure(self, standard_go_chart):
with pytest.raises(ValidationError, match="must provide a valid CapturedCallable object"):
vm.Graph(
figure=standard_go_chart,
)
vm.Graph(figure=standard_go_chart)

def test_getitem_known_args(self, standard_px_chart):
graph = vm.Graph(figure=standard_px_chart)
Expand All @@ -107,13 +92,34 @@ def test_getitem_unknown_args(self, standard_px_chart):

@pytest.mark.parametrize("title, expected", [(None, 24), ("Test", None)])
def test_title_margin_adjustment(self, gapminder, title, expected):
figure = vm.Graph(figure=px.bar(data_frame=gapminder, x="year", y="pop", title=title)).__call__()

assert figure.layout.margin.t == expected
assert figure.layout.template.layout.margin.t == 64
assert figure.layout.template.layout.margin.l == 80
assert figure.layout.template.layout.margin.b == 64
assert figure.layout.template.layout.margin.r == 12
graph = vm.Graph(figure=px.bar(data_frame=gapminder, x="year", y="pop", title=title)).__call__()

assert graph.layout.margin.t == expected
assert graph.layout.template.layout.margin.t == 64
assert graph.layout.template.layout.margin.l == 80
assert graph.layout.template.layout.margin.b == 64
assert graph.layout.template.layout.margin.r == 12

def test_update_theme_outside_callback(self, standard_px_chart):
graph = vm.Graph(figure=standard_px_chart).__call__()
assert graph == standard_px_chart.update_layout(margin_t=24, template="vizro_dark")

@pytest.mark.parametrize("template", ["vizro_dark", "vizro_light"])
def test_update_theme_inside_callback(self, standard_px_chart, template):
mock_callback_context = {
"args_grouping": {
"theme_selector": CallbackTriggerDict(
id="theme_selector",
property="on",
value=template == "vizro_dark",
str_id="theme_selector",
triggered=False,
)
}
}
context_value.set(AttributeDict(**mock_callback_context))
graph = vm.Graph(figure=standard_px_chart).__call__()
assert graph == standard_px_chart.update_layout(margin_t=24, template=template)

def test_set_action_via_validator(self, standard_px_chart, test_action_function):
graph = vm.Graph(figure=standard_px_chart, actions=[Action(function=test_action_function)])
Expand All @@ -124,18 +130,12 @@ def test_set_action_via_validator(self, standard_px_chart, test_action_function)
class TestProcessFigureDataFrame:
def test_process_figure_data_frame_str_df(self, standard_px_chart_with_str_dataframe, gapminder):
data_manager["gapminder"] = gapminder
graph_with_str_df = vm.Graph(
id="text_graph",
figure=standard_px_chart_with_str_dataframe,
)
graph_with_str_df = vm.Graph(id="text_graph", figure=standard_px_chart_with_str_dataframe)
assert data_manager._get_component_data("text_graph").equals(gapminder)
assert graph_with_str_df["data_frame"] == "gapminder"

def test_process_figure_data_frame_df(self, standard_px_chart, gapminder):
graph_with_df = vm.Graph(
id="text_graph",
figure=standard_px_chart,
)
graph_with_df = vm.Graph(id="text_graph", figure=standard_px_chart)
assert data_manager._get_component_data("text_graph").equals(gapminder)
with pytest.raises(KeyError, match="'data_frame'"):
graph_with_df.figure["data_frame"]
Expand Down
Loading