diff --git a/src/otf_api/filters.py b/src/otf_api/filters.py new file mode 100644 index 0000000..4f5f215 --- /dev/null +++ b/src/otf_api/filters.py @@ -0,0 +1,42 @@ +from datetime import date, time + +from pydantic import BaseModel, field_validator + +from otf_api.models import ClassType, DoW + + +class ClassFilter(BaseModel): + start_date: date | None = None + end_date: date | None = None + class_type: list[ClassType] | None = None + day_of_week: list[DoW] | None = None + start_time: list[time] | None = None + + @field_validator("class_type", "day_of_week", "start_time", mode="before") + @classmethod + def single_item_to_list(cls, v): + if v and not isinstance(v, list): + return [v] + return v + + @field_validator("day_of_week", mode="before") + @classmethod + def day_of_week_str_to_enum(cls, v): + if v and isinstance(v, str): + return [DoW(v.title())] + + if v and isinstance(v, list) and not all(isinstance(i, DoW) for i in v): + return [DoW(i.title()) for i in v] + + return v + + @field_validator("class_type", mode="before") + @classmethod + def class_type_str_to_enum(cls, v): + if v and isinstance(v, str): + return [ClassType.get_case_insensitive(v)] + + if v and isinstance(v, list) and not all(isinstance(i, ClassType) for i in v): + return [ClassType.get_case_insensitive(i) for i in v] + + return v diff --git a/tests/test_classes.py b/tests/test_classes.py new file mode 100644 index 0000000..ac3031d --- /dev/null +++ b/tests/test_classes.py @@ -0,0 +1,27 @@ +import os + +import pytest + +from otf_api import Otf as Otf +from otf_api.models import ClassType, DoW + + +def test_get_classes_filters(): + username = os.getenv("OTF_EMAIL") + password = os.getenv("OTF_PASSWORD") + + if not username or not password: + raise ValueError("Please set OTF_EMAIL and OTF_PASSWORD environment variables") + + otf = Otf(username=username, password=password) + + with otf: + classes = otf.get_classes( + start_date="2024-12-29", + end_date="2025-01-30", + class_type=ClassType.ORANGE_3G, + day_of_week=DoW.SATURDAY, + include_home_studio=False, + ) + + assert len(classes) diff --git a/tests/test_filters.py b/tests/test_filters.py new file mode 100644 index 0000000..77908b5 --- /dev/null +++ b/tests/test_filters.py @@ -0,0 +1,77 @@ +from datetime import date, datetime, time + +import pytest + +from otf_api.filters import ClassFilter +from otf_api.models import ClassType, DoW + + +def test_class_filter_string_to_date(): + cf = ClassFilter(start_date="2024-12-29", end_date="2025-01-30") + assert cf.start_date == date(2024, 12, 29) + assert cf.end_date == date(2025, 1, 30) + + +def test_class_filter_datetime_str_to_date(): + cf = ClassFilter(start_date="2024-12-29T00:00:00", end_date="2025-01-30T00:00:00") + assert cf.start_date == date(2024, 12, 29) + assert cf.end_date == date(2025, 1, 30) + + +def test_class_filter_datetime_object_to_date(): + cf = ClassFilter(start_date=datetime(2024, 12, 29), end_date=datetime(2025, 1, 30)) + assert cf.start_date == date(2024, 12, 29) + assert cf.end_date == date(2025, 1, 30) + + +def test_class_filter_single_item_to_list(): + cf = ClassFilter(class_type=ClassType.ORANGE_3G) + assert cf.class_type == [ClassType.ORANGE_3G] + + cf = ClassFilter(start_time=time(12, 30)) + assert cf.start_time == [time(12, 30)] + + cf = ClassFilter(day_of_week=DoW.SUNDAY) + assert cf.day_of_week == [DoW.SUNDAY] + + +@pytest.mark.parametrize( + ["provided", "expected"], + [ + ("ORANGE 3G", [ClassType.ORANGE_3G]), + ("orange 3g", [ClassType.ORANGE_3G]), + (["ORANGE 3G"], [ClassType.ORANGE_3G]), + ([ClassType.ORANGE_3G, "ORANGE Tornado"], [ClassType.ORANGE_3G, ClassType.ORANGE_TORNADO]), + ], +) +def test_class_type_str_to_enum(provided, expected): + cf = ClassFilter(class_type=provided) + assert cf.class_type == expected + + +@pytest.mark.parametrize( + ["provided", "expected"], + [ + ("Sunday", [DoW.SUNDAY]), + ("sunday", [DoW.SUNDAY]), + (["Sunday"], [DoW.SUNDAY]), + ([DoW.SUNDAY, "Monday"], [DoW.SUNDAY, DoW.MONDAY]), + ], +) +def test_day_of_week_str_to_enum(provided, expected): + cf = ClassFilter(day_of_week=provided) + assert cf.day_of_week == expected + + +@pytest.mark.parametrize( + ["provided", "expected"], + [ + ("12:30", [time(12, 30)]), + ("12:30", [time(12, 30)]), + (["12:30"], [time(12, 30)]), + ([time(12, 30), "13:45"], [time(12, 30), time(13, 45)]), + ], +) +def test_class_filter_string_time(provided, expected): + cf = ClassFilter(start_time=provided) + assert cf.start_time == expected