diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9061d1b59..acccb5e13 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -48,6 +48,7 @@ repos: hooks: - id: ruff args: [--fix] + exclude: "vizro-core/examples/scratch_dev/app.py" - id: ruff-format - repo: https://github.com/PyCQA/bandit diff --git a/vizro-core/changelog.d/20241030_170000_antony.milne_vvv_link_targets.md b/vizro-core/changelog.d/20241030_170000_antony.milne_vvv_link_targets.md new file mode 100644 index 000000000..7c0d58d4f --- /dev/null +++ b/vizro-core/changelog.d/20241030_170000_antony.milne_vvv_link_targets.md @@ -0,0 +1,48 @@ + + + + + + + + + diff --git a/vizro-core/hatch.toml b/vizro-core/hatch.toml index b532cb558..a8165f25e 100644 --- a/vizro-core/hatch.toml +++ b/vizro-core/hatch.toml @@ -111,7 +111,9 @@ VIZRO_LOG_LEVEL = "DEBUG" extra-dependencies = [ "pydantic==1.10.16", "dash==2.17.1", - "plotly==5.12.0" + "plotly==5.12.0", + "pandas==2.0.0", + "numpy==1.23.0" # Need numpy<2 to work with pandas==2.0.0. See https://stackoverflow.com/questions/78634235/. ] features = ["kedro"] python = "3.9" diff --git a/vizro-core/pyproject.toml b/vizro-core/pyproject.toml index aa511adc6..a1992ddc8 100644 --- a/vizro-core/pyproject.toml +++ b/vizro-core/pyproject.toml @@ -18,7 +18,7 @@ dependencies = [ "dash>=2.17.1", # 2.17.1 needed for no_output fix in clientside_callback "dash_bootstrap_components", "dash-ag-grid>=31.0.0", - "pandas", + "pandas>=2", "plotly>=5.12.0", "pydantic>=1.10.16", # must be synced with pre-commit mypy hook manually "dash_mantine_components<0.13.0", # 0.13.0 is not compatible with 0.12, diff --git a/vizro-core/src/vizro/models/_controls/filter.py b/vizro-core/src/vizro/models/_controls/filter.py index a66e8d62d..e70fad98c 100644 --- a/vizro-core/src/vizro/models/_controls/filter.py +++ b/vizro-core/src/vizro/models/_controls/filter.py @@ -1,7 +1,8 @@ from __future__ import annotations -from typing import Literal, Union +from typing import Any, Literal, Union +import numpy as np import pandas as pd from pandas.api.types import is_datetime64_any_dtype, is_numeric_dtype @@ -95,96 +96,43 @@ def check_target_present(cls, target): @_log_call def pre_build(self): - self._set_targets() - self._set_column_type() - self._set_selector() - self._validate_disallowed_selector() - self._set_numerical_and_temporal_selectors_values() - self._set_categorical_selectors_options() - self._set_actions() - - @_log_call - def build(self): - return self.selector.build() - - def _set_targets(self): - if not self.targets: - for component_id in model_manager._get_page_model_ids_with_figure( - page_id=model_manager._get_model_page_id(model_id=ModelID(str(self.id))) - ): - # TODO: consider making a helper method in data_manager or elsewhere to reduce this operation being - # duplicated across Filter so much, and/or consider storing the result to avoid repeating it. - # Need to think about this in connection with how to update filters on the fly and duplicated calls - # issue outlined in https://github.com/mckinsey/vizro/pull/398#discussion_r1559120849. - data_source_name = model_manager[component_id]["data_frame"] - data_frame = data_manager[data_source_name].load() - if self.column in data_frame.columns: - self.targets.append(component_id) - if not self.targets: - raise ValueError(f"Selected column {self.column} not found in any dataframe on this page.") - - def _set_column_type(self): - data_source_name = model_manager[self.targets[0]]["data_frame"] - data_frame = data_manager[data_source_name].load() - - if is_numeric_dtype(data_frame[self.column]): - self._column_type = "numerical" - elif is_datetime64_any_dtype(data_frame[self.column]): - self._column_type = "temporal" + if self.targets: + targeted_data = self._validate_targeted_data(targets=self.targets) else: - self._column_type = "categorical" + # If targets aren't explicitly provided then try to target all figures on the page. In this case we don't + # want to raise an error if the column is not found in a figure's data_frame, it will just be ignored. + # Possibly in future this will change (which would be breaking change). + targeted_data = self._validate_targeted_data( + targets=model_manager._get_page_model_ids_with_figure( + page_id=model_manager._get_model_page_id(model_id=ModelID(str(self.id))) + ), + eagerly_raise_column_not_found_error=False, + ) + self.targets = list(targeted_data.columns) - def _set_selector(self): + # Set default selector according to column type. + self._column_type = self._validate_column_type(targeted_data) self.selector = self.selector or SELECTORS[self._column_type][0]() self.selector.title = self.selector.title or self.column.title() - def _validate_disallowed_selector(self): if isinstance(self.selector, DISALLOWED_SELECTORS.get(self._column_type, ())): raise ValueError( - f"Chosen selector {self.selector.type} is not compatible " - f"with {self._column_type} column '{self.column}'. " + f"Chosen selector {type(self.selector).__name__} is not compatible with {self._column_type} column " + f"'{self.column}'." ) - def _set_numerical_and_temporal_selectors_values(self): - # If the selector is a numerical or temporal selector, and the min and max values are not set, then set them - # N.B. All custom selectors inherit from numerical or temporal selector should also pass this check + # Set appropriate properties for the selector. if isinstance(self.selector, SELECTORS["numerical"] + SELECTORS["temporal"]): - min_values = [] - max_values = [] - for target_id in self.targets: - data_source_name = model_manager[target_id]["data_frame"] - data_frame = data_manager[data_source_name].load() - min_values.append(data_frame[self.column].min()) - max_values.append(data_frame[self.column].max()) - - if not ( - (is_numeric_dtype(pd.Series(min_values)) and is_numeric_dtype(pd.Series(max_values))) - or (is_datetime64_any_dtype(pd.Series(min_values)) and is_datetime64_any_dtype(pd.Series(max_values))) - ): - raise ValueError( - f"Inconsistent types detected in the shared data column '{self.column}' for targeted charts " - f"{self.targets}. Please ensure that the data column contains the same data type across all " - f"targeted charts." - ) - + _min, _max = self._get_min_max(targeted_data) + # Note that manually set self.selector.min/max = 0 are Falsey but should not be overwritten. if self.selector.min is None: - self.selector.min = min(min_values) + self.selector.min = _min if self.selector.max is None: - self.selector.max = max(max_values) - - def _set_categorical_selectors_options(self): - # If the selector is a categorical selector, and the options are not set, then set them - # N.B. All custom selectors inherit from categorical selector should also pass this check - if isinstance(self.selector, SELECTORS["categorical"]) and not self.selector.options: - options = set() - for target_id in self.targets: - data_source_name = model_manager[target_id]["data_frame"] - data_frame = data_manager[data_source_name].load() - options |= set(data_frame[self.column]) - - self.selector.options = sorted(options) + self.selector.max = _max + else: + # Categorical selector. + self.selector.options = self.selector.options or self._get_options(targeted_data) - def _set_actions(self): if not self.selector.actions: if isinstance(self.selector, RangeSlider) or ( isinstance(self.selector, DatePicker) and self.selector.range @@ -199,3 +147,102 @@ def _set_actions(self): function=_filter(filter_column=self.column, targets=self.targets, filter_function=filter_function), ) ] + + def __call__(self, **kwargs): + # Only relevant for a dynamic filter. + # TODO: this will need to pass parametrised data_frame arguments through to _validate_targeted_data. + # Although targets are fixed at build time, the validation logic is repeated during runtime, so if a column + # is missing then it will raise an error. We could change this if we wanted. + targeted_data = self._validate_targeted_data(targets=self.targets) + + if (column_type := self._validate_column_type(targeted_data)) != self._column_type: + raise ValueError( + f"{self.column} has changed type from {self._column_type} to {column_type}. A filtered column cannot " + "change type while the dashboard is running." + ) + + # TODO: when implement dynamic, will need to do something with this e.g. pass to selector.__call__. + # if isinstance(self.selector, SELECTORS["numerical"] + SELECTORS["temporal"]): + # options = self._get_options(targeted_data) + # else: + # # Categorical selector. + # _min, _max = self._get_min_max(targeted_data) + + @_log_call + def build(self): + return self.selector.build() + + def _validate_targeted_data( + self, targets: list[ModelID], eagerly_raise_column_not_found_error=True + ) -> pd.DataFrame: + # TODO: consider moving some of this logic to data_manager when implement dynamic filter. Make sure + # get_modified_figures and stuff in _actions_utils.py is as efficient as code here. + + # When loading data_frame there are possible keys: + # 1. target. In worst case scenario this is needed but can lead to unnecessary repeated data loading. + # 2. data_source_name. No repeated data loading but won't work when applying data_frame parameters at runtime. + # 3. target + data_frame parameters keyword-argument pairs. This is the correct key to use at runtime. + # For now we follow scheme 2 for data loading (due to set() below) and 1 for the returned targeted_data + # pd.DataFrame, i.e. a separate column for each target even if some data is repeated. + # TODO: when this works with data_frame parameters load() will need to take arguments and the structures here + # might change a bit. + target_to_data_source_name = {target: model_manager[target]["data_frame"] for target in targets} + data_source_name_to_data = { + data_source_name: data_manager[data_source_name].load() + for data_source_name in set(target_to_data_source_name.values()) + } + target_to_series = {} + + for target, data_source_name in target_to_data_source_name.items(): + data_frame = data_source_name_to_data[data_source_name] + + if self.column in data_frame.columns: + # reset_index so that when we make a DataFrame out of all these pd.Series pandas doesn't try to align + # the columns by index. + target_to_series[target] = data_frame[self.column].reset_index(drop=True) + elif eagerly_raise_column_not_found_error: + raise ValueError(f"Selected column {self.column} not found in dataframe for {target}.") + + targeted_data = pd.DataFrame(target_to_series) + if targeted_data.columns.empty: + # Still raised when eagerly_raise_column_not_found_error=False. + raise ValueError(f"Selected column {self.column} not found in any dataframe for {', '.join(targets)}.") + if targeted_data.empty: + raise ValueError( + f"Selected column {self.column} does not contain anything in any dataframe for {', '.join(targets)}." + ) + + return targeted_data + + def _validate_column_type(self, targeted_data: pd.DataFrame) -> Literal["numerical", "categorical", "temporal"]: + is_numerical = targeted_data.apply(is_numeric_dtype) + is_temporal = targeted_data.apply(is_datetime64_any_dtype) + is_categorical = ~is_numerical & ~is_temporal + + if is_numerical.all(): + return "numerical" + elif is_temporal.all(): + return "temporal" + elif is_categorical.all(): + return "categorical" + else: + raise ValueError( + f"Inconsistent types detected in column {self.column}. This column must have the same type for all " + "targets." + ) + + @staticmethod + def _get_min_max(targeted_data: pd.DataFrame) -> tuple[float, float]: + # Use item() to convert to convert scalar from numpy to Python type. This isn't needed during pre_build because + # pydantic will coerce the type, but it is necessary in __call__ where we don't update model field values + # and instead just pass straight to the Dash component. + return targeted_data.min(axis=None).item(), targeted_data.max(axis=None).item() + + @staticmethod + def _get_options(targeted_data: pd.DataFrame) -> list[Any]: + # Use tolist() to convert to convert scalar from numpy to Python type. This isn't needed during pre_build + # because pydantic will coerce the type, but it is necessary in __call__ where we don't update model field + # values and instead just pass straight to the Dash component. + # The dropna() isn't strictly required here but will be in future pandas versions when the behavior of stack + # changes. See https://pandas.pydata.org/docs/whatsnew/v2.1.0.html#whatsnew-210-enhancements-new-stack. + return np.unique(targeted_data.stack().dropna()).tolist() # noqa: PD013 diff --git a/vizro-core/tests/unit/vizro/models/_action/test_action.py b/vizro-core/tests/unit/vizro/models/_action/test_action.py index 24b14f04b..48ed759d8 100644 --- a/vizro-core/tests/unit/vizro/models/_action/test_action.py +++ b/vizro-core/tests/unit/vizro/models/_action/test_action.py @@ -150,7 +150,7 @@ def managers_one_page_without_graphs_one_button(): vm.Page( id="test_page", title="Test page", - components=[vm.Graph(figure=px.scatter(data_frame=pd.DataFrame(columns=["A"]), x="A", y="A"))], + components=[vm.Graph(figure=px.scatter(data_frame=pd.DataFrame(data={"A": [1], "B": [2]}), x="A", y="B"))], controls=[vm.Filter(id="test_filter", column="A")], ) Vizro._pre_build() diff --git a/vizro-core/tests/unit/vizro/models/_controls/conftest.py b/vizro-core/tests/unit/vizro/models/_controls/conftest.py index 1d376439e..8887876ff 100644 --- a/vizro-core/tests/unit/vizro/models/_controls/conftest.py +++ b/vizro-core/tests/unit/vizro/models/_controls/conftest.py @@ -1,8 +1,3 @@ -import datetime -import random - -import numpy as np -import pandas as pd import pytest import vizro.models as vm @@ -10,21 +5,6 @@ from vizro import Vizro -@pytest.fixture -def dfs_with_shared_column(): - df1 = pd.DataFrame() - df1["x"] = np.random.uniform(0, 10, 100) - df1["y"] = np.random.uniform(0, 10, 100) - df2 = df1.copy() - df3 = df1.copy() - - df1["shared_column"] = np.random.uniform(0, 10, 100) - df2["shared_column"] = [datetime.datetime(2024, 1, 1) + datetime.timedelta(days=i) for i in range(100)] - df3["shared_column"] = random.choices(["CATEGORY 1", "CATEGORY 2"], k=100) - - return df1, df2, df3 - - @pytest.fixture def managers_one_page_two_graphs(gapminder): """Instantiates a simple model_manager and data_manager with a page, and two graph models and gapminder data.""" @@ -37,19 +17,3 @@ def managers_one_page_two_graphs(gapminder): ], ) Vizro._pre_build() - - -@pytest.fixture -def managers_shared_column_different_dtype(dfs_with_shared_column): - """Instantiates the managers with a page and two graphs sharing the same column but of different data types.""" - df1, df2, df3 = dfs_with_shared_column - vm.Page( - id="graphs_with_shared_column", - title="Page Title", - components=[ - vm.Graph(id="id_shared_column_numerical", figure=px.scatter(df1, x="x", y="y", color="shared_column")), - vm.Graph(id="id_shared_column_temporal", figure=px.scatter(df2, x="x", y="y", color="shared_column")), - vm.Graph(id="id_shared_column_categorical", figure=px.scatter(df3, x="x", y="y", color="shared_column")), - ], - ) - Vizro._pre_build() diff --git a/vizro-core/tests/unit/vizro/models/_controls/test_filter.py b/vizro-core/tests/unit/vizro/models/_controls/test_filter.py index 6a87e8c83..3f3d909e9 100644 --- a/vizro-core/tests/unit/vizro/models/_controls/test_filter.py +++ b/vizro-core/tests/unit/vizro/models/_controls/test_filter.py @@ -1,4 +1,3 @@ -import re from datetime import date, datetime from typing import Literal @@ -7,12 +6,52 @@ from asserts import assert_component_equal import vizro.models as vm +import vizro.plotly.express as px +from vizro import Vizro from vizro.managers import model_manager from vizro.models._action._actions_chain import ActionsChain from vizro.models._controls.filter import Filter, _filter_between, _filter_isin from vizro.models.types import CapturedCallable +@pytest.fixture +def managers_column_different_type(): + """Instantiates the managers with a page and two graphs sharing the same column but of different data types.""" + df_numerical = pd.DataFrame({"shared_column": [1]}) + df_temporal = pd.DataFrame({"shared_column": [datetime(2024, 1, 1)]}) + df_categorical = pd.DataFrame({"shared_column": ["a"]}) + + vm.Page( + id="test_page", + title="Page Title", + components=[ + vm.Graph(id="column_numerical", figure=px.scatter(df_numerical)), + vm.Graph(id="column_temporal", figure=px.scatter(df_temporal)), + vm.Graph(id="column_categorical", figure=px.scatter(df_categorical)), + ], + ) + Vizro._pre_build() + + +@pytest.fixture +def managers_column_only_exists_in_some(): + """Dataframes with column_numerical and column_categorical, which can be different lengths.""" + vm.Page( + id="test_page", + title="Page Title", + components=[ + vm.Graph(id="column_numerical_exists_1", figure=px.scatter(pd.DataFrame({"column_numerical": [1]}))), + vm.Graph(id="column_numerical_exists_2", figure=px.scatter(pd.DataFrame({"column_numerical": [1, 2]}))), + vm.Graph(id="column_numerical_exists_empty", figure=px.scatter(pd.DataFrame({"column_numerical": []}))), + vm.Graph(id="column_categorical_exists_1", figure=px.scatter(pd.DataFrame({"column_categorical": ["a"]}))), + vm.Graph( + id="column_categorical_exists_2", figure=px.scatter(pd.DataFrame({"column_categorical": ["a", "b"]})) + ), + ], + ) + Vizro._pre_build() + + class TestFilterFunctions: @pytest.mark.parametrize( "data, value, expected", @@ -206,26 +245,62 @@ def test_check_target_present_invalid(self): class TestPreBuildMethod: - def test_set_targets_valid(self, managers_one_page_two_graphs): + def test_targets_default_valid(self, managers_column_only_exists_in_some): # Core of tests is still interface level - filter = vm.Filter(column="country") + filter = vm.Filter(column="column_numerical") # Special case - need filter in the context of page in order to run filter.pre_build model_manager["test_page"].controls = [filter] filter.pre_build() - assert set(filter.targets) == {"scatter_chart", "bar_chart"} + assert filter.targets == [ + "column_numerical_exists_1", + "column_numerical_exists_2", + "column_numerical_exists_empty", + ] + + def test_targets_specific_valid(self, managers_column_only_exists_in_some): + filter = vm.Filter(column="column_numerical", targets=["column_numerical_exists_1"]) + model_manager["test_page"].controls = [filter] + filter.pre_build() + assert filter.targets == ["column_numerical_exists_1"] - def test_set_targets_invalid(self, managers_one_page_two_graphs): + def test_targets_default_invalid(self, managers_column_only_exists_in_some): filter = vm.Filter(column="invalid_choice") model_manager["test_page"].controls = [filter] - with pytest.raises(ValueError, match="Selected column invalid_choice not found in any dataframe on this page."): + with pytest.raises( + ValueError, + match="Selected column invalid_choice not found in any dataframe for column_numerical_exists_1, " + "column_numerical_exists_2, column_numerical_exists_empty, column_categorical_exists_1, " + "column_categorical_exists_2.", + ): + filter.pre_build() + + def test_targets_specific_invalid(self, managers_column_only_exists_in_some): + filter = vm.Filter(column="column_numerical", targets=["column_categorical_exists_1"]) + model_manager["test_page"].controls = [filter] + + with pytest.raises( + ValueError, + match="Selected column column_numerical not found in dataframe for column_categorical_exists_1.", + ): + filter.pre_build() + + def test_targets_empty(self, managers_column_only_exists_in_some): + filter = vm.Filter(column="column_numerical", targets=["column_numerical_exists_empty"]) + model_manager["test_page"].controls = [filter] + + with pytest.raises( + ValueError, + match="Selected column column_numerical does not contain anything in any dataframe for " + "column_numerical_exists_empty.", + ): filter.pre_build() @pytest.mark.parametrize( "filtered_column, expected_column_type", [("country", "categorical"), ("year", "temporal"), ("lifeExp", "numerical")], ) - def test_set_column_type(self, filtered_column, expected_column_type, managers_one_page_two_graphs): + def test_column_type(self, filtered_column, expected_column_type, managers_one_page_two_graphs): filter = vm.Filter(column=filtered_column) model_manager["test_page"].controls = [filter] filter.pre_build() @@ -235,7 +310,7 @@ def test_set_column_type(self, filtered_column, expected_column_type, managers_o "filtered_column, expected_selector", [("country", vm.Dropdown), ("year", vm.DatePicker), ("lifeExp", vm.RangeSlider)], ) - def test_set_selector_default_selector(self, filtered_column, expected_selector, managers_one_page_two_graphs): + def test_selector_default_selector(self, filtered_column, expected_selector, managers_one_page_two_graphs): filter = vm.Filter(column=filtered_column) model_manager["test_page"].controls = [filter] filter.pre_build() @@ -243,7 +318,7 @@ def test_set_selector_default_selector(self, filtered_column, expected_selector, assert filter.selector.title == filtered_column.title() @pytest.mark.parametrize("filtered_column", ["country", "year", "lifeExp"]) - def test_set_selector_specific_selector(self, filtered_column, managers_one_page_two_graphs): + def test_selector_specific_selector(self, filtered_column, managers_one_page_two_graphs): filter = vm.Filter(column=filtered_column, selector=vm.RadioItems(title="Title")) model_manager["test_page"].controls = [filter] filter.pre_build() @@ -274,54 +349,60 @@ def test_allowed_selectors_per_column_type(self, filtered_column, selector, mana assert isinstance(filter.selector, selector) @pytest.mark.parametrize( - "filtered_column, selector", + "filtered_column, selector, selector_name, column_type", [ - ("country", vm.Slider), - ("country", vm.RangeSlider), - ("country", vm.DatePicker), - ("lifeExp", vm.DatePicker), - ("year", vm.Slider), - ("year", vm.RangeSlider), + ("country", vm.Slider, "Slider", "categorical"), + ("country", vm.RangeSlider, "RangeSlider", "categorical"), + ("country", vm.DatePicker, "DatePicker", "categorical"), + ("lifeExp", vm.DatePicker, "DatePicker", "numerical"), + ("year", vm.Slider, "Slider", "temporal"), + ("year", vm.RangeSlider, "RangeSlider", "temporal"), ], ) - def test_disallowed_selectors_per_column_type(self, filtered_column, selector, managers_one_page_two_graphs): + def test_disallowed_selectors_per_column_type( + self, filtered_column, selector, selector_name, column_type, managers_one_page_two_graphs + ): filter = vm.Filter(column=filtered_column, selector=selector()) model_manager["test_page"].controls = [filter] with pytest.raises( ValueError, - match=f"Chosen selector {selector().type} is not compatible with .* column '{filtered_column}'. ", + match=f"Chosen selector {selector_name} is not compatible with {column_type} column '{filtered_column}'.", ): filter.pre_build() @pytest.mark.parametrize( "targets", [ - ["id_shared_column_numerical", "id_shared_column_temporal"], - ["id_shared_column_numerical", "id_shared_column_categorical"], - ["id_shared_column_temporal", "id_shared_column_categorical"], + ["column_numerical", "column_temporal"], + ["column_numerical", "column_categorical"], + ["column_temporal", "column_categorical"], ], ) - def test_set_slider_values_shared_column_inconsistent_dtype(self, targets, managers_shared_column_different_dtype): + def test_validate_column_type(self, targets, managers_column_different_type): filter = vm.Filter(column="shared_column", targets=targets) - model_manager["graphs_with_shared_column"].controls = [filter] + model_manager["test_page"].controls = [filter] with pytest.raises( ValueError, - match=re.escape( - f"Inconsistent types detected in the shared data column 'shared_column' for targeted charts {targets}. " - f"Please ensure that the data column contains the same data type across all targeted charts." - ), + match="Inconsistent types detected in column shared_column.", ): filter.pre_build() @pytest.mark.parametrize("selector", [vm.Slider, vm.RangeSlider]) - def test_set_numerical_selectors_values_min_max_default(self, selector, gapminder, managers_one_page_two_graphs): + def test_numerical_min_max_default(self, selector, gapminder, managers_one_page_two_graphs): filter = vm.Filter(column="lifeExp", selector=selector()) model_manager["test_page"].controls = [filter] filter.pre_build() assert filter.selector.min == gapminder.lifeExp.min() assert filter.selector.max == gapminder.lifeExp.max() - def test_set_temporal_selectors_values_min_max_default(self, gapminder, managers_one_page_two_graphs): + def test_numerical_min_max_different_column_lengths(self, gapminder, managers_column_only_exists_in_some): + filter = vm.Filter(column="column_numerical", selector=vm.Slider()) + model_manager["test_page"].controls = [filter] + filter.pre_build() + assert filter.selector.min == 1 + assert filter.selector.max == 2 + + def test_temporal_min_max_default(self, gapminder, managers_one_page_two_graphs): filter = vm.Filter(column="year", selector=vm.DatePicker()) model_manager["test_page"].controls = [filter] filter.pre_build() @@ -329,14 +410,15 @@ def test_set_temporal_selectors_values_min_max_default(self, gapminder, managers assert filter.selector.max == gapminder.year.max().to_pydatetime().date() @pytest.mark.parametrize("selector", [vm.Slider, vm.RangeSlider]) - def test_set_numerical_selectors_values_min_max_specific(self, selector, managers_one_page_two_graphs): - filter = vm.Filter(column="lifeExp", selector=selector(min=3, max=5)) + @pytest.mark.parametrize("min, max", [(3, 5), (0, 5), (-5, 0)]) + def test_numerical_min_max_specific(self, selector, min, max, managers_one_page_two_graphs): + filter = vm.Filter(column="lifeExp", selector=selector(min=min, max=max)) model_manager["test_page"].controls = [filter] filter.pre_build() - assert filter.selector.min == 3 - assert filter.selector.max == 5 + assert filter.selector.min == min + assert filter.selector.max == max - def test_set_temporal_selectors_values_min_max_specific(self, managers_one_page_two_graphs): + def test_temporal_min_max_specific(self, managers_one_page_two_graphs): filter = vm.Filter(column="year", selector=vm.DatePicker(min="1952-01-01", max="2007-01-01")) model_manager["test_page"].controls = [filter] filter.pre_build() @@ -344,14 +426,20 @@ def test_set_temporal_selectors_values_min_max_specific(self, managers_one_page_ assert filter.selector.max == date(2007, 1, 1) @pytest.mark.parametrize("selector", [vm.Checklist, vm.Dropdown, vm.RadioItems]) - def test_set_categorical_selectors_options_default(self, selector, gapminder, managers_one_page_two_graphs): + def test_categorical_options_default(self, selector, gapminder, managers_one_page_two_graphs): filter = vm.Filter(column="continent", selector=selector()) model_manager["test_page"].controls = [filter] filter.pre_build() assert filter.selector.options == sorted(set(gapminder["continent"])) + def test_categorical_options_different_column_lengths(self, gapminder, managers_column_only_exists_in_some): + filter = vm.Filter(column="column_categorical", selector=vm.Checklist()) + model_manager["test_page"].controls = [filter] + filter.pre_build() + assert filter.selector.options == ["a", "b"] + @pytest.mark.parametrize("selector", [vm.Checklist, vm.Dropdown, vm.RadioItems]) - def test_set_categorical_selectors_options_specific(self, selector, managers_one_page_two_graphs): + def test_categorical_options_specific(self, selector, managers_one_page_two_graphs): filter = vm.Filter(column="continent", selector=selector(options=["Africa", "Europe"])) model_manager["test_page"].controls = [filter] filter.pre_build()