Skip to content

Commit

Permalink
Fix ValueError in Filter and add unit test for it (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
huong-li-nguyen authored Sep 26, 2023
1 parent bcbc56a commit 583e951
Show file tree
Hide file tree
Showing 10 changed files with 120 additions and 39 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/lint-vizro-core.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@ on:
push:
branches: [main]
pull_request:
# TODO: remove "dev/mono_repo_structure" after test
branches:
- "main"
- "dev/mono_repo_structure"

concurrency:
group: lint-${{ github.head_ref }}
Expand Down
2 changes: 0 additions & 2 deletions .github/workflows/test-integration-vizro-core.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@ on:
push:
branches: [main]
pull_request:
# TODO: remove "dev/mono_repo_structure" after test
branches:
- "main"
- "dev/mono_repo_structure"

concurrency:
group: test-integration-${{ github.head_ref }}
Expand Down
2 changes: 0 additions & 2 deletions .github/workflows/test-unit-vizro-core.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@ on:
push:
branches: [main]
pull_request:
# TODO: remove "dev/mono_repo_structure" after test
branches:
- "main"
- "dev/mono_repo_structure"

concurrency:
group: test-unit-${{ github.head_ref }}
Expand Down
2 changes: 1 addition & 1 deletion tools/check_for_datafiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"sqlite3",
"orc",
]
whitelist_folders = [] # starting from project root dir
whitelist_folders = ["/venv"] # starting from project root dir


def check_for_data_files():
Expand Down
2 changes: 1 addition & 1 deletion tools/find_forbidden_words_in_repo.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash

words=$"colour\|visualisation"
words_finder=$(grep -Irwno --exclude=find_forbidden_words_in_repo.sh --exclude-dir={.git,*cache*,*node_modules*} . -e "$words")
words_finder=$(grep -Irwno --exclude=find_forbidden_words_in_repo.sh --exclude-dir={.git,*cache*,*node_modules*,venv} . -e "$words")

if [[ $words_finder ]]
then
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
<!--
A new scriv changelog fragment.
Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Removed
- A bullet item for the Removed category.
-->
<!--
### Added
- A bullet item for the Added category.
-->
<!--
### Changed
- A bullet item for the Changed category.
-->
<!--
### Deprecated
- A bullet item for the Deprecated category.
-->

### Fixed

