diff --git a/aredis_om/model/encoders.py b/aredis_om/model/encoders.py index 4007640f..53785ebb 100644 --- a/aredis_om/model/encoders.py +++ b/aredis_om/model/encoders.py @@ -23,8 +23,11 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ - +import calendar +import copy import dataclasses +import datetime +import math from collections import defaultdict from enum import Enum from pathlib import PurePath @@ -35,8 +38,36 @@ from pydantic.json import ENCODERS_BY_TYPE +# TODO: check if correct +def date_to_timestamp(t: datetime.date) -> int: + return calendar.timegm(t.timetuple()) + + +# TODO: check if correct +def datetime_to_timestamp(t: datetime.datetime) -> int: + return math.floor(t.astimezone(datetime.timezone.utc).timestamp() * 1000) + + +zero_time = datetime.datetime.fromtimestamp(0) +zero_day = zero_time.date() + + +# TODO: check if correct +def time_to_timestamp(t: datetime.time) -> int: + time_point = datetime.datetime.combine(zero_day, t, t.tzinfo) + return datetime_to_timestamp(time_point) + + SetIntStr = Set[Union[int, str]] DictIntStrAny = Dict[Union[int, str], Any] +ENCODERS_BY_TYPE_ENHANCED = copy.copy(ENCODERS_BY_TYPE) +ENCODERS_BY_TYPE_ENHANCED.update( + { + datetime.date: date_to_timestamp, + datetime.datetime: datetime_to_timestamp, + datetime.time: time_to_timestamp, + } +) def generate_encoders_by_class_tuples( @@ -154,8 +185,8 @@ def jsonable_encoder( if isinstance(obj, encoder_type): return encoder(obj) - if type(obj) in ENCODERS_BY_TYPE: - return ENCODERS_BY_TYPE[type(obj)](obj) + if type(obj) in ENCODERS_BY_TYPE_ENHANCED: + return ENCODERS_BY_TYPE_ENHANCED[type(obj)](obj) for encoder, classes_tuple in encoders_by_class_tuples.items(): if isinstance(obj, classes_tuple): return encoder(obj) diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 92bb6f9a..b6036622 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -1,5 +1,6 @@ import abc import dataclasses +import datetime import decimal import json import logging @@ -38,7 +39,12 @@ from ..checks import has_redis_json, has_redisearch from ..connections import get_redis_connection from ..unasync_util import ASYNC_MODE -from .encoders import jsonable_encoder +from .encoders import ( + date_to_timestamp, + datetime_to_timestamp, + jsonable_encoder, + time_to_timestamp, +) from .render_tree import render_tree from .token_escaper import TokenEscaper @@ -330,7 +336,14 @@ class RediSearchFieldTypes(Enum): # TODO: How to handle Geo fields? -NUMERIC_TYPES = (float, int, decimal.Decimal) +NUMERIC_TYPES = ( + float, + int, + decimal.Decimal, + datetime.date, + datetime.datetime, + datetime.time, +) DEFAULT_PAGE_SIZE = 1000 @@ -530,6 +543,7 @@ def resolve_value( f"Docs: {ERRORS_URL}#E5" ) elif field_type is RediSearchFieldTypes.NUMERIC: + value = jsonable_encoder(value) if op is Operators.EQ: result += f"@{field_name}:[{value} {value}]" elif op is Operators.NE: @@ -1461,6 +1475,13 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo): class JsonModel(RedisModel, abc.ABC): + class Config(RedisModel.Config): + json_encoders = { + datetime.date: date_to_timestamp, + datetime.datetime: datetime_to_timestamp, + datetime.time: time_to_timestamp, + } + def __init_subclass__(cls, **kwargs): # Generate the RediSearch schema once to validate fields. cls.redisearch_schema() diff --git a/tests/conftest.py b/tests/conftest.py index 9f067a38..090d0704 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ import asyncio import random +from typing import Iterable import pytest @@ -37,6 +38,20 @@ def _delete_test_keys(prefix: str, conn): conn.delete(*keys) +def _delete_test_indexes(prefix: str, conn): + # TODO: move to scan when available + # https://redis.io/commands/ft._list/ + from redis import ResponseError + + try: + indexes: Iterable[str] = conn.execute_command("ft._list") + except ResponseError: + return + for index in indexes: + if index.startswith(prefix): + conn.execute_command("ft.dropindex", index, "dd") + + @pytest.fixture def key_prefix(request, redis): key_prefix = f"{TEST_PREFIX}:{random.random()}" @@ -59,3 +74,4 @@ def cleanup_keys(request): # Delete keys only once if conn.decr(once_key) == 0: _delete_test_keys(TEST_PREFIX, conn) + _delete_test_indexes(TEST_PREFIX, conn) diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index 0a79aa6b..b268d4b2 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -25,6 +25,7 @@ from redis_om import has_redisearch from tests.conftest import py_test_mark_asyncio + if not has_redisearch(): pytestmark = pytest.mark.skip diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 8a114f9a..f5aba4b8 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -27,6 +27,7 @@ from redis_om import has_redis_json from tests.conftest import py_test_mark_asyncio + if not has_redis_json(): pytestmark = pytest.mark.skip @@ -437,7 +438,11 @@ async def test_recursive_query_expression_resolution(members, m): async def test_recursive_query_field_resolution(members, m): member1, _, _ = members member1.address.note = m.Note( - description="Weird house", created_on=datetime.datetime.now() + description="Weird house", + created_on=datetime.datetime.now().replace( + microsecond=0, + tzinfo=datetime.timezone.utc, + ), ) await member1.save() actual = await m.Member.find( @@ -449,7 +454,10 @@ async def test_recursive_query_field_resolution(members, m): m.Order( items=[m.Item(price=10.99, name="Ball")], total=10.99, - created_on=datetime.datetime.now(), + created_on=datetime.datetime.now().replace( + microsecond=0, + tzinfo=datetime.timezone.utc, + ), ) ] await member1.save() diff --git a/tests/test_time.py b/tests/test_time.py new file mode 100644 index 00000000..af976703 --- /dev/null +++ b/tests/test_time.py @@ -0,0 +1,308 @@ +import abc +import datetime +from operator import attrgetter + +import pytest +import pytest_asyncio +from pydantic import validator + +from aredis_om import Field, HashModel, JsonModel, Migrator + +from .conftest import py_test_mark_asyncio + + +# TODO: disable tests based on checks +@pytest_asyncio.fixture(params=[HashModel, JsonModel]) +async def post_model_datetime(request, key_prefix): + base_model = request.param + + class BaseModel(base_model, abc.ABC): + class Meta: + global_key_prefix = key_prefix + + class PostDatetime(BaseModel): + created: datetime.datetime = Field(index=True) + + await Migrator().run() + return PostDatetime + + +# TODO: code duplication +@py_test_mark_asyncio +async def test_datetime(post_model_datetime): + now = datetime.datetime(1980, 1, 1, hour=2, second=20, tzinfo=datetime.timezone.utc) + now_p10 = now + datetime.timedelta(seconds=10) + now_m10 = now - datetime.timedelta(seconds=10) + + next_hour_timezone = datetime.timezone(datetime.timedelta(hours=1)) + now_01_00 = now.replace(hour=3, tzinfo=next_hour_timezone) + # Sanity check + assert now == now_01_00 + + posts = [ + post_model_datetime(created=time_point) + for time_point in (now, now_p10, now_m10) + ] + for post in posts: + await post.save() + + expected_sorted_posts = sorted(posts, key=attrgetter("created")) + + # Check all + assert ( + await post_model_datetime.find().sort_by("created").all() + == expected_sorted_posts + ) + # Check one + assert await post_model_datetime.find(post_model_datetime.created == now).all() == [ + posts[0] + ] + # Check one using different timezone but the same time + assert await post_model_datetime.find( + post_model_datetime.created == now_01_00 + ).all() == [posts[0]] + + # Check one + post = await post_model_datetime.find(post_model_datetime.created == now).first() + assert post.created == now == now_01_00 + + # Check index comparison + assert ( + await post_model_datetime.find(post_model_datetime.created < now_p10) + .sort_by("created") + .all() + == expected_sorted_posts[:2] + ) + assert ( + await post_model_datetime.find(post_model_datetime.created < now) + .sort_by("created") + .all() + == expected_sorted_posts[:1] + ) + assert ( + await post_model_datetime.find(post_model_datetime.created < now_m10) + .sort_by("created") + .all() + == [] + ) + + +# TODO: disable tests based on checks +@pytest_asyncio.fixture(params=[HashModel, JsonModel]) +async def post_model_date(request, key_prefix): + base_model = request.param + + class BaseModel(base_model, abc.ABC): + class Meta: + global_key_prefix = key_prefix + + class PostDate(BaseModel): + created: datetime.date = Field(index=True) + + await Migrator().run() + return PostDate + + +# TODO: code duplication +@py_test_mark_asyncio +async def test_date(post_model_date): + now = datetime.date(1980, 1, 2) + now_next = now.replace(day=3) + now_prev = now.replace(day=1) + + posts = [ + post_model_date(created=time_point) for time_point in (now, now_next, now_prev) + ] + for post in posts: + await post.save() + + expected_sorted_posts = sorted(posts, key=attrgetter("created")) + + # Check all + assert ( + await post_model_date.find().sort_by("created").all() == expected_sorted_posts + ) + # Check one + assert await post_model_date.find(post_model_date.created == now).all() == [ + posts[0] + ] + + # Check index comparison + assert ( + await post_model_date.find(post_model_date.created < now_next) + .sort_by("created") + .all() + == expected_sorted_posts[:2] + ) + assert ( + await post_model_date.find(post_model_date.created < now) + .sort_by("created") + .all() + == expected_sorted_posts[:1] + ) + assert ( + await post_model_date.find(post_model_date.created < now_prev) + .sort_by("created") + .all() + == [] + ) + + +# TODO: disable tests based on checks +@pytest_asyncio.fixture(params=[HashModel, JsonModel]) +async def post_model_time(request, key_prefix): + base_model = request.param + + class BaseModel(base_model, abc.ABC): + class Meta: + global_key_prefix = key_prefix + + class PostTime(BaseModel): + created: datetime.time = Field(index=True) + + # TODO: Provide our field type instead of date datetime.time? + # https://pydantic-docs.helpmanual.io/usage/types/#datetime-types + # datetime.time is parsing only from time obj or iso? str + @validator("created", pre=True, allow_reuse=True) + def time_validator(cls, value): + if isinstance(value, str): + value = int(value) + if isinstance(value, int): + # TODO: check if correct + return ( + datetime.datetime.fromtimestamp( + value // 1000, tz=datetime.timezone.utc + ) + .time() + .replace(tzinfo=datetime.timezone.utc) + ) + return value + + await Migrator().run() + return PostTime + + +# TODO: code duplication +@py_test_mark_asyncio +async def test_time(post_model_time): + now = datetime.time(hour=2, second=20, tzinfo=datetime.timezone.utc) + now_p10 = now.replace(second=30) + now_m10 = now.replace(second=10) + + next_hour_timezone = datetime.timezone(datetime.timedelta(hours=1)) + now_01_00 = now.replace(hour=3, tzinfo=next_hour_timezone) + # Sanity check + assert now == now_01_00 + + posts = [ + post_model_time(created=time_point) for time_point in (now, now_p10, now_m10) + ] + for post in posts: + await post.save() + + expected_sorted_posts = sorted(posts, key=attrgetter("created")) + + # Check all + assert ( + await post_model_time.find().sort_by("created").all() == expected_sorted_posts + ) + # Check one + assert await post_model_time.find(post_model_time.created == now).all() == [ + posts[0] + ] + # Check one using different timezone but the same time + assert await post_model_time.find(post_model_time.created == now_01_00).all() == [ + posts[0] + ] + + # Check one + post = await post_model_time.find(post_model_time.created == now).first() + assert post.created == now == now_01_00 + + # Check index comparison + assert ( + await post_model_time.find(post_model_time.created < now_p10) + .sort_by("created") + .all() + == expected_sorted_posts[:2] + ) + assert ( + await post_model_time.find(post_model_time.created < now) + .sort_by("created") + .all() + == expected_sorted_posts[:1] + ) + assert ( + await post_model_time.find(post_model_time.created < now_m10) + .sort_by("created") + .all() + == [] + ) + + +@pytest.fixture( + params=[ + datetime.timezone.utc, + datetime.timezone(datetime.timedelta(hours=2)), + datetime.timezone(datetime.timedelta(hours=-5)), + ], + ids=["UTC", "UTC+2", "UTC-5"], +) +def timezone(request): + return request.param + + +@py_test_mark_asyncio +async def test_mixing(post_model_time, post_model_date, post_model_datetime, timezone): + now = datetime.datetime(1980, 1, 1, hour=2, second=20, tzinfo=timezone) + now_date, now_time = now.date(), now.time().replace(tzinfo=timezone) + + # Serialize + Deserialize datetime.datetime + await post_model_datetime(created=now).save() + obj = await post_model_datetime.find().first() + assert obj.created == now + + # Serialize + Deserialize datetime.date + await post_model_date(created=now_date).save() + obj_date = await post_model_date.find().first() + assert obj_date.created == now_date + + # Serialize + Deserialize datetime.time + await post_model_time(created=now_time).save() + obj_time = await post_model_time.find().first() + assert obj_time.created == now_time + + # Combine deserialized and compare to expected + restored = datetime.datetime.combine(obj_date.created, obj_time.created) + assert restored == now + + +@py_test_mark_asyncio +async def test_precision(post_model_datetime): + now = datetime.datetime( + 1980, 1, 1, hour=2, second=20, microsecond=123457, tzinfo=datetime.timezone.utc + ) + # Serialize + Deserialize datetime.datetime + await post_model_datetime(created=now).save() + obj = await post_model_datetime.find().first() + obj_now = obj.created + + # Test seconds + assert obj_now.replace(microsecond=0) == now.replace(microsecond=0) + + # Test milliseconds + assert obj_now.replace(microsecond=obj_now.microsecond // 1000) == now.replace( + microsecond=now.microsecond // 1000 + ) + + # Test microseconds + # Our precision is millisecond + with pytest.raises(AssertionError): + assert obj_now == now + + # We should be in 1000 microsecond range + assert ( + datetime.timedelta(microseconds=-1000) + <= obj_now - now + <= datetime.timedelta(microseconds=1000) + )