diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 1887476..8271d2e 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -564,7 +564,10 @@ class FindQuery: separator_char, ) return "" - if separator_char in value: + if isinstance(value, int): + # This if will hit only if the field is a primary key of type int + result = f"@{field_name}:[{value} {value}]" + elif separator_char in value: # The value contains the TAG field separator. We can work # around this by breaking apart the values and unioning them # with multiple field:{} queries. @@ -1106,12 +1109,12 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): extra = "allow" def __init__(__pydantic_self__, **data: Any) -> None: - super().__init__(**data) __pydantic_self__.validate_primary_key() + super().__init__(**data) def __lt__(self, other): """Default sort: compare primary key of models.""" - return self.pk < other.pk + return self.key() < other.key() def key(self): """Return the Redis key for this model.""" @@ -1150,7 +1153,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): db = self._get_db(pipeline) # TODO: Wrap any Redis response errors in a custom exception? - await db.expire(self.make_primary_key(self.pk), num_seconds) + await db.expire(self.key(), num_seconds) @validator("pk", always=True, allow_reuse=True) def validate_pk(cls, v): @@ -1167,7 +1170,9 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): primary_keys += 1 if primary_keys == 0: raise RedisModelError("You must define a primary key for the model") - elif primary_keys > 1: + elif primary_keys == 2: + cls.__fields__.pop('pk') + elif primary_keys > 2: raise RedisModelError("You must define only one primary key for a model") @classmethod @@ -1275,7 +1280,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): db = cls._get_db(pipeline) for chunk in ichunked(models, 100): - pks = [cls.make_primary_key(model.pk) for model in chunk] + pks = [model.key() for model in chunk] await cls._delete(db, *pks) return len(models) diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index 84a0508..b2bd30a 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -46,7 +46,7 @@ async def m(key_prefix, redis): created_on: datetime.datetime class Member(BaseHashModel): - id: int = Field(index=True) + id: int = Field(index=True, primary_key=True) first_name: str = Field(index=True) last_name: str = Field(index=True) email: str = Field(index=True) @@ -445,7 +445,7 @@ async def test_saves_model_and_creates_pk(m): # Save a model instance to Redis await member.save() - member2 = await m.Member.get(member.pk) + member2 = await m.Member.get(pk=member.id) assert member2 == member @@ -495,7 +495,7 @@ async def test_delete(m): ) await member.save() - response = await m.Member.delete(member.pk) + response = await m.Member.delete(pk=member.id) assert response == 1 @@ -588,8 +588,8 @@ async def test_saves_many(m): result = await m.Member.add(members) assert result == [member1, member2] - assert await m.Member.get(pk=member1.pk) == member1 - assert await m.Member.get(pk=member2.pk) == member2 + assert await m.Member.get(pk=member1.id) == member1 + assert await m.Member.get(pk=member2.id) == member2 @py_test_mark_asyncio @@ -618,14 +618,14 @@ async def test_delete_many(m): result = await m.Member.delete_many(members) assert result == 2 with pytest.raises(NotFoundError): - await m.Member.get(pk=member1.pk) + await m.Member.get(pk=member1.key()) @py_test_mark_asyncio async def test_updates_a_model(members, m): member1, member2, member3 = members await member1.update(last_name="Smith") - member = await m.Member.get(member1.pk) + member = await m.Member.get(member1.id) assert member.last_name == "Smith" @@ -681,3 +681,59 @@ def test_schema(m): Address.redisearch_schema() == f"ON HASH PREFIX 1 {key_prefix} SCHEMA pk TAG SEPARATOR | a_string TAG SEPARATOR | a_full_text_string TAG SEPARATOR | a_full_text_string AS a_full_text_string_fts TEXT an_integer NUMERIC SORTABLE a_float NUMERIC" ) + + +@py_test_mark_asyncio +async def test_primary_key_model_error(m): + + class Customer(m.BaseHashModel): + id: int = Field(primary_key=True, index=True) + first_name: str = Field(primary_key=True, index=True) + last_name: str + bio: Optional[str] + + await Migrator().run() + + with pytest.raises(RedisModelError, match="You must define only one primary key for a model"): + _ = Customer( + id=0, + first_name="Mahmoud", + last_name="Harmouch", + bio="Python developer, wanna work at Redis, Inc." + ) + + +@py_test_mark_asyncio +async def test_primary_pk_exists(m): + + class Customer1(m.BaseHashModel): + id: int + first_name: str + last_name: str + bio: Optional[str] + + class Customer2(m.BaseHashModel): + id: int = Field(primary_key=True, index=True) + first_name: str + last_name: str + bio: Optional[str] + + await Migrator().run() + + customer = Customer1( + id=0, + first_name="Mahmoud", + last_name="Harmouch", + bio="Python developer, wanna work at Redis, Inc." + ) + + assert 'pk' in customer.__fields__ + + customer = Customer2( + id=1, + first_name="Kim", + last_name="Brookins", + bio="This is member 2 who can be quite anxious until you get to know them.", + ) + + assert 'pk' not in customer.__fields__