Skip to content

Commit

Permalink
fix: all_pks() for complex keys (#471)
Browse files Browse the repository at this point in the history
* fix: all_pks for complex keys

* fix tests

* more fixes

* support Python below 3.9+

* black

* linter again
  • Loading branch information
YaraslauZhylko authored Feb 12, 2023
1 parent 250d29d commit 412bdd6
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 16 deletions.
28 changes: 14 additions & 14 deletions aredis_om/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,14 @@ def decode_redis_value(
return obj.decode(encoding)


# TODO: replace with `str.removeprefix()` when only Python 3.9+ is supported
def remove_prefix(value: str, prefix: str) -> str:
"""Remove a prefix from a string."""
if value.startswith(prefix):
value = value[len(prefix) :] # noqa: E203
return value


class PipelineError(Exception):
"""A Redis pipeline error."""

Expand Down Expand Up @@ -1350,16 +1358,12 @@ async def save(
@classmethod
async def all_pks(cls): # type: ignore
key_prefix = cls.make_key(cls._meta.primary_key_pattern.format(pk=""))
# TODO: We assume the key ends with the default separator, ":" -- when
# we make the separator configurable, we need to update this as well.
# ... And probably lots of other places ...
#
# TODO: Also, we need to decide how we want to handle the lack of
# TODO: We need to decide how we want to handle the lack of
# decode_responses=True...
return (
key.split(":")[-1]
remove_prefix(key, key_prefix)
if isinstance(key, str)
else key.decode(cls.Meta.encoding).split(":")[-1]
else remove_prefix(key.decode(cls.Meta.encoding), key_prefix)
async for key in cls.db().scan_iter(f"{key_prefix}*", _type="HASH")
)

Expand Down Expand Up @@ -1521,16 +1525,12 @@ async def save(
@classmethod
async def all_pks(cls): # type: ignore
key_prefix = cls.make_key(cls._meta.primary_key_pattern.format(pk=""))
# TODO: We assume the key ends with the default separator, ":" -- when
# we make the separator configurable, we need to update this as well.
# ... And probably lots of other places ...
#
# TODO: Also, we need to decide how we want to handle the lack of
# TODO: We need to decide how we want to handle the lack of
# decode_responses=True...
return (
key.split(":")[-1]
remove_prefix(key, key_prefix)
if isinstance(key, str)
else key.decode(cls.Meta.encoding).split(":")[-1]
else remove_prefix(key.decode(cls.Meta.encoding), key_prefix)
async for key in cls.db().scan_iter(f"{key_prefix}*", _type="ReJSON-RL")
)

Expand Down
32 changes: 31 additions & 1 deletion tests/test_hash_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,37 @@ async def test_all_pks(m):
async for pk in await m.Member.all_pks():
pk_list.append(pk)

assert len(pk_list) == 2
assert sorted(pk_list) == ["0", "1"]


@py_test_mark_asyncio
async def test_all_pks_with_complex_pks(key_prefix):
class City(HashModel):
name: str

class Meta:
global_key_prefix = key_prefix
model_key_prefix = "city"

city1 = City(
pk="ca:on:toronto",
name="Toronto",
)

await city1.save()

city2 = City(
pk="ca:qc:montreal",
name="Montreal",
)

await city2.save()

pk_list = []
async for pk in await City.all_pks():
pk_list.append(pk)

assert sorted(pk_list) == ["ca:on:toronto", "ca:qc:montreal"]


@py_test_mark_asyncio
Expand Down
32 changes: 31 additions & 1 deletion tests/test_json_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,37 @@ async def test_all_pks(address, m, redis):
async for pk in await m.Member.all_pks():
pk_list.append(pk)

assert len(pk_list) == 2
assert sorted(pk_list) == sorted([member.pk, member1.pk])


@py_test_mark_asyncio
async def test_all_pks_with_complex_pks(key_prefix):
class City(JsonModel):
name: str

class Meta:
global_key_prefix = key_prefix
model_key_prefix = "city"

city1 = City(
pk="ca:on:toronto",
name="Toronto",
)

await city1.save()

city2 = City(
pk="ca:qc:montreal",
name="Montreal",
)

await city2.save()

pk_list = []
async for pk in await City.all_pks():
pk_list.append(pk)

assert sorted(pk_list) == ["ca:on:toronto", "ca:qc:montreal"]


@py_test_mark_asyncio
Expand Down

0 comments on commit 412bdd6

Please sign in to comment.