Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tidy] Prepare for dynamic filters, part 1 of 2 #850

Merged
merged 14 commits into from
Nov 7, 2024
Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<!--
A new scriv changelog fragment.

Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨

- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Removed

- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Added

- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Changed

- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Deprecated

- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Fixed

- A bullet item for the Fixed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Security

- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
4 changes: 3 additions & 1 deletion vizro-core/hatch.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ VIZRO_LOG_LEVEL = "DEBUG"
extra-dependencies = [
"pydantic==1.10.16",
"dash==2.17.1",
"plotly==5.12.0"
"plotly==5.12.0",
"pandas==2.0.0",
"numpy==1.23.0"
]
features = ["kedro"]
python = "3.9"
Expand Down
2 changes: 1 addition & 1 deletion vizro-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies = [
"dash>=2.17.1", # 2.17.1 needed for no_output fix in clientside_callback
"dash_bootstrap_components",
"dash-ag-grid>=31.0.0",
"pandas",
"pandas>=2",
"plotly>=5.12.0",
"pydantic>=1.10.16", # must be synced with pre-commit mypy hook manually
"dash_mantine_components<0.13.0", # 0.13.0 is not compatible with 0.12,
Expand Down
198 changes: 119 additions & 79 deletions vizro-core/src/vizro/models/_controls/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Literal, Union

import numpy as np
import pandas as pd
from pandas.api.types import is_datetime64_any_dtype, is_numeric_dtype

Expand Down Expand Up @@ -95,96 +96,40 @@ def check_target_present(cls, target):

@_log_call
def pre_build(self):
self._set_targets()
self._set_column_type()
self._set_selector()
self._validate_disallowed_selector()
self._set_numerical_and_temporal_selectors_values()
self._set_categorical_selectors_options()
self._set_actions()

@_log_call
def build(self):
return self.selector.build()

def _set_targets(self):
if not self.targets:
for component_id in model_manager._get_page_model_ids_with_figure(
page_id=model_manager._get_model_page_id(model_id=ModelID(str(self.id)))
):
# TODO: consider making a helper method in data_manager or elsewhere to reduce this operation being
# duplicated across Filter so much, and/or consider storing the result to avoid repeating it.
# Need to think about this in connection with how to update filters on the fly and duplicated calls
# issue outlined in https://github.com/mckinsey/vizro/pull/398#discussion_r1559120849.
data_source_name = model_manager[component_id]["data_frame"]
data_frame = data_manager[data_source_name].load()
if self.column in data_frame.columns:
self.targets.append(component_id)
if not self.targets:
raise ValueError(f"Selected column {self.column} not found in any dataframe on this page.")

def _set_column_type(self):
data_source_name = model_manager[self.targets[0]]["data_frame"]
data_frame = data_manager[data_source_name].load()

if is_numeric_dtype(data_frame[self.column]):
self._column_type = "numerical"
elif is_datetime64_any_dtype(data_frame[self.column]):
self._column_type = "temporal"
if self.targets:
targeted_data = self._validate_targeted_data(targets=self.targets)
else:
self._column_type = "categorical"
# If targets aren't explicitly provided then try to target all figures on the page. In this case we don't
antonymilne marked this conversation as resolved.
Show resolved Hide resolved
# want to raise an error if the column is not found in a figure's data_frame, it will just be ignored.
# Possibly in future this will change (which would be breaking change).
targeted_data = self._validate_targeted_data(
targets=model_manager._get_page_model_ids_with_figure(
page_id=model_manager._get_model_page_id(model_id=ModelID(str(self.id)))
),
eagerly_raise_column_not_found_error=False,
)
self.targets = list(targeted_data.columns)

def _set_selector(self):
# Set default selector according to column type.
self._column_type = self._validate_column_type(targeted_data)
self.selector = self.selector or SELECTORS[self._column_type][0]()
self.selector.title = self.selector.title or self.column.title()

def _validate_disallowed_selector(self):
if isinstance(self.selector, DISALLOWED_SELECTORS.get(self._column_type, ())):
raise ValueError(
f"Chosen selector {self.selector.type} is not compatible "
f"with {self._column_type} column '{self.column}'. "
f"Chosen selector {type(self.selector).__name__} is not compatible with {self._column_type} column "
f"'{self.column}'."
)

def _set_numerical_and_temporal_selectors_values(self):
# If the selector is a numerical or temporal selector, and the min and max values are not set, then set them
# N.B. All custom selectors inherit from numerical or temporal selector should also pass this check
# Set appropriate properties for the selector.
if isinstance(self.selector, SELECTORS["numerical"] + SELECTORS["temporal"]):
min_values = []
max_values = []
for target_id in self.targets:
data_source_name = model_manager[target_id]["data_frame"]
data_frame = data_manager[data_source_name].load()
min_values.append(data_frame[self.column].min())
max_values.append(data_frame[self.column].max())

if not (
(is_numeric_dtype(pd.Series(min_values)) and is_numeric_dtype(pd.Series(max_values)))
or (is_datetime64_any_dtype(pd.Series(min_values)) and is_datetime64_any_dtype(pd.Series(max_values)))
):
raise ValueError(
f"Inconsistent types detected in the shared data column '{self.column}' for targeted charts "
f"{self.targets}. Please ensure that the data column contains the same data type across all "
f"targeted charts."
)

if self.selector.min is None:
self.selector.min = min(min_values)
if self.selector.max is None:
self.selector.max = max(max_values)

def _set_categorical_selectors_options(self):
# If the selector is a categorical selector, and the options are not set, then set them
# N.B. All custom selectors inherit from categorical selector should also pass this check
if isinstance(self.selector, SELECTORS["categorical"]) and not self.selector.options:
options = set()
for target_id in self.targets:
data_source_name = model_manager[target_id]["data_frame"]
data_frame = data_manager[data_source_name].load()
options |= set(data_frame[self.column])

self.selector.options = sorted(options)
_min, _max = self._get_min_max(targeted_data)
self.selector.min = self.selector.min or _min
self.selector.max = self.selector.max or _max
petar-qb marked this conversation as resolved.
Show resolved Hide resolved
else:
# Categorical selector.
self.selector.options = self.selector.options or self._get_options(targeted_data)

def _set_actions(self):
if not self.selector.actions:
if isinstance(self.selector, RangeSlider) or (
isinstance(self.selector, DatePicker) and self.selector.range
Expand All @@ -199,3 +144,98 @@ def _set_actions(self):
function=_filter(filter_column=self.column, targets=self.targets, filter_function=filter_function),
)
]

