diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index c99e556be..e67308004 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -127,7 +127,6 @@ NamedData, ParameterName, PointSelectionConfig, - Predicate, PredicateComposition, ProjectionType, RepeatMapping, @@ -542,12 +541,19 @@ def check_fields_and_encodings(parameter: Parameter, field_name: str) -> bool: """ -_FieldEqualType: TypeAlias = Union[PrimitiveValue_T, Map, Parameter, SchemaBase] -"""Permitted types for equality checks on field values: +_FieldEqualType: TypeAlias = Union["IntoExpression", Parameter, SchemaBase] +""" +Permitted types for equality checks on field values. + +Applies to the following context(s): + + import altair as alt -- `datum.field == ...` -- `FieldEqualPredicate(equal=...)` -- `when(**constraints=...)` + alt.datum.field == ... + alt.FieldEqualPredicate(field="field", equal=...) + alt.when(field=...) + alt.when().then().when(field=...) + alt.Chart.transform_filter(field=...) """ @@ -2986,45 +2992,113 @@ def transform_extent( """ return self._add_transform(core.ExtentTransform(extent=extent, param=param)) - # TODO: Update docstring - # # E.g. {'not': alt.FieldRangePredicate(field='year', range=[1950, 1960])} def transform_filter( self, - filter: str - | Expr - | Expression - | Predicate - | Parameter - | PredicateComposition - | dict[str, Predicate | str | list | bool], - **kwargs: Any, + predicate: Optional[_PredicateType] = Undefined, + *more_predicates: _ComposablePredicateType, + empty: Optional[bool] = Undefined, + **constraints: _FieldEqualType, ) -> Self: """ - Add a :class:`FilterTransform` to the schema. + Add a :class:`FilterTransform` to the spec. + + The resulting predicate is an ``&`` reduction over ``predicate`` and optional ``*``, ``**``, arguments. Parameters ---------- - filter : a filter expression or :class:`PredicateComposition` - The `filter` property must be one of the predicate definitions: - (1) a string or alt.expr expression - (2) a range predicate - (3) a selection predicate - (4) a logical operand combining (1)-(3) - (5) a Selection object + predicate + A selection or test predicate. ``str`` input will be treated as a test operand. + *more_predicates + Additional predicates, restricted to types supporting ``&``. + empty + For selection parameters, the predicate of empty selections returns ``True`` by default. + Override this behavior, with ``empty=False``. - Returns - ------- - self : Chart object - returns chart to allow for chaining + .. note:: + When ``predicate`` is a ``Parameter`` that is used more than once, + ``self.transform_filter(..., empty=...)`` provides granular control for each occurrence. + **constraints + Specify `Field Equal Predicate`_'s. + Shortcut for ``alt.datum.field_name == value``, see examples for usage. + + Warns + ----- + AltairDeprecationWarning + If called using ``filter`` as a keyword argument. + + See Also + -------- + alt.when : Uses a similar syntax for defining conditional values. + + Notes + ----- + - Directly inspired by the syntax used in `polars.DataFrame.filter`_. + + .. _Field Equal Predicate: + https://vega.github.io/vega-lite/docs/predicate.html#equal-predicate + .. _polars.DataFrame.filter: + https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.filter.html + + Examples + -------- + Setting up a common chart:: + + import altair as alt + from altair import datum + from vega_datasets import data + + source = data.population.url + chart = ( + alt.Chart(source) + .mark_line() + .encode( + x="age:O", + y="sum(people):Q", + color=alt.Color("year:O").legend(symbolType="square"), + ) + ) + chart + + Singular predicates can be expressed via ``datum``:: + + chart.transform_filter(datum.year <= 1980) + + We can also use selection parameters directly:: + + selection = alt.selection_point(encodings=["color"], bind="legend") + chart.transform_filter(selection).add_params(selection) + + Or a field predicate:: + + between_1950_60 = alt.FieldRangePredicate(field="year", range=[1950, 1960]) + chart.transform_filter(between_1950_60) | chart.transform_filter(~between_1950_60) + + Predicates can be composed together using logical operands:: + + chart.transform_filter(between_1950_60 | (datum.year == 1850)) + + Predicates passed as positional arguments will be reduced with ``&``:: + + chart.transform_filter(datum.year > 1980, datum.age != 90) + + Using keyword-argument ``constraints`` can simplify compositions like:: + + verbose_composition = chart.transform_filter((datum.year == 2000) & (datum.sex == 1)) + chart.transform_filter(year=2000, sex=1) """ - if isinstance(filter, Parameter): - new_filter: dict[str, Any] = {"param": filter.name} - if "empty" in kwargs: - new_filter["empty"] = kwargs.pop("empty") - elif isinstance(filter.empty, bool): - new_filter["empty"] = filter.empty - filter = new_filter - return self._add_transform(core.FilterTransform(filter=filter, **kwargs)) + if depr_filter := t.cast(Any, constraints.pop("filter", None)): + utils.deprecated_warn( + "Passing `filter` as a keyword is ambiguous.\n\n" + "Use a positional argument for `<5.5.0` behavior.\n" + "Or, `alt.datum['filter'] == ...` if referring to a column named 'filter'.", + version="5.5.0", + ) + if utils.is_undefined(predicate): + predicate = depr_filter + else: + more_predicates = *more_predicates, depr_filter + cond = _parse_when(predicate, *more_predicates, empty=empty, **constraints) + return self._add_transform(core.FilterTransform(filter=cond.get("test", cond))) def transform_flatten( self, diff --git a/doc/user_guide/transform/filter.rst b/doc/user_guide/transform/filter.rst index 39c268210..62ee6e334 100644 --- a/doc/user_guide/transform/filter.rst +++ b/doc/user_guide/transform/filter.rst @@ -20,6 +20,8 @@ expressions and objects: We'll show a brief example of each of these in the following sections +.. _filter-expression: + Filter Expression ^^^^^^^^^^^^^^^^^ A filter expression uses the `Vega expression`_ language, either specified @@ -189,12 +191,26 @@ Then, we can *invert* this selection using ``~``: chart.transform_filter(~between_1950_60) We can further refine our filter by *composing* multiple predicates together. -In this case, using ``alt.datum``: +In this case, using ``datum``: + +.. altair-plot:: + + chart.transform_filter(~between_1950_60 & (datum.age <= 70)) + +When passing multiple predicates they will be reduced with ``&``: .. altair-plot:: - chart.transform_filter(~between_1950_60 & (alt.datum.age <= 70)) + chart.transform_filter(datum.year > 1980, datum.age != 90) +Using keyword-argument ``constraints`` can simplify our first example in :ref:`filter-expression`: + +.. altair-plot:: + + alt.Chart(source).mark_area().encode( + x="age:O", + y="people:Q", + ).transform_filter(year=2000, sex=1) Transform Options ^^^^^^^^^^^^^^^^^ diff --git a/tests/examples_arguments_syntax/line_chart_with_cumsum_faceted.py b/tests/examples_arguments_syntax/line_chart_with_cumsum_faceted.py index 5a0fdb743..d33df06ad 100644 --- a/tests/examples_arguments_syntax/line_chart_with_cumsum_faceted.py +++ b/tests/examples_arguments_syntax/line_chart_with_cumsum_faceted.py @@ -12,10 +12,8 @@ columns_sorted = ['Drought', 'Epidemic', 'Earthquake', 'Flood'] alt.Chart(source).transform_filter( - {'and': [ - alt.FieldOneOfPredicate(field='Entity', oneOf=columns_sorted), # Filter data to show only disasters in columns_sorted - alt.FieldRangePredicate(field='Year', range=[1900, 2000]) # Filter data to show only 20th century - ]} + alt.FieldOneOfPredicate(field='Entity', oneOf=columns_sorted), + alt.FieldRangePredicate(field='Year', range=[1900, 2000]) ).transform_window( cumulative_deaths='sum(Deaths)', groupby=['Entity'] # Calculate cumulative sum of Deaths by Entity ).mark_line().encode( diff --git a/tests/examples_methods_syntax/line_chart_with_cumsum_faceted.py b/tests/examples_methods_syntax/line_chart_with_cumsum_faceted.py index d9d887ba5..56dcdb931 100644 --- a/tests/examples_methods_syntax/line_chart_with_cumsum_faceted.py +++ b/tests/examples_methods_syntax/line_chart_with_cumsum_faceted.py @@ -12,10 +12,8 @@ columns_sorted = ['Drought', 'Epidemic', 'Earthquake', 'Flood'] alt.Chart(source).transform_filter( - {'and': [ - alt.FieldOneOfPredicate(field='Entity', oneOf=columns_sorted), # Filter data to show only disasters in columns_sorted - alt.FieldRangePredicate(field='Year', range=[1900, 2000]) # Filter data to show only 20th century - ]} + alt.FieldOneOfPredicate(field='Entity', oneOf=columns_sorted), + alt.FieldRangePredicate(field='Year', range=[1900, 2000]) ).transform_window( cumulative_deaths='sum(Deaths)', groupby=['Entity'] # Calculate cumulative sum of Deaths by Entity ).mark_line().encode( diff --git a/tests/vegalite/v5/test_api.py b/tests/vegalite/v5/test_api.py index 0f60d185a..6bb4ac9ef 100644 --- a/tests/vegalite/v5/test_api.py +++ b/tests/vegalite/v5/test_api.py @@ -10,6 +10,7 @@ import re import sys import tempfile +import warnings from collections.abc import Mapping from datetime import date, datetime from importlib.metadata import version as importlib_version @@ -85,7 +86,7 @@ def _make_chart_type(chart_type): @pytest.fixture -def basic_chart(): +def basic_chart() -> alt.Chart: data = pd.DataFrame( { "a": ["A", "B", "C", "D", "E", "F", "G", "H", "I"], @@ -1247,6 +1248,64 @@ def test_predicate_composition() -> None: assert actual_multi == expected_multi +def test_filter_transform_predicates(basic_chart) -> None: + lhs, rhs = alt.datum["b"] >= 30, alt.datum["b"] < 60 + expected = [{"filter": lhs & rhs}] + actual = basic_chart.transform_filter(lhs, rhs).to_dict()["transform"] + assert actual == expected + + +def test_filter_transform_constraints(basic_chart) -> None: + lhs, rhs = alt.datum["a"] == "A", alt.datum["b"] == 30 + expected = [{"filter": lhs & rhs}] + actual = basic_chart.transform_filter(a="A", b=30).to_dict()["transform"] + assert actual == expected + + +def test_filter_transform_predicates_constraints(basic_chart) -> None: + from functools import reduce + from operator import and_ + + predicates = ( + alt.datum["a"] != "A", + alt.datum["a"] != "B", + alt.datum["a"] != "C", + alt.datum["b"] > 1, + alt.datum["b"] < 99, + ) + constraints = {"b": 30, "a": "D"} + pred_constraints = *predicates, alt.datum["b"] == 30, alt.datum["a"] != "D" + expected = [{"filter": reduce(and_, pred_constraints)}] + actual = basic_chart.transform_filter(*predicates, **constraints).to_dict()[ + "transform" + ] + assert actual == expected + + +def test_filter_transform_errors(basic_chart) -> None: + NO_ARGS = r"At least one.+Undefined" + FILTER_KWARGS = r"ambiguous" + + depr_filter = {"field": "year", "oneOf": [1955, 2000]} + expected = [{"filter": depr_filter}] + + with pytest.raises(TypeError, match=NO_ARGS): + basic_chart.transform_filter() + with pytest.raises(TypeError, match=NO_ARGS): + basic_chart.transform_filter(empty=True) + with pytest.raises(TypeError, match=NO_ARGS): + basic_chart.transform_filter(empty=False) + + with pytest.warns(alt.AltairDeprecationWarning, match=FILTER_KWARGS): + basic_chart.transform_filter(filter=depr_filter) + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=alt.AltairDeprecationWarning) + actual = basic_chart.transform_filter(filter=depr_filter).to_dict()["transform"] + + assert actual == expected + + def test_resolve_methods(): chart = alt.LayerChart().resolve_axis(x="shared", y="independent") assert chart.resolve == alt.Resolve(