Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tidy] Prepare for dynamic filters, part 1 of 2 #850

Merged
merged 14 commits into from
Nov 7, 2024
Merged
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ repos:
hooks:
- id: ruff
args: [--fix]
exclude: "vizro-core/examples/scratch_dev/app.py"
- id: ruff-format

- repo: https://github.com/PyCQA/bandit
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<!--
A new scriv changelog fragment.

Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨

- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Removed

- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Added

- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Changed

- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Deprecated

- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Fixed

- A bullet item for the Fixed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
<!--
### Security

- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX. ([#1](https://github.com/mckinsey/vizro/pull/1))

-->
4 changes: 3 additions & 1 deletion vizro-core/hatch.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ VIZRO_LOG_LEVEL = "DEBUG"
extra-dependencies = [
"pydantic==1.10.16",
"dash==2.17.1",
"plotly==5.12.0"
"plotly==5.12.0",
"pandas==2.0.0",
"numpy==1.23.0" # Need numpy<2 to work with pandas==2.0.0. See https://stackoverflow.com/questions/78634235/.
antonymilne marked this conversation as resolved.
Show resolved Hide resolved
]
features = ["kedro"]
python = "3.9"
Expand Down
2 changes: 1 addition & 1 deletion vizro-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies = [
"dash>=2.17.1", # 2.17.1 needed for no_output fix in clientside_callback
"dash_bootstrap_components",
"dash-ag-grid>=31.0.0",
"pandas",
"pandas>=2",
"plotly>=5.12.0",
"pydantic>=1.10.16", # must be synced with pre-commit mypy hook manually
"dash_mantine_components<0.13.0", # 0.13.0 is not compatible with 0.12,
Expand Down
203 changes: 125 additions & 78 deletions vizro-core/src/vizro/models/_controls/filter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from typing import Literal, Union
from typing import Any, Literal, Union

import numpy as np
import pandas as pd
from pandas.api.types import is_datetime64_any_dtype, is_numeric_dtype

Expand Down Expand Up @@ -95,96 +96,43 @@ def check_target_present(cls, target):

@_log_call
def pre_build(self):
self._set_targets()
self._set_column_type()
self._set_selector()
self._validate_disallowed_selector()
self._set_numerical_and_temporal_selectors_values()
self._set_categorical_selectors_options()
self._set_actions()

@_log_call
def build(self):
return self.selector.build()

def _set_targets(self):
if not self.targets:
for component_id in model_manager._get_page_model_ids_with_figure(
page_id=model_manager._get_model_page_id(model_id=ModelID(str(self.id)))
):
# TODO: consider making a helper method in data_manager or elsewhere to reduce this operation being
# duplicated across Filter so much, and/or consider storing the result to avoid repeating it.
# Need to think about this in connection with how to update filters on the fly and duplicated calls
# issue outlined in https://github.com/mckinsey/vizro/pull/398#discussion_r1559120849.
data_source_name = model_manager[component_id]["data_frame"]
data_frame = data_manager[data_source_name].load()
if self.column in data_frame.columns:
self.targets.append(component_id)
if not self.targets:
raise ValueError(f"Selected column {self.column} not found in any dataframe on this page.")

def _set_column_type(self):
data_source_name = model_manager[self.targets[0]]["data_frame"]
data_frame = data_manager[data_source_name].load()

if is_numeric_dtype(data_frame[self.column]):
self._column_type = "numerical"
elif is_datetime64_any_dtype(data_frame[self.column]):
self._column_type = "temporal"
if self.targets:
targeted_data = self._validate_targeted_data(targets=self.targets)
else:
self._column_type = "categorical"
# If targets aren't explicitly provided then try to target all figures on the page. In this case we don't
antonymilne marked this conversation as resolved.
Show resolved Hide resolved
# want to raise an error if the column is not found in a figure's data_frame, it will just be ignored.
# Possibly in future this will change (which would be breaking change).
targeted_data = self._validate_targeted_data(
targets=model_manager._get_page_model_ids_with_figure(
page_id=model_manager._get_model_page_id(model_id=ModelID(str(self.id)))
),
eagerly_raise_column_not_found_error=False,
)
self.targets = list(targeted_data.columns)

def _set_selector(self):
# Set default selector according to column type.
self._column_type = self._validate_column_type(targeted_data)
self.selector = self.selector or SELECTORS[self._column_type][0]()
self.selector.title = self.selector.title or self.column.title()

def _validate_disallowed_selector(self):
if isinstance(self.selector, DISALLOWED_SELECTORS.get(self._column_type, ())):
raise ValueError(
f"Chosen selector {self.selector.type} is not compatible "
f"with {self._column_type} column '{self.column}'. "
f"Chosen selector {type(self.selector).__name__} is not compatible with {self._column_type} column "
f"'{self.column}'."
)

def _set_numerical_and_temporal_selectors_values(self):
# If the selector is a numerical or temporal selector, and the min and max values are not set, then set them
# N.B. All custom selectors inherit from numerical or temporal selector should also pass this check
# Set appropriate properties for the selector.
if isinstance(self.selector, SELECTORS["numerical"] + SELECTORS["temporal"]):
min_values = []
max_values = []
for target_id in self.targets:
data_source_name = model_manager[target_id]["data_frame"]
data_frame = data_manager[data_source_name].load()
min_values.append(data_frame[self.column].min())
max_values.append(data_frame[self.column].max())

if not (
(is_numeric_dtype(pd.Series(min_values)) and is_numeric_dtype(pd.Series(max_values)))
or (is_datetime64_any_dtype(pd.Series(min_values)) and is_datetime64_any_dtype(pd.Series(max_values)))
):
raise ValueError(
f"Inconsistent types detected in the shared data column '{self.column}' for targeted charts "
f"{self.targets}. Please ensure that the data column contains the same data type across all "
f"targeted charts."
)

_min, _max = self._get_min_max(targeted_data)
# Note that manually set self.selector.min/max = 0 are Falsey but should not be overwritten.
if self.selector.min is None:
self.selector.min = min(min_values)
self.selector.min = _min
if self.selector.max is None:
self.selector.max = max(max_values)

def _set_categorical_selectors_options(self):
# If the selector is a categorical selector, and the options are not set, then set them
# N.B. All custom selectors inherit from categorical selector should also pass this check
if isinstance(self.selector, SELECTORS["categorical"]) and not self.selector.options:
options = set()
for target_id in self.targets:
data_source_name = model_manager[target_id]["data_frame"]
data_frame = data_manager[data_source_name].load()
options |= set(data_frame[self.column])

self.selector.options = sorted(options)
self.selector.max = _max
else:
# Categorical selector.
self.selector.options = self.selector.options or self._get_options(targeted_data)

def _set_actions(self):
if not self.selector.actions:
if isinstance(self.selector, RangeSlider) or (
isinstance(self.selector, DatePicker) and self.selector.range
Expand All @@ -199,3 +147,102 @@ def _set_actions(self):
function=_filter(filter_column=self.column, targets=self.targets, filter_function=filter_function),
)
]

def __call__(self, **kwargs):
antonymilne marked this conversation as resolved.
Show resolved Hide resolved
# Only relevant for a dynamic filter.
# TODO: this will need to pass parametrised data_frame arguments through to _validate_targeted_data.
# Although targets are fixed at build time, the validation logic is repeated during runtime, so if a column
# is missing then it will raise an error. We could change this if we wanted.
targeted_data = self._validate_targeted_data(targets=self.targets)

if (column_type := self._validate_column_type(targeted_data)) != self._column_type:
raise ValueError(
f"{self.column} has changed type from {self._column_type} to {column_type}. A filtered column cannot "
"change type while the dashboard is running."
)

# TODO: when implement dynamic, will need to do something with this e.g. pass to selector.__call__.
# if isinstance(self.selector, SELECTORS["numerical"] + SELECTORS["temporal"]):
# options = self._get_options(targeted_data)
# else:
# # Categorical selector.
# _min, _max = self._get_min_max(targeted_data)

petar-qb marked this conversation as resolved.
Show resolved Hide resolved
@_log_call
def build(self):
return self.selector.build()

def _validate_targeted_data(
self, targets: list[ModelID], eagerly_raise_column_not_found_error=True
) -> pd.DataFrame:
# TODO: consider moving some of this logic to data_manager when implement dynamic filter. Make sure
# get_modified_figures and stuff in _actions_utils.py is as efficient as code here.

# When loading data_frame there are possible keys:
# 1. target. In worst case scenario this is needed but can lead to unnecessary repeated data loading.
# 2. data_source_name. No repeated data loading but won't work when applying data_frame parameters at runtime.
# 3. target + data_frame parameters keyword-argument pairs. This is the correct key to use at runtime.
# For now we follow scheme 2 for data loading (due to set() below) and 1 for the returned targeted_data
# pd.DataFrame, i.e. a separate column for each target even if some data is repeated.
# TODO: when this works with data_frame parameters load() will need to take arguments and the structures here
# might change a bit.
target_to_data_source_name = {target: model_manager[target]["data_frame"] for target in targets}
data_source_name_to_data = {
data_source_name: data_manager[data_source_name].load()
antonymilne marked this conversation as resolved.
Show resolved Hide resolved
for data_source_name in set(target_to_data_source_name.values())
}
target_to_series = {}

for target, data_source_name in target_to_data_source_name.items():
data_frame = data_source_name_to_data[data_source_name]

if self.column in data_frame.columns:
# reset_index so that when we make a DataFrame out of all these pd.Series pandas doesn't try to align
# the columns by index.
target_to_series[target] = data_frame[self.column].reset_index(drop=True)
antonymilne marked this conversation as resolved.
Show resolved Hide resolved
elif eagerly_raise_column_not_found_error:
raise ValueError(f"Selected column {self.column} not found in dataframe for {target}.")

targeted_data = pd.DataFrame(target_to_series)
if targeted_data.columns.empty:
# Still raised when eagerly_raise_column_not_found_error=False.
raise ValueError(f"Selected column {self.column} not found in any dataframe for {', '.join(targets)}.")
if targeted_data.empty:
raise ValueError(
f"Selected column {self.column} does not contain anything in any dataframe for {', '.join(targets)}."
)

return targeted_data

def _validate_column_type(self, targeted_data: pd.DataFrame) -> Literal["numerical", "categorical", "temporal"]:
antonymilne marked this conversation as resolved.
Show resolved Hide resolved
is_numerical = targeted_data.apply(is_numeric_dtype)
is_temporal = targeted_data.apply(is_datetime64_any_dtype)
is_categorical = ~is_numerical & ~is_temporal

if is_numerical.all():
return "numerical"
elif is_temporal.all():
return "temporal"
elif is_categorical.all():
return "categorical"
else:
raise ValueError(
f"Inconsistent types detected in column {self.column}. This column must have the same type for all "
"targets."
)
antonymilne marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def _get_min_max(targeted_data: pd.DataFrame) -> tuple[float, float]:
# Use item() to convert to convert scalar from numpy to Python type. This isn't needed during pre_build because
# pydantic will coerce the type, but it is necessary in __call__ where we don't update model field values
# and instead just pass straight to the Dash component.
return targeted_data.min(axis=None).item(), targeted_data.max(axis=None).item()

@staticmethod
def _get_options(targeted_data: pd.DataFrame) -> list[Any]:
# Use tolist() to convert to convert scalar from numpy to Python type. This isn't needed during pre_build
# because pydantic will coerce the type, but it is necessary in __call__ where we don't update model field
# values and instead just pass straight to the Dash component.
# The dropna() isn't strictly required here but will be in future pandas versions when the behavior of stack
# changes. See https://pandas.pydata.org/docs/whatsnew/v2.1.0.html#whatsnew-210-enhancements-new-stack.
return np.unique(targeted_data.stack().dropna()).tolist() # noqa: PD013
antonymilne marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def managers_one_page_without_graphs_one_button():
vm.Page(
id="test_page",
title="Test page",
components=[vm.Graph(figure=px.scatter(data_frame=pd.DataFrame(columns=["A"]), x="A", y="A"))],
components=[vm.Graph(figure=px.scatter(data_frame=pd.DataFrame(data={"A": [1], "B": [2]}), x="A", y="B"))],
antonymilne marked this conversation as resolved.
Show resolved Hide resolved
controls=[vm.Filter(id="test_filter", column="A")],
)
Vizro._pre_build()
Expand Down
36 changes: 0 additions & 36 deletions vizro-core/tests/unit/vizro/models/_controls/conftest.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,10 @@
import datetime
import random

import numpy as np
import pandas as pd
import pytest

import vizro.models as vm
import vizro.plotly.express as px
from vizro import Vizro


@pytest.fixture
def dfs_with_shared_column():
df1 = pd.DataFrame()
df1["x"] = np.random.uniform(0, 10, 100)
df1["y"] = np.random.uniform(0, 10, 100)
df2 = df1.copy()
df3 = df1.copy()

df1["shared_column"] = np.random.uniform(0, 10, 100)
df2["shared_column"] = [datetime.datetime(2024, 1, 1) + datetime.timedelta(days=i) for i in range(100)]
df3["shared_column"] = random.choices(["CATEGORY 1", "CATEGORY 2"], k=100)

return df1, df2, df3


@pytest.fixture
def managers_one_page_two_graphs(gapminder):
"""Instantiates a simple model_manager and data_manager with a page, and two graph models and gapminder data."""
Expand All @@ -37,19 +17,3 @@ def managers_one_page_two_graphs(gapminder):
],
)
Vizro._pre_build()


@pytest.fixture
def managers_shared_column_different_dtype(dfs_with_shared_column):
"""Instantiates the managers with a page and two graphs sharing the same column but of different data types."""
df1, df2, df3 = dfs_with_shared_column
vm.Page(
id="graphs_with_shared_column",
title="Page Title",
components=[
vm.Graph(id="id_shared_column_numerical", figure=px.scatter(df1, x="x", y="y", color="shared_column")),
vm.Graph(id="id_shared_column_temporal", figure=px.scatter(df2, x="x", y="y", color="shared_column")),
vm.Graph(id="id_shared_column_categorical", figure=px.scatter(df3, x="x", y="y", color="shared_column")),
],
)
Vizro._pre_build()
Loading