def __call__(self, **kwargs):
antonymilne marked this conversation as resolved.
Show resolved Hide resolved
# Only relevant for a dynamic filter.
# TODO: this will need to pass parametrised data_frame arguments through to _validate_targeted_data.
# Although targets are fixed at build time, the validation logic is repeated during runtime, so if a column
# is missing then it will raise an error. We could change this if we wanted.
targeted_data = self._validate_targeted_data(targets=self.targets)

if (column_type := self._validate_column_type(targeted_data)) != self._column_type:
raise ValueError(
f"{self.column} has changed type from {self._column_type} to {column_type}. A filtered column cannot "
"change type while the dashboard is running."
)

# TODO: when implement dynamic, will need to do something with this e.g. pass to selector.__call__.
# if isinstance(self.selector, SELECTORS["numerical"] + SELECTORS["temporal"]):
# options = self._get_options(targeted_data)
# else:
# # Categorical selector.
# _min, _max = self._get_min_max(targeted_data)

petar-qb marked this conversation as resolved.
Show resolved Hide resolved
@_log_call
def build(self):
return self.selector.build()

def _validate_targeted_data(
self, targets: list[ModelID], eagerly_raise_column_not_found_error=True
) -> pd.DataFrame:
# TODO: consider moving some of this logic to data_manager when implement dynamic filter. Make sure
# get_modified_figures and stuff in _actions_utils.py is as efficient as code here.

# When loading data_frame there are possible keys:
# 1. target. In worst case scenario this is needed but can lead to unnecessary repeated data loading.
# 2. data_source_name. No repeated data loading but won't work when applying data_frame parameters at runtime.
# 3. target + data_frame parameters keyword-argument pairs. This is the correct key to use at runtime.
# For now we follow scheme 2 for data loading (due to set() below) and 1 for the returned targeted_data
# pd.DataFrame, i.e. a separate column for each target even if some data is repeated.
# TODO: when this works with data_frame parameters load() will need to take arguments and the structures here
# might change a bit.
target_to_data_source_name = {target: model_manager[target]["data_frame"] for target in targets}
data_source_name_to_data = {
data_source_name: data_manager[data_source_name].load()
antonymilne marked this conversation as resolved.
Show resolved Hide resolved
for data_source_name in set(target_to_data_source_name.values())
}
target_to_series = dict()

for target, data_source_name in target_to_data_source_name.items():
data_frame = data_source_name_to_data[data_source_name]

if self.column in data_frame.columns:
# reset_index so that when we make a DataFrame out of all these pd.Series pandas doesn't try to align
# the columns by index.
target_to_series[target] = data_frame[self.column].reset_index(drop=True)
antonymilne marked this conversation as resolved.
Show resolved Hide resolved
elif eagerly_raise_column_not_found_error:
raise ValueError(f"Selected column {self.column} not found in dataframe for {target}.")

targeted_data = pd.DataFrame(target_to_series)
if targeted_data.columns.empty:
# Still raised when eagerly_raise_column_not_found_error=False.
raise ValueError(f"Selected column {self.column} not found in any dataframe for {", ".join(targets)}.")
if targeted_data.empty:
raise ValueError(f"Selected column {self.column} does not contain any data.")

return targeted_data

def _validate_column_type(self, targeted_data: pd.DataFrame) -> Literal["numerical", "categorical", "temporal"]:
antonymilne marked this conversation as resolved.
Show resolved Hide resolved
is_numerical = targeted_data.apply(is_numeric_dtype)
is_temporal = targeted_data.apply(is_datetime64_any_dtype)
is_categorical = ~is_numerical & ~is_temporal

if is_numerical.all():
return "numerical"
elif is_temporal.all():
return "temporal"
elif is_categorical.all():
return "categorical"
else:
raise ValueError(
f"Inconsistent types detected in the shared data column {self.column}. This column must "
"have the same type for all targets."
)

# TODO: write tests. Include N/A
# TODO: block all update of models during runtime
def _get_min_max(self, targeted_data: pd.DataFrame) -> tuple[float, float]:
# Use item() to convert to convert scalar from numpy to Python type. This isn't needed during pre_build because
# pydantic will coerce the type, but it is necessary in __call__ where we don't update model field values
# and instead just pass straight to the Dash component.
return targeted_data.min(axis=None).item(), targeted_data.max(axis=None).item()

def _get_options(self, targeted_data: pd.DataFrame) -> list:
# Use tolist() to convert to convert scalar from numpy to Python type. This isn't needed during pre_build
# because pydantic will coerce the type, but it is necessary in __call__ where we don't update model field values
# and instead just pass straight to the Dash component.
return np.unique(targeted_data.stack().dropna()).tolist()
antonymilne marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def managers_one_page_without_graphs_one_button():
vm.Page(
id="test_page",
title="Test page",
components=[vm.Graph(figure=px.scatter(data_frame=pd.DataFrame(columns=["A"]), x="A", y="A"))],
components=[vm.Graph(figure=px.scatter(data_frame=pd.DataFrame(data={"A": [1], "B": [2]}), x="A", y="B"))],
antonymilne marked this conversation as resolved.
Show resolved Hide resolved
controls=[vm.Filter(id="test_filter", column="A")],
)
Vizro._pre_build()
Expand Down
30 changes: 15 additions & 15 deletions vizro-core/tests/unit/vizro/models/_controls/test_filter.py
antonymilne marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import re
from datetime import date, datetime
from typing import Literal

Expand Down Expand Up @@ -218,7 +217,9 @@ def test_set_targets_invalid(self, managers_one_page_two_graphs):
filter = vm.Filter(column="invalid_choice")
model_manager["test_page"].controls = [filter]

with pytest.raises(ValueError, match="Selected column invalid_choice not found in any dataframe on this page."):
with pytest.raises(
ValueError, match="Selected column invalid_choice not found in any dataframe for scatter_chart, bar_chart."
):
filter.pre_build()

@pytest.mark.parametrize(
Expand Down Expand Up @@ -274,22 +275,24 @@ def test_allowed_selectors_per_column_type(self, filtered_column, selector, mana
assert isinstance(filter.selector, selector)

@pytest.mark.parametrize(
"filtered_column, selector",
"filtered_column, selector, selector_name, column_type",
[
("country", vm.Slider),
("country", vm.RangeSlider),
("country", vm.DatePicker),
("lifeExp", vm.DatePicker),
("year", vm.Slider),
("year", vm.RangeSlider),
("country", vm.Slider, "Slider", "categorical"),
("country", vm.RangeSlider, "RangeSlider", "categorical"),
("country", vm.DatePicker, "DatePicker", "categorical"),
("lifeExp", vm.DatePicker, "DatePicker", "numerical"),
("year", vm.Slider, "Slider", "temporal"),
("year", vm.RangeSlider, "RangeSlider", "temporal"),
],
)
def test_disallowed_selectors_per_column_type(self, filtered_column, selector, managers_one_page_two_graphs):
def test_disallowed_selectors_per_column_type(
self, filtered_column, selector, selector_name, column_type, managers_one_page_two_graphs
):
filter = vm.Filter(column=filtered_column, selector=selector())
model_manager["test_page"].controls = [filter]
with pytest.raises(
ValueError,
match=f"Chosen selector {selector().type} is not compatible with .* column '{filtered_column}'. ",
match=f"Chosen selector {selector_name} is not compatible with {column_type} column '{filtered_column}'.",
):
filter.pre_build()

Expand All @@ -306,10 +309,7 @@ def test_set_slider_values_shared_column_inconsistent_dtype(self, targets, manag
model_manager["graphs_with_shared_column"].controls = [filter]
with pytest.raises(
ValueError,
match=re.escape(
f"Inconsistent types detected in the shared data column 'shared_column' for targeted charts {targets}. "
f"Please ensure that the data column contains the same data type across all targeted charts."
),
match="Inconsistent types detected in the shared data column shared_column.",
):
filter.pre_build()

Expand Down
Loading