Allow users to define a new primary key. (#347)
* make primary key programmable Signed-off-by: wiseaidev <business@wiseai.dev> * get primary key field using the `key` method Signed-off-by: wiseaidev <business@wiseai.dev> * adjust delete_many & expire methods Signed-off-by: wiseaidev <business@wiseai.dev> * fix query for int primary key Signed-off-by: wiseaidev <business@wiseai.dev> * fix grammar Signed-off-by: wiseaidev <business@wiseai.dev> * add unit tests Signed-off-by: wiseaidev <business@wiseai.dev> Signed-off-by: wiseaidev <business@wiseai.dev> Co-authored-by: Chayim <chayim@users.noreply.github.com> Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com>
This commit is contained in:
parent
2e09234b68
commit
551429c01a
2 changed files with 74 additions and 13 deletions
|
@ -564,7 +564,10 @@ class FindQuery:
|
||||||
separator_char,
|
separator_char,
|
||||||
)
|
)
|
||||||
return ""
|
return ""
|
||||||
if separator_char in value:
|
if isinstance(value, int):
|
||||||
|
# This if will hit only if the field is a primary key of type int
|
||||||
|
result = f"@{field_name}:[{value} {value}]"
|
||||||
|
elif separator_char in value:
|
||||||
# The value contains the TAG field separator. We can work
|
# The value contains the TAG field separator. We can work
|
||||||
# around this by breaking apart the values and unioning them
|
# around this by breaking apart the values and unioning them
|
||||||
# with multiple field:{} queries.
|
# with multiple field:{} queries.
|
||||||
|
@ -1106,12 +1109,12 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
||||||
extra = "allow"
|
extra = "allow"
|
||||||
|
|
||||||
def __init__(__pydantic_self__, **data: Any) -> None:
|
def __init__(__pydantic_self__, **data: Any) -> None:
|
||||||
super().__init__(**data)
|
|
||||||
__pydantic_self__.validate_primary_key()
|
__pydantic_self__.validate_primary_key()
|
||||||
|
super().__init__(**data)
|
||||||
|
|
||||||
def __lt__(self, other):
|
def __lt__(self, other):
|
||||||
"""Default sort: compare primary key of models."""
|
"""Default sort: compare primary key of models."""
|
||||||
return self.pk < other.pk
|
return self.key() < other.key()
|
||||||
|
|
||||||
def key(self):
|
def key(self):
|
||||||
"""Return the Redis key for this model."""
|
"""Return the Redis key for this model."""
|
||||||
|
@ -1150,7 +1153,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
||||||
db = self._get_db(pipeline)
|
db = self._get_db(pipeline)
|
||||||
|
|
||||||
# TODO: Wrap any Redis response errors in a custom exception?
|
# TODO: Wrap any Redis response errors in a custom exception?
|
||||||
await db.expire(self.make_primary_key(self.pk), num_seconds)
|
await db.expire(self.key(), num_seconds)
|
||||||
|
|
||||||
@validator("pk", always=True, allow_reuse=True)
|
@validator("pk", always=True, allow_reuse=True)
|
||||||
def validate_pk(cls, v):
|
def validate_pk(cls, v):
|
||||||
|
@ -1167,7 +1170,9 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
||||||
primary_keys += 1
|
primary_keys += 1
|
||||||
if primary_keys == 0:
|
if primary_keys == 0:
|
||||||
raise RedisModelError("You must define a primary key for the model")
|
raise RedisModelError("You must define a primary key for the model")
|
||||||
elif primary_keys > 1:
|
elif primary_keys == 2:
|
||||||
|
cls.__fields__.pop('pk')
|
||||||
|
elif primary_keys > 2:
|
||||||
raise RedisModelError("You must define only one primary key for a model")
|
raise RedisModelError("You must define only one primary key for a model")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -1275,7 +1280,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
||||||
db = cls._get_db(pipeline)
|
db = cls._get_db(pipeline)
|
||||||
|
|
||||||
for chunk in ichunked(models, 100):
|
for chunk in ichunked(models, 100):
|
||||||
pks = [cls.make_primary_key(model.pk) for model in chunk]
|
pks = [model.key() for model in chunk]
|
||||||
await cls._delete(db, *pks)
|
await cls._delete(db, *pks)
|
||||||
|
|
||||||
return len(models)
|
return len(models)
|
||||||
|
|
|
@ -46,7 +46,7 @@ async def m(key_prefix, redis):
|
||||||
created_on: datetime.datetime
|
created_on: datetime.datetime
|
||||||
|
|
||||||
class Member(BaseHashModel):
|
class Member(BaseHashModel):
|
||||||
id: int = Field(index=True)
|
id: int = Field(index=True, primary_key=True)
|
||||||
first_name: str = Field(index=True)
|
first_name: str = Field(index=True)
|
||||||
last_name: str = Field(index=True)
|
last_name: str = Field(index=True)
|
||||||
email: str = Field(index=True)
|
email: str = Field(index=True)
|
||||||
|
@ -445,7 +445,7 @@ async def test_saves_model_and_creates_pk(m):
|
||||||
# Save a model instance to Redis
|
# Save a model instance to Redis
|
||||||
await member.save()
|
await member.save()
|
||||||
|
|
||||||
member2 = await m.Member.get(member.pk)
|
member2 = await m.Member.get(pk=member.id)
|
||||||
assert member2 == member
|
assert member2 == member
|
||||||
|
|
||||||
|
|
||||||
|
@ -495,7 +495,7 @@ async def test_delete(m):
|
||||||
)
|
)
|
||||||
|
|
||||||
await member.save()
|
await member.save()
|
||||||
response = await m.Member.delete(member.pk)
|
response = await m.Member.delete(pk=member.id)
|
||||||
assert response == 1
|
assert response == 1
|
||||||
|
|
||||||
|
|
||||||
|
@ -588,8 +588,8 @@ async def test_saves_many(m):
|
||||||
result = await m.Member.add(members)
|
result = await m.Member.add(members)
|
||||||
assert result == [member1, member2]
|
assert result == [member1, member2]
|
||||||
|
|
||||||
assert await m.Member.get(pk=member1.pk) == member1
|
assert await m.Member.get(pk=member1.id) == member1
|
||||||
assert await m.Member.get(pk=member2.pk) == member2
|
assert await m.Member.get(pk=member2.id) == member2
|
||||||
|
|
||||||
|
|
||||||
@py_test_mark_asyncio
|
@py_test_mark_asyncio
|
||||||
|
@ -618,14 +618,14 @@ async def test_delete_many(m):
|
||||||
result = await m.Member.delete_many(members)
|
result = await m.Member.delete_many(members)
|
||||||
assert result == 2
|
assert result == 2
|
||||||
with pytest.raises(NotFoundError):
|
with pytest.raises(NotFoundError):
|
||||||
await m.Member.get(pk=member1.pk)
|
await m.Member.get(pk=member1.key())
|
||||||
|
|
||||||
|
|
||||||
@py_test_mark_asyncio
|
@py_test_mark_asyncio
|
||||||
async def test_updates_a_model(members, m):
|
async def test_updates_a_model(members, m):
|
||||||
member1, member2, member3 = members
|
member1, member2, member3 = members
|
||||||
await member1.update(last_name="Smith")
|
await member1.update(last_name="Smith")
|
||||||
member = await m.Member.get(member1.pk)
|
member = await m.Member.get(member1.id)
|
||||||
assert member.last_name == "Smith"
|
assert member.last_name == "Smith"
|
||||||
|
|
||||||
|
|
||||||
|
@ -681,3 +681,59 @@ def test_schema(m):
|
||||||
Address.redisearch_schema()
|
Address.redisearch_schema()
|
||||||
== f"ON HASH PREFIX 1 {key_prefix} SCHEMA pk TAG SEPARATOR | a_string TAG SEPARATOR | a_full_text_string TAG SEPARATOR | a_full_text_string AS a_full_text_string_fts TEXT an_integer NUMERIC SORTABLE a_float NUMERIC"
|
== f"ON HASH PREFIX 1 {key_prefix} SCHEMA pk TAG SEPARATOR | a_string TAG SEPARATOR | a_full_text_string TAG SEPARATOR | a_full_text_string AS a_full_text_string_fts TEXT an_integer NUMERIC SORTABLE a_float NUMERIC"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@py_test_mark_asyncio
|
||||||
|
async def test_primary_key_model_error(m):
|
||||||
|
|
||||||
|
class Customer(m.BaseHashModel):
|
||||||
|
id: int = Field(primary_key=True, index=True)
|
||||||
|
first_name: str = Field(primary_key=True, index=True)
|
||||||
|
last_name: str
|
||||||
|
bio: Optional[str]
|
||||||
|
|
||||||
|
await Migrator().run()
|
||||||
|
|
||||||
|
with pytest.raises(RedisModelError, match="You must define only one primary key for a model"):
|
||||||
|
_ = Customer(
|
||||||
|
id=0,
|
||||||
|
first_name="Mahmoud",
|
||||||
|
last_name="Harmouch",
|
||||||
|
bio="Python developer, wanna work at Redis, Inc."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@py_test_mark_asyncio
|
||||||
|
async def test_primary_pk_exists(m):
|
||||||
|
|
||||||
|
class Customer1(m.BaseHashModel):
|
||||||
|
id: int
|
||||||
|
first_name: str
|
||||||
|
last_name: str
|
||||||
|
bio: Optional[str]
|
||||||
|
|
||||||
|
class Customer2(m.BaseHashModel):
|
||||||
|
id: int = Field(primary_key=True, index=True)
|
||||||
|
first_name: str
|
||||||
|
last_name: str
|
||||||
|
bio: Optional[str]
|
||||||
|
|
||||||
|
await Migrator().run()
|
||||||
|
|
||||||
|
customer = Customer1(
|
||||||
|
id=0,
|
||||||
|
first_name="Mahmoud",
|
||||||
|
last_name="Harmouch",
|
||||||
|
bio="Python developer, wanna work at Redis, Inc."
|
||||||
|
)
|
||||||
|
|
||||||
|
assert 'pk' in customer.__fields__
|
||||||
|
|
||||||
|
customer = Customer2(
|
||||||
|
id=1,
|
||||||
|
first_name="Kim",
|
||||||
|
last_name="Brookins",
|
||||||
|
bio="This is member 2 who can be quite anxious until you get to know them.",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert 'pk' not in customer.__fields__
|
||||||
|
|
Loading…
Reference in a new issue