From 274ff788b1c9addc8f096baf8a5220b8c7f2f26f Mon Sep 17 00:00:00 2001 From: Simon Prickett Date: Tue, 12 Apr 2022 11:31:09 +0100 Subject: [PATCH] Adds all_pks() method and test. --- aredis_om/model/model.py | 16 ++++++++++++++++ tests/test_json_model.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index dafd41a..9f2c3a7 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -1483,6 +1483,22 @@ class JsonModel(RedisModel, abc.ABC): await db.execute_command("JSON.SET", self.key(), ".", self.json()) return self + @classmethod + async def all_pks(cls): # type: ignore + key_prefix = cls.make_key(cls._meta.primary_key_pattern.format(pk="")) + # TODO: We assume the key ends with the default separator, ":" -- when + # we make the separator configurable, we need to update this as well. + # ... And probably lots of other places ... + # + # TODO: Also, we need to decide how we want to handle the lack of + # decode_responses=True... + return ( + key.split(":")[-1] + if isinstance(key, str) + else key.decode(cls.Meta.encoding).split(":")[-1] + async for key in cls.db().scan_iter(f"{key_prefix}*", _type="ReJSON-RL") + ) + async def update(self, **field_values): validate_model_fields(self.__class__, field_values) for field, value in field_values.items(): diff --git a/tests/test_json_model.py b/tests/test_json_model.py index d920a40..13c2834 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -186,6 +186,35 @@ async def test_saves_model_and_creates_pk(address, m, redis): assert member2 == member assert member2.address == address +@pytest.mark.asyncio +async def test_all_pks(address, m, redis): + member = m.Member( + first_name="Andrew", + last_name="Brookins", + email="a@example.com", + join_date=today, + age=38, + address=address, + ) + + await member.save() + + member1 = m.Member( + first_name="Simon", + last_name="Prickett", + email="s@example.com", + join_date=today, + age=99, + address=address, + ) + + await member1.save() + + pk_list = [] + async for pk in await m.Member.all_pks(): + pk_list.append(pk) + + assert len(pk_list) == 2 @pytest.mark.asyncio async def test_delete(address, m, redis):