- Raise `ValueError` of shared column with inconsistent dtypes properly ([#64](https://github.com/mckinsey/vizro/pull/64))

<!--
### Security
- A bullet item for the Security category.
-->
14 changes: 9 additions & 5 deletions vizro-core/src/vizro/models/_controls/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import TYPE_CHECKING, List, Literal, Optional

import pandas as pd
from pandas.api.types import is_numeric_dtype, is_period_dtype
from pandas.api.types import is_numeric_dtype
from pydantic import Field, PrivateAttr, validator

from vizro._constants import FILTER_ACTION_PREFIX
Expand Down Expand Up @@ -103,7 +103,7 @@ def _set_targets(self):

def _set_column_type(self):
data_frame = data_manager._get_component_data(self.targets[0])
if is_period_dtype(data_frame[self.column]) or is_numeric_dtype(data_frame[self.column]):
if isinstance(data_frame[self.column], pd.PeriodDtype) or is_numeric_dtype(data_frame[self.column]):
self._column_type = "numerical"
else:
self._column_type = "categorical"
Expand All @@ -119,16 +119,20 @@ def _set_slider_values(self):
if isinstance(self.selector, SELECTORS["numerical"]):
if self._column_type != "numerical":
raise ValueError(
f"Chosen selector {self.selector.type} is not compatible with column_type {self._column_type}."
f"Chosen selector {self.selector.type} is not compatible "
f"with {self._column_type} column '{self.column}'."
)
min_values = []
max_values = []
for target_id in self.targets:
data_frame = data_manager._get_component_data(target_id)
min_values.append(data_frame[self.column].min())
max_values.append(data_frame[self.column].max())
if not is_numeric_dtype(min(min_values)) or not is_numeric_dtype(max(max_values)):
raise ValueError(f"No numeric value detected in chosen column {self.column} for numerical selector.")
if not is_numeric_dtype(pd.Series(min_values)) or not is_numeric_dtype(pd.Series(max_values)):
raise ValueError(
f"Non-numeric values detected in the shared data column '{self.column}' for targeted charts. "
f"Please ensure that the data column contains the same data type across all targeted charts."
)
if self.selector.min is None:
self.selector.min = min(min_values)
if self.selector.max is None:
Expand Down
2 changes: 1 addition & 1 deletion vizro-core/src/vizro/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _parse_json(
else:
raise ValueError(f"_target_={function_name} must be wrapped in the @capture decorator.")

# TODO-actions: Find the way how to compare CapturedCallable and function
# TODO-actions: Find a way how to compare CapturedCallable and function
@property
def _function(self):
return self.__function
Expand Down
32 changes: 32 additions & 0 deletions vizro-core/tests/unit/vizro/models/_controls/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,27 @@
import random

import numpy as np
import pandas as pd
import pytest

import vizro.models as vm
import vizro.plotly.express as px
from vizro import Vizro


@pytest.fixture
def dfs_with_shared_column():
df1 = pd.DataFrame()
df1["x"] = np.random.uniform(0, 10, 100)
df1["y"] = np.random.uniform(0, 10, 100)
df2 = df1.copy()

df1["shared_column"] = np.random.uniform(0, 10, 100)
df2["shared_column"] = random.choices(["CATEGORY 1", "CATEGORY 2"], k=100)

return df1, df2


@pytest.fixture
def managers_one_page_two_graphs(gapminder):
"""Instantiates a simple model_manager and data_manager with a page, and two graph models and gapminder data."""
Expand All @@ -17,3 +34,18 @@ def managers_one_page_two_graphs(gapminder):
],
)
Vizro._pre_build()


@pytest.fixture
def managers_shared_column_different_dtype(dfs_with_shared_column):
"""Instantiates the managers with a page and two graphs sharing the same column but of different data types."""
df1, df2 = dfs_with_shared_column
vm.Page(
id="graphs_with_shared_column",
title="Page Title",
components=[
vm.Graph(figure=px.scatter(df1, x="x", y="y", color="shared_column")),
vm.Graph(figure=px.scatter(df2, x="x", y="y", color="shared_column")),
],
)
Vizro._pre_build()
60 changes: 35 additions & 25 deletions vizro-core/tests/unit/vizro/models/_controls/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,16 @@ def test_check_target_present_invalid(self):
Filter(column="foo", targets=["invalid_target"])


@pytest.mark.usefixtures("managers_one_page_two_graphs")
class TestPreBuildMethod:
def test_set_targets_valid(self):
def test_set_targets_valid(self, managers_one_page_two_graphs):
# Core of tests is still interface level
filter = vm.Filter(column="country")
# Special case - need filter in the context of page in order to run filter.pre_build
model_manager["test_page"].controls = [filter]
filter.pre_build()
assert set(filter.targets) == {"scatter_chart", "bar_chart"}

def test_set_targets_invalid(self):
def test_set_targets_invalid(self, managers_one_page_two_graphs):
filter = vm.Filter(column="invalid_choice")
model_manager["test_page"].controls = [filter]

Expand All @@ -95,7 +94,7 @@ def test_set_targets_invalid(self):
@pytest.mark.parametrize(
"test_input,expected", [("country", "categorical"), ("year", "numerical"), ("lifeExp", "numerical")]
)
def test_set_column_type(self, test_input, expected):
def test_set_column_type(self, test_input, expected, managers_one_page_two_graphs):
filter = vm.Filter(column=test_input)
model_manager["test_page"].controls = [filter]
filter.pre_build()
Expand All @@ -104,40 +103,56 @@ def test_set_column_type(self, test_input, expected):
@pytest.mark.parametrize(
"test_input,expected", [("country", vm.Dropdown), ("year", vm.RangeSlider), ("lifeExp", vm.RangeSlider)]
)
def test_set_selector(self, test_input, expected):
def test_set_selector(self, test_input, expected, managers_one_page_two_graphs):
filter = vm.Filter(column=test_input)
model_manager["test_page"].controls = [filter]
filter.pre_build()
assert isinstance(filter.selector, expected)
assert filter.selector.title == test_input.title()

@pytest.mark.parametrize("test_input", [vm.Slider(), vm.RangeSlider()])
def test_determine_slider_defaults_invalid_selector(self, test_input):
def test_set_slider_values_incompatible_column_type(self, test_input, managers_one_page_two_graphs):
filter = vm.Filter(column="country", selector=test_input)
model_manager["test_page"].controls = [filter]
with pytest.raises(
ValueError, match=f"Chosen selector {test_input.type} is not compatible with column_type categorical."
ValueError,
match=f"Chosen selector {test_input.type} is not compatible with categorical column '{filter.column}'.",
):
filter.pre_build()

@pytest.mark.parametrize("test_input", [vm.Slider(), vm.RangeSlider()])
def test_set_slider_values_shared_column_inconsistent_dtype(
self, test_input, managers_shared_column_different_dtype
):
filter = vm.Filter(column="shared_column", selector=test_input)
model_manager["graphs_with_shared_column"].controls = [filter]
with pytest.raises(
ValueError,
match=f"Non-numeric values detected in the shared data column '{filter.column}' for targeted charts. "
f"Please ensure that the data column contains the same data type across all targeted charts.",
):
filter.pre_build()

@pytest.mark.parametrize("test_input", [vm.Slider(), vm.RangeSlider()])
def test_set_slider_values_defaults_min_max_none(self, test_input, gapminder):
def test_set_slider_values_defaults_min_max_none(self, test_input, gapminder, managers_one_page_two_graphs):
filter = vm.Filter(column="lifeExp", selector=test_input)
model_manager["test_page"].controls = [filter]
filter.pre_build()
assert filter.selector.min == gapminder.lifeExp.min()
assert filter.selector.max == gapminder.lifeExp.max()

@pytest.mark.parametrize("test_input", [vm.Slider(min=3, max=5), vm.RangeSlider(min=3, max=5)])
def test_set_slider_values_defaults_min_max_fix(self, test_input):
def test_set_slider_values_defaults_min_max_fix(self, test_input, managers_one_page_two_graphs):
filter = vm.Filter(column="lifeExp", selector=test_input)
model_manager["test_page"].controls = [filter]
filter.pre_build()
assert filter.selector.min == 3
assert filter.selector.max == 5

@pytest.mark.parametrize("test_input", [vm.Checklist(), vm.Dropdown(), vm.RadioItems()])
def test_set_categorical_selectors_options_defaults_options_none(self, test_input, gapminder):
def test_set_categorical_selectors_options_defaults_options_none(
self, test_input, gapminder, managers_one_page_two_graphs
):
filter = vm.Filter(column="continent", selector=test_input)
model_manager["test_page"].controls = [filter]
filter.pre_build()
Expand All @@ -151,16 +166,21 @@ def test_set_categorical_selectors_options_defaults_options_none(self, test_inpu
vm.RadioItems(options=["Africa", "Europe"]),
],
)
def test_set_categorical_selectors_options_defaults_options_fix(self, test_input):
def test_set_categorical_selectors_options_defaults_options_fix(self, test_input, managers_one_page_two_graphs):
filter = vm.Filter(column="continent", selector=test_input)
model_manager["test_page"].controls = [filter]
filter.pre_build()
assert filter.selector.options == ["Africa", "Europe"]


# TODO: split out pre_build method, and test only the units
# TODO: write test for: "No numeric value detected in chosen column lifeExp for numerical selector.")
# TODO: write tests for where there are columns shared, but the content is different
@pytest.mark.parametrize("test_input", ["country", "year", "lifeExp"])
def test_set_actions(self, test_input, managers_one_page_two_graphs):
filter = vm.Filter(column=test_input)
model_manager["test_page"].controls = [filter]
filter.pre_build()
default_action = filter.selector.actions[0]
assert isinstance(default_action, ActionsChain)
assert isinstance(default_action.actions[0].function, CapturedCallable)
assert default_action.actions[0].id == f"filter_action_{filter.id}"


@pytest.mark.usefixtures("managers_one_page_two_graphs")
Expand All @@ -184,13 +204,3 @@ def test_filter_build(self, test_column, test_selector):
result = str(filter.build())
expected = str(test_selector.build())
assert result == expected

@pytest.mark.parametrize("test_input", ["country", "year", "lifeExp"])
def test_set_actions(self, test_input):
filter = vm.Filter(column=test_input)
model_manager["test_page"].controls = [filter]
filter.pre_build()
default_action = filter.selector.actions[0]
assert isinstance(default_action, ActionsChain)
assert isinstance(default_action.actions[0].function, CapturedCallable)
assert default_action.actions[0].id == f"filter_action_{filter.id}"

0 comments on commit 583e951

Please sign in to comment.