Skip to content

Commit

Permalink
Surprisingly big refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
antonymilne committed Nov 27, 2024
1 parent 2b90dcf commit 3b77712
Show file tree
Hide file tree
Showing 16 changed files with 188 additions and 209 deletions.
100 changes: 54 additions & 46 deletions vizro-core/examples/scratch_dev/app.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,68 @@
"""Dev app to try things out."""
from typing import List, Literal

from dash import html

import vizro.models as vm
import vizro.plotly.express as px
from vizro import Vizro
from vizro.models.types import ControlType

df_gapminder = px.data.gapminder()


class ControlGroup(vm.VizroBaseModel):
"""Container to group controls."""

type: Literal["control_group"] = "control_group"
title: str
controls: List[ControlType] = []

def build(self):
return html.Div(
[html.H4(self.title), html.Hr()] + [control.build() for control in self.controls],
className="control_group_container",
)

gapminder_2007 = px.data.gapminder().query("year == 2007")

page = vm.Page(
title="Tabs",
vm.Page.add_type("controls", ControlGroup)

page1 = vm.Page(
title="Relationship Analysis",
components=[
vm.Tabs(
tabs=[
vm.Container(
title="Tab I",
components=[
vm.Graph(
title="Graph I",
figure=px.bar(
gapminder_2007,
x="continent",
y="lifeExp",
color="continent",
),
),
vm.Graph(
title="Graph II",
figure=px.box(
gapminder_2007,
x="continent",
y="lifeExp",
color="continent",
),
),
],
vm.Graph(id="scatter", figure=px.scatter(df_gapminder, x="gdpPercap", y="lifeExp", size="pop")),
],
controls=[
ControlGroup(
title="Group A",
controls=[
vm.Parameter(
id="this",
targets=["scatter.x"],
selector=vm.Dropdown(
options=["lifeExp", "gdpPercap", "pop"], multi=False, value="gdpPercap", title="Choose x-axis"
),
),
vm.Container(
title="Tab II",
components=[
vm.Graph(
title="Graph III",
figure=px.scatter(
gapminder_2007,
x="gdpPercap",
y="lifeExp",
size="pop",
color="continent",
),
),
],
vm.Parameter(
targets=["scatter.y"],
selector=vm.Dropdown(
options=["lifeExp", "gdpPercap", "pop"], multi=False, value="lifeExp", title="Choose y-axis"
),
),
],
),
ControlGroup(
title="Group B",
controls=[
vm.Parameter(
targets=["scatter.size"],
selector=vm.Dropdown(
options=["lifeExp", "gdpPercap", "pop"], multi=False, value="pop", title="Choose bubble size"
),
)
],
),
],
)

dashboard = vm.Dashboard(pages=[page])

if __name__ == "__main__":
Vizro().build(dashboard).run()
dashboard = vm.Dashboard(pages=[page1])
Vizro().build(dashboard).run()
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,21 @@


def _get_actions_chains_on_all_pages() -> list[ActionsChain]:
from vizro.models._action._actions_chain import ActionsChain

"""Gets list of ActionsChain models for registered pages."""
actions_chains: list[ActionsChain] = []
# TODO: once dash.page_registry matches up with model_manager, change this to use purely model_manager.
# Making the change now leads to problems since there can be Action models defined that aren't used in the
# dashboard.
# See https://github.com/mckinsey/vizro/pull/366.
# TODO NOW: try to change this
for registered_page in dash.page_registry.values():
try:
page: Page = model_manager[registered_page["module"]]
except KeyError:
continue
actions_chains.extend(model_manager._get_page_actions_chains(page_id=ModelID(str(page.id))))
actions_chains.extend(model_manager._get_models(ActionsChain, page))
return actions_chains


Expand Down
13 changes: 4 additions & 9 deletions vizro-core/src/vizro/actions/_actions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,10 @@ def _apply_filter_controls(
return data_frame


def _get_parent_vizro_model(_underlying_callable_object_id: str) -> VizroBaseModel:
from vizro.models import VizroBaseModel

for _, vizro_base_model in model_manager._items_with_type(VizroBaseModel):
if (
hasattr(vizro_base_model, "_input_component_id")
and vizro_base_model._input_component_id == _underlying_callable_object_id
):
return vizro_base_model
def _get_parent_model(_underlying_callable_object_id: str) -> VizroBaseModel:
for model in model_manager._get_models():
if hasattr(model, "_input_component_id") and model._input_component_id == _underlying_callable_object_id:
return model
raise KeyError(
f"No parent Vizro model found for underlying callable object with id: {_underlying_callable_object_id}."
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,20 @@
from vizro.actions import _parameter, filter_interaction
from vizro.managers import model_manager
from vizro.managers._model_manager import ModelID
from vizro.models import Action, Page
from vizro.models import Action, Page, Graph, AgGrid, Table, Figure, VizroBaseModel
from vizro.models._action._actions_chain import ActionsChain
from vizro.models._controls import Filter, Parameter
from vizro.models.types import ControlType


# This function can also be reused for all other inputs (filters, parameters).
# Potentially this could be a way to reconcile predefined with custom actions,
# and make that predefined actions see and add into account custom actions.
def _get_matching_actions_by_function(
page_id: ModelID, action_function: Callable[[Any], dict[str, Any]]
) -> list[Action]:
def _get_matching_actions_by_function(page: Page, action_function: Callable[[Any], dict[str, Any]]) -> list[Action]:
"""Gets list of `Actions` on triggered `Page` that match the provided `action_function`."""
return [
action
for actions_chain in model_manager._get_page_actions_chains(page_id=page_id)
for actions_chain in model_manager._get_models(ActionsChain, page)
for action in actions_chain.actions
if action.function._function == action_function
]
Expand All @@ -32,21 +31,27 @@ def _get_inputs_of_controls(page: Page, control_type: ControlType) -> list[State
"""Gets list of `States` for selected `control_type` of triggered `Page`."""
return [
State(component_id=control.selector.id, component_property=control.selector._input_property)
for control in page.controls
if isinstance(control, control_type)
for control in model_manager._get_models(control_type, page)
]


def _get_action_trigger(action: Action) -> VizroBaseModel: # type: ignore[return]
"""Gets the model that triggers the action with "action_id"."""
from vizro.models._action._actions_chain import ActionsChain

for actions_chain in model_manager._get_models(ActionsChain):
if action in actions_chain.actions:
return model_manager[ModelID(str(actions_chain.trigger.component_id))]


def _get_inputs_of_figure_interactions(
page: Page, action_function: Callable[[Any], dict[str, Any]]
) -> list[dict[str, State]]:
"""Gets list of `States` for selected chart interaction `action_function` of triggered `Page`."""
figure_interactions_on_page = _get_matching_actions_by_function(
page_id=ModelID(str(page.id)), action_function=action_function
)
figure_interactions_on_page = _get_matching_actions_by_function(page=page, action_function=action_function)
inputs = []
for action in figure_interactions_on_page:
triggered_model = model_manager._get_action_trigger(action_id=ModelID(str(action.id)))
triggered_model = _get_action_trigger(action)
required_attributes = ["_filter_interaction_input", "_filter_interaction"]
for attribute in required_attributes:
if not hasattr(triggered_model, attribute):
Expand All @@ -60,9 +65,9 @@ def _get_inputs_of_figure_interactions(


# TODO: Refactor this and util functions once we implement "_get_input_property" method in VizroBaseModel models
def _get_action_callback_inputs(action_id: ModelID) -> dict[str, list[Union[State, dict[str, State]]]]:
def _get_action_callback_inputs(action: Action) -> dict[str, list[Union[State, dict[str, State]]]]:
"""Creates mapping of pre-defined action names and a list of `States`."""
page: Page = model_manager[model_manager._get_model_page_id(model_id=action_id)]
page = model_manager._get_model_page(action)

action_input_mapping = {
"filters": _get_inputs_of_controls(page=page, control_type=Filter),
Expand All @@ -76,17 +81,17 @@ def _get_action_callback_inputs(action_id: ModelID) -> dict[str, list[Union[Stat


# CALLBACK OUTPUTS --------------
def _get_action_callback_outputs(action_id: ModelID) -> dict[str, Output]:
def _get_action_callback_outputs(action: Action) -> dict[str, Output]:
"""Creates mapping of target names and their `Output`."""
action_function = model_manager[action_id].function._function
action_function = action.function._function

# The right solution for mypy here is to not e.g. define new attributes on the base but instead to get mypy to
# recognize that model_manager[action_id] is of type Action and hence has the function attribute.
# Ideally model_manager.__getitem__ would handle this itself, possibly with suitable use of a cast.
# If not then we can do the cast to Action at the point of consumption here to avoid needing mypy ignores.

try:
targets = model_manager[action_id].function["targets"]
targets = action.function["targets"]
except KeyError:
targets = []

Expand All @@ -103,45 +108,41 @@ def _get_action_callback_outputs(action_id: ModelID) -> dict[str, Output]:
}


def _get_export_data_callback_outputs(action_id: ModelID) -> dict[str, Output]:
def _get_export_data_callback_outputs(action: Action) -> dict[str, Output]:
"""Gets mapping of relevant output target name and `Outputs` for `export_data` action."""
action = model_manager[action_id]

try:
targets = action.function["targets"]
except KeyError:
targets = None

if not targets:
targets = model_manager._get_page_model_ids_with_figure(
page_id=model_manager._get_model_page_id(model_id=action_id)
)
targets = targets or [
model.id
for model in model_manager._get_models((Graph, AgGrid, Table, Figure), model_manager._get_model_page(action))
]

return {
f"download_dataframe_{target}": Output(
component_id={"type": "download_dataframe", "action_id": action_id, "target_id": target},
component_id={"type": "download_dataframe", "action_id": action.id, "target_id": target},
component_property="data",
)
for target in targets
}


# CALLBACK COMPONENTS --------------
def _get_export_data_callback_components(action_id: ModelID) -> list[dcc.Download]:
def _get_export_data_callback_components(action: Action) -> list[dcc.Download]:
"""Creates dcc.Downloads for target components of the `export_data` action."""
action = model_manager[action_id]

try:
targets = action.function["targets"]
except KeyError:
targets = None

if not targets:
targets = model_manager._get_page_model_ids_with_figure(
page_id=model_manager._get_model_page_id(model_id=action_id)
)
targets = targets or [
model.id
for model in model_manager._get_models((Graph, AgGrid, Table, Figure), model_manager._get_model_page(action))
]

return [
dcc.Download(id={"type": "download_dataframe", "action_id": action_id, "target_id": target})
dcc.Download(id={"type": "download_dataframe", "action_id": action.id, "target_id": target})
for target in targets
]
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@
from vizro.actions._parameter_action import _parameter
from vizro.managers import model_manager
from vizro.managers._model_manager import ModelID
from vizro.models import Action


def _get_action_callback_mapping(
action_id: ModelID, argument: str
) -> Union[list[dcc.Download], dict[str, DashDependency]]:
def _get_action_callback_mapping(action: Action, argument: str) -> Union[list[dcc.Download], dict[str, DashDependency]]:
"""Creates mapping of action name and required callback input/output."""
action_function = model_manager[action_id].function._function
action_function = action.function._function

action_callback_mapping: dict[str, Any] = {
export_data.__wrapped__: {
Expand All @@ -50,4 +49,4 @@ def _get_action_callback_mapping(
}
action_call = action_callback_mapping.get(action_function, {}).get(argument)
default_value: Union[list[dcc.Download], dict[str, DashDependency]] = [] if argument == "components" else {}
return default_value if not action_call else action_call(action_id=action_id)
return default_value if not action_call else action_call(action=action)
Loading

0 comments on commit 3b77712

Please sign in to comment.