Skip to content

Commit

Permalink
[Tidy] Prepare for dynamic filters, part 1 of 2 (#850)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
antonymilne and pre-commit-ci[bot] authored Nov 7, 2024
1 parent c7df0c5 commit ae68bd6
Show file tree
Hide file tree
Showing 8 changed files with 303 additions and 153 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ repos:
hooks:
- id: ruff
args: [--fix]
exclude: "vizro-core/examples/scratch_dev/app.py"
- id: ruff-format

- repo: https://github.com/PyCQA/bandit
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<!--
A new scriv changelog fragment.
Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨
- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Removed
- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Added
- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Changed
- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Deprecated
- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Fixed
- A bullet item for the Fixed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Security
- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
4 changes: 3 additions & 1 deletion vizro-core/hatch.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ VIZRO_LOG_LEVEL = "DEBUG"
extra-dependencies = [
"pydantic==1.10.16",
"dash==2.17.1",
"plotly==5.12.0"
"plotly==5.12.0",
"pandas==2.0.0",
"numpy==1.23.0" # Need numpy<2 to work with pandas==2.0.0. See https://stackoverflow.com/questions/78634235/.
]
features = ["kedro"]
python = "3.9"
Expand Down
2 changes: 1 addition & 1 deletion vizro-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies = [
"dash>=2.17.1", # 2.17.1 needed for no_output fix in clientside_callback
"dash_bootstrap_components",
"dash-ag-grid>=31.0.0",
"pandas",
"pandas>=2",
"plotly>=5.12.0",
"pydantic>=1.10.16", # must be synced with pre-commit mypy hook manually
"dash_mantine_components<0.13.0", # 0.13.0 is not compatible with 0.12,
Expand Down
203 changes: 125 additions & 78 deletions vizro-core/src/vizro/models/_controls/filter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from typing import Literal, Union
from typing import Any, Literal, Union

import numpy as np
import pandas as pd
from pandas.api.types import is_datetime64_any_dtype, is_numeric_dtype

Expand Down Expand Up @@ -95,96 +96,43 @@ def check_target_present(cls, target):

@_log_call
def pre_build(self):
self._set_targets()
self._set_column_type()
self._set_selector()
self._validate_disallowed_selector()
self._set_numerical_and_temporal_selectors_values()
self._set_categorical_selectors_options()
self._set_actions()

@_log_call
def build(self):
return self.selector.build()

def _set_targets(self):
if not self.targets:
for component_id in model_manager._get_page_model_ids_with_figure(
page_id=model_manager._get_model_page_id(model_id=ModelID(str(self.id)))
):
# TODO: consider making a helper method in data_manager or elsewhere to reduce this operation being
# duplicated across Filter so much, and/or consider storing the result to avoid repeating it.
# Need to think about this in connection with how to update filters on the fly and duplicated calls
# issue outlined in https://github.com/mckinsey/vizro/pull/398#discussion_r1559120849.
data_source_name = model_manager[component_id]["data_frame"]
data_frame = data_manager[data_source_name].load()
if self.column in data_frame.columns:
self.targets.append(component_id)
if not self.targets:
raise ValueError(f"Selected column {self.column} not found in any dataframe on this page.")

def _set_column_type(self):
data_source_name = model_manager[self.targets[0]]["data_frame"]
data_frame = data_manager[data_source_name].load()

if is_numeric_dtype(data_frame[self.column]):
self._column_type = "numerical"
elif is_datetime64_any_dtype(data_frame[self.column]):
self._column_type = "temporal"
if self.targets:
targeted_data = self._validate_targeted_data(targets=self.targets)
else:
self._column_type = "categorical"
# If targets aren't explicitly provided then try to target all figures on the page. In this case we don't
# want to raise an error if the column is not found in a figure's data_frame, it will just be ignored.
# Possibly in future this will change (which would be breaking change).
targeted_data = self._validate_targeted_data(
targets=model_manager._get_page_model_ids_with_figure(
page_id=model_manager._get_model_page_id(model_id=ModelID(str(self.id)))
),
eagerly_raise_column_not_found_error=False,
)
self.targets = list(targeted_data.columns)

def _set_selector(self):
# Set default selector according to column type.
self._column_type = self._validate_column_type(targeted_data)
self.selector = self.selector or SELECTORS[self._column_type][0]()
self.selector.title = self.selector.title or self.column.title()

def _validate_disallowed_selector(self):
if isinstance(self.selector, DISALLOWED_SELECTORS.get(self._column_type, ())):
raise ValueError(
f"Chosen selector {self.selector.type} is not compatible "
f"with {self._column_type} column '{self.column}'. "
f"Chosen selector {type(self.selector).__name__} is not compatible with {self._column_type} column "
f"'{self.column}'."
)

def _set_numerical_and_temporal_selectors_values(self):
# If the selector is a numerical or temporal selector, and the min and max values are not set, then set them
# N.B. All custom selectors inherit from numerical or temporal selector should also pass this check
# Set appropriate properties for the selector.
if isinstance(self.selector, SELECTORS["numerical"] + SELECTORS["temporal"]):
min_values = []
max_values = []
for target_id in self.targets:
data_source_name = model_manager[target_id]["data_frame"]
data_frame = data_manager[data_source_name].load()
min_values.append(data_frame[self.column].min())
max_values.append(data_frame[self.column].max())

if not (
(is_numeric_dtype(pd.Series(min_values)) and is_numeric_dtype(pd.Series(max_values)))
or (is_datetime64_any_dtype(pd.Series(min_values)) and is_datetime64_any_dtype(pd.Series(max_values)))
):
raise ValueError(
f"Inconsistent types detected in the shared data column '{self.column}' for targeted charts "
f"{self.targets}. Please ensure that the data column contains the same data type across all "
f"targeted charts."
)

_min, _max = self._get_min_max(targeted_data)
# Note that manually set self.selector.min/max = 0 are Falsey but should not be overwritten.
if self.selector.min is None:
self.selector.min = min(min_values)
self.selector.min = _min
if self.selector.max is None:
self.selector.max = max(max_values)

def _set_categorical_selectors_options(self):
# If the selector is a categorical selector, and the options are not set, then set them
# N.B. All custom selectors inherit from categorical selector should also pass this check
if isinstance(self.selector, SELECTORS["categorical"]) and not self.selector.options:
options = set()
for target_id in self.targets:
data_source_name = model_manager[target_id]["data_frame"]
data_frame = data_manager[data_source_name].load()
options |= set(data_frame[self.column])

self.selector.options = sorted(options)
self.selector.max = _max
else:
# Categorical selector.
self.selector.options = self.selector.options or self._get_options(targeted_data)

def _set_actions(self):
if not self.selector.actions:
if isinstance(self.selector, RangeSlider) or (
isinstance(self.selector, DatePicker) and self.selector.range
Expand All @@ -199,3 +147,102 @@ def _set_actions(self):
function=_filter(filter_column=self.column, targets=self.targets, filter_function=filter_function),
)
]

def __call__(self, **kwargs):
# Only relevant for a dynamic filter.
# TODO: this will need to pass parametrised data_frame arguments through to _validate_targeted_data.
# Although targets are fixed at build time, the validation logic is repeated during runtime, so if a column
# is missing then it will raise an error. We could change this if we wanted.
targeted_data = self._validate_targeted_data(targets=self.targets)

if (column_type := self._validate_column_type(targeted_data)) != self._column_type:
raise ValueError(
f"{self.column} has changed type from {self._column_type} to {column_type}. A filtered column cannot "
"change type while the dashboard is running."
)

# TODO: when implement dynamic, will need to do something with this e.g. pass to selector.__call__.
# if isinstance(self.selector, SELECTORS["numerical"] + SELECTORS["temporal"]):
# options = self._get_options(targeted_data)
# else:
# # Categorical selector.
# _min, _max = self._get_min_max(targeted_data)

@_log_call
def build(self):
return self.selector.build()

def _validate_targeted_data(
self, targets: list[ModelID], eagerly_raise_column_not_found_error=True
) -> pd.DataFrame:
# TODO: consider moving some of this logic to data_manager when implement dynamic filter. Make sure
# get_modified_figures and stuff in _actions_utils.py is as efficient as code here.

# When loading data_frame there are possible keys:
# 1. target. In worst case scenario this is needed but can lead to unnecessary repeated data loading.
# 2. data_source_name. No repeated data loading but won't work when applying data_frame parameters at runtime.
# 3. target + data_frame parameters keyword-argument pairs. This is the correct key to use at runtime.
# For now we follow scheme 2 for data loading (due to set() below) and 1 for the returned targeted_data
# pd.DataFrame, i.e. a separate column for each target even if some data is repeated.
# TODO: when this works with data_frame parameters load() will need to take arguments and the structures here
# might change a bit.
target_to_data_source_name = {target: model_manager[target]["data_frame"] for target in targets}
data_source_name_to_data = {
data_source_name: data_manager[data_source_name].load()
for data_source_name in set(target_to_data_source_name.values())
}
target_to_series = {}

for target, data_source_name in target_to_data_source_name.items():
data_frame = data_source_name_to_data[data_source_name]

if self.column in data_frame.columns:
# reset_index so that when we make a DataFrame out of all these pd.Series pandas doesn't try to align
# the columns by index.
target_to_series[target] = data_frame[self.column].reset_index(drop=True)
elif eagerly_raise_column_not_found_error:
raise ValueError(f"Selected column {self.column} not found in dataframe for {target}.")

targeted_data = pd.DataFrame(target_to_series)
if targeted_data.columns.empty:
# Still raised when eagerly_raise_column_not_found_error=False.
raise ValueError(f"Selected column {self.column} not found in any dataframe for {', '.join(targets)}.")
if targeted_data.empty:
raise ValueError(
f"Selected column {self.column} does not contain anything in any dataframe for {', '.join(targets)}."
)

return targeted_data

def _validate_column_type(self, targeted_data: pd.DataFrame) -> Literal["numerical", "categorical", "temporal"]:
is_numerical = targeted_data.apply(is_numeric_dtype)
is_temporal = targeted_data.apply(is_datetime64_any_dtype)
is_categorical = ~is_numerical & ~is_temporal

if is_numerical.all():
return "numerical"
elif is_temporal.all():
return "temporal"
elif is_categorical.all():
return "categorical"
else:
raise ValueError(
f"Inconsistent types detected in column {self.column}. This column must have the same type for all "
"targets."
)

@staticmethod
def _get_min_max(targeted_data: pd.DataFrame) -> tuple[float, float]:
# Use item() to convert to convert scalar from numpy to Python type. This isn't needed during pre_build because
# pydantic will coerce the type, but it is necessary in __call__ where we don't update model field values
# and instead just pass straight to the Dash component.
return targeted_data.min(axis=None).item(), targeted_data.max(axis=None).item()

@staticmethod
def _get_options(targeted_data: pd.DataFrame) -> list[Any]:
# Use tolist() to convert to convert scalar from numpy to Python type. This isn't needed during pre_build
# because pydantic will coerce the type, but it is necessary in __call__ where we don't update model field
# values and instead just pass straight to the Dash component.
# The dropna() isn't strictly required here but will be in future pandas versions when the behavior of stack
# changes. See https://pandas.pydata.org/docs/whatsnew/v2.1.0.html#whatsnew-210-enhancements-new-stack.
return np.unique(targeted_data.stack().dropna()).tolist() # noqa: PD013
2 changes: 1 addition & 1 deletion vizro-core/tests/unit/vizro/models/_action/test_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def managers_one_page_without_graphs_one_button():
vm.Page(
id="test_page",
title="Test page",
components=[vm.Graph(figure=px.scatter(data_frame=pd.DataFrame(columns=["A"]), x="A", y="A"))],
components=[vm.Graph(figure=px.scatter(data_frame=pd.DataFrame(data={"A": [1], "B": [2]}), x="A", y="B"))],
controls=[vm.Filter(id="test_filter", column="A")],
)
Vizro._pre_build()
Expand Down
36 changes: 0 additions & 36 deletions vizro-core/tests/unit/vizro/models/_controls/conftest.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,10 @@
import datetime
import random

import numpy as np
import pandas as pd
import pytest

import vizro.models as vm
import vizro.plotly.express as px
from vizro import Vizro


@pytest.fixture
def dfs_with_shared_column():
df1 = pd.DataFrame()
df1["x"] = np.random.uniform(0, 10, 100)
df1["y"] = np.random.uniform(0, 10, 100)
df2 = df1.copy()
df3 = df1.copy()

df1["shared_column"] = np.random.uniform(0, 10, 100)
df2["shared_column"] = [datetime.datetime(2024, 1, 1) + datetime.timedelta(days=i) for i in range(100)]
df3["shared_column"] = random.choices(["CATEGORY 1", "CATEGORY 2"], k=100)

return df1, df2, df3


@pytest.fixture
def managers_one_page_two_graphs(gapminder):
"""Instantiates a simple model_manager and data_manager with a page, and two graph models and gapminder data."""
Expand All @@ -37,19 +17,3 @@ def managers_one_page_two_graphs(gapminder):
],
)
Vizro._pre_build()


@pytest.fixture
def managers_shared_column_different_dtype(dfs_with_shared_column):
"""Instantiates the managers with a page and two graphs sharing the same column but of different data types."""
df1, df2, df3 = dfs_with_shared_column
vm.Page(
id="graphs_with_shared_column",
title="Page Title",
components=[
vm.Graph(id="id_shared_column_numerical", figure=px.scatter(df1, x="x", y="y", color="shared_column")),
vm.Graph(id="id_shared_column_temporal", figure=px.scatter(df2, x="x", y="y", color="shared_column")),
vm.Graph(id="id_shared_column_categorical", figure=px.scatter(df3, x="x", y="y", color="shared_column")),
],
)
Vizro._pre_build()
Loading

0 comments on commit ae68bd6

Please sign in to comment.