diff --git a/vizro-core/src/vizro/actions/_action_loop/_action_loop.py b/vizro-core/src/vizro/actions/_action_loop/_action_loop.py index d9caac0ff..2d8edefa9 100644 --- a/vizro-core/src/vizro/actions/_action_loop/_action_loop.py +++ b/vizro-core/src/vizro/actions/_action_loop/_action_loop.py @@ -1,10 +1,14 @@ """The action loop creates all the required action callbacks and its components.""" +from collections.abc import Iterable +from typing import cast + from dash import html -from vizro.actions._action_loop._action_loop_utils import _get_actions_on_registered_pages from vizro.actions._action_loop._build_action_loop_callbacks import _build_action_loop_callbacks from vizro.actions._action_loop._get_action_loop_components import _get_action_loop_components +from vizro.managers import model_manager +from vizro.models import Action class ActionLoop: @@ -37,5 +41,8 @@ def _build_actions_models(): List of required components for each `Action` in the `Dashboard` e.g. list[dcc.Download] """ - actions = _get_actions_on_registered_pages() - return html.Div([action.build() for action in actions], id="app_action_models_components_div", hidden=True) + return html.Div( + [action.build() for action in cast(Iterable[Action], model_manager._get_models(Action))], + id="app_action_models_components_div", + hidden=True, + ) diff --git a/vizro-core/src/vizro/actions/_action_loop/_action_loop_utils.py b/vizro-core/src/vizro/actions/_action_loop/_action_loop_utils.py deleted file mode 100644 index eefc22f6f..000000000 --- a/vizro-core/src/vizro/actions/_action_loop/_action_loop_utils.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Contains utilities to extract the Action and ActionsChain models from registered pages only.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -import dash - -from vizro.managers import model_manager - -if TYPE_CHECKING: - from vizro.models import Action, Page - from vizro.models._action._actions_chain import ActionsChain - - -def _get_actions_chains_on_all_pages() -> list[ActionsChain]: - from vizro.models._action._actions_chain import ActionsChain - - """Gets list of ActionsChain models for registered pages.""" - actions_chains: list[ActionsChain] = [] - # TODO: once dash.page_registry matches up with model_manager, change this to use purely model_manager. - # Making the change now leads to problems since there can be Action models defined that aren't used in the - # dashboard. - # See https://github.com/mckinsey/vizro/pull/366. - # TODO NOW: try to change this - for registered_page in dash.page_registry.values(): - try: - page: Page = model_manager[registered_page["module"]] - except KeyError: - continue - actions_chains.extend(model_manager._get_models(ActionsChain, page)) - return actions_chains - - -def _get_actions_on_registered_pages() -> list[Action]: - """Gets list of Action models for registered pages.""" - return [action for action_chain in _get_actions_chains_on_all_pages() for action in action_chain.actions] diff --git a/vizro-core/src/vizro/actions/_action_loop/_build_action_loop_callbacks.py b/vizro-core/src/vizro/actions/_action_loop/_build_action_loop_callbacks.py index bc76a146d..439c443e1 100644 --- a/vizro-core/src/vizro/actions/_action_loop/_build_action_loop_callbacks.py +++ b/vizro-core/src/vizro/actions/_action_loop/_build_action_loop_callbacks.py @@ -4,20 +4,20 @@ from dash import ClientsideFunction, Input, Output, State, clientside_callback -from vizro.actions._action_loop._action_loop_utils import ( - _get_actions_chains_on_all_pages, - _get_actions_on_registered_pages, -) from vizro.managers import model_manager from vizro.managers._model_manager import ModelID +from vizro.models import Action +from vizro.models._action._actions_chain import ActionsChain logger = logging.getLogger(__name__) def _build_action_loop_callbacks() -> None: """Creates all required dash callbacks for the action loop.""" - actions_chains = _get_actions_chains_on_all_pages() - actions = _get_actions_on_registered_pages() + # actions_chain and actions are not iterated over multiple times so conversion to list is not technically needed, + # but it prevents future bugs and matches _get_action_loop_components. + actions_chains: list[ActionsChain] = list(model_manager._get_models(ActionsChain)) + actions: list[Action] = list(model_manager._get_models(Action)) if not actions_chains: return diff --git a/vizro-core/src/vizro/actions/_action_loop/_get_action_loop_components.py b/vizro-core/src/vizro/actions/_action_loop/_get_action_loop_components.py index 7d34c2a4a..2d18c18df 100644 --- a/vizro-core/src/vizro/actions/_action_loop/_get_action_loop_components.py +++ b/vizro-core/src/vizro/actions/_action_loop/_get_action_loop_components.py @@ -2,10 +2,9 @@ from dash import dcc, html -from vizro.actions._action_loop._action_loop_utils import ( - _get_actions_chains_on_all_pages, - _get_actions_on_registered_pages, -) +from vizro.managers import model_manager +from vizro.models import Action +from vizro.models._action._actions_chain import ActionsChain def _get_action_loop_components() -> html.Div: @@ -15,8 +14,9 @@ def _get_action_loop_components() -> html.Div: List of dcc or html components. """ - actions_chains = _get_actions_chains_on_all_pages() - actions = _get_actions_on_registered_pages() + # actions_chain and actions are iterated over multiple times so must be realized into a list. + actions_chains: list[ActionsChain] = list(model_manager._get_models(ActionsChain)) + actions: list[Action] = list(model_manager._get_models(Action)) if not actions_chains: return html.Div(id="action_loop_components_div") diff --git a/vizro-core/src/vizro/actions/_actions_utils.py b/vizro-core/src/vizro/actions/_actions_utils.py index 7e84be8ee..984359fc9 100644 --- a/vizro-core/src/vizro/actions/_actions_utils.py +++ b/vizro-core/src/vizro/actions/_actions_utils.py @@ -82,6 +82,8 @@ def _apply_filter_controls( def _get_parent_model(_underlying_callable_object_id: str) -> VizroBaseModel: + from vizro.models import VizroBaseModel + for model in cast(Iterable[VizroBaseModel], model_manager._get_models()): if hasattr(model, "_input_component_id") and model._input_component_id == _underlying_callable_object_id: return model diff --git a/vizro-core/src/vizro/managers/_model_manager.py b/vizro-core/src/vizro/managers/_model_manager.py index fe140e98c..d9682a521 100644 --- a/vizro-core/src/vizro/managers/_model_manager.py +++ b/vizro-core/src/vizro/managers/_model_manager.py @@ -59,7 +59,7 @@ def _get_models( ) -> Generator[Model, None, None]: """Iterates through all models of type `model_type` (including subclasses). - If `model_type` not given then look at all models. If `page_id` specified then only give models from that page. + If `model_type` not given then look at all models. If `page` specified then only give models from that page. """ models = self._get_model_children(page) if page is not None else self.__models.values() @@ -67,8 +67,6 @@ def _get_models( if model_type is None or isinstance(model, model_type): yield model - # TODO: Consider returning with yield - # TODO NOW: Make brief comments on how in future it should work to use Page (or maybe Dashboard) as key primitive. def _get_model_children(self, model: Model) -> Generator[Model, None, None]: from vizro.models import VizroBaseModel @@ -91,7 +89,8 @@ def _get_model_children(self, model: Model) -> Generator[Model, None, None]: # TODO: Add navigation, accordions and other page objects. Won't be needed once have made whole model # manager work better recursively and have better ways to navigate the hierarchy. In pydantic v2 this would use - # model_fields. + # model_fields. Maybe we'd also use Page (or sometimes Dashboard) as the central model for navigating the + # hierarchy rather than it being so generic. def _get_model_page(self, model: Model) -> Page: # type: ignore[return] """Gets the id of the page containing the model with "model_id".""" diff --git a/vizro-core/tests/unit/vizro/models/_controls/test_filter.py b/vizro-core/tests/unit/vizro/models/_controls/test_filter.py index 5ec6e9749..c645ebd45 100644 --- a/vizro-core/tests/unit/vizro/models/_controls/test_filter.py +++ b/vizro-core/tests/unit/vizro/models/_controls/test_filter.py @@ -445,6 +445,8 @@ def test_categorical_options_specific(self, selector, managers_one_page_two_grap filter.pre_build() assert filter.selector.options == ["Africa", "Europe"] + # Use lambda to create test_selector only in the test itself rather than in the parameters, so that it's in the + # model_manager for the test. @pytest.mark.parametrize( "filtered_column, selector, filter_function", [ @@ -504,6 +506,8 @@ def build(self): class TestFilterBuild: """Tests filter build method.""" + # Use lambda to create test_selector only in the test itself rather than in the parameters, so that it's in the + # model_manager for the test. @pytest.mark.parametrize( "test_column,test_selector", [