Allow users to define a new primary key. (#347)
* make primary key programmable Signed-off-by: wiseaidev <business@wiseai.dev> * get primary key field using the `key` method Signed-off-by: wiseaidev <business@wiseai.dev> * adjust delete_many & expire methods Signed-off-by: wiseaidev <business@wiseai.dev> * fix query for int primary key Signed-off-by: wiseaidev <business@wiseai.dev> * fix grammar Signed-off-by: wiseaidev <business@wiseai.dev> * add unit tests Signed-off-by: wiseaidev <business@wiseai.dev> Signed-off-by: wiseaidev <business@wiseai.dev> Co-authored-by: Chayim <chayim@users.noreply.github.com> Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									2e09234b68
								
							
						
					
					
						commit
						551429c01a
					
				
					 2 changed files with 74 additions and 13 deletions
				
			
		| 
						 | 
					@ -564,7 +564,10 @@ class FindQuery:
 | 
				
			||||||
                        separator_char,
 | 
					                        separator_char,
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
                    return ""
 | 
					                    return ""
 | 
				
			||||||
                if separator_char in value:
 | 
					                if isinstance(value, int):
 | 
				
			||||||
 | 
					                    # This if will hit only if the field is a primary key of type int
 | 
				
			||||||
 | 
					                    result = f"@{field_name}:[{value} {value}]"
 | 
				
			||||||
 | 
					                elif separator_char in value:
 | 
				
			||||||
                    # The value contains the TAG field separator. We can work
 | 
					                    # The value contains the TAG field separator. We can work
 | 
				
			||||||
                    # around this by breaking apart the values and unioning them
 | 
					                    # around this by breaking apart the values and unioning them
 | 
				
			||||||
                    # with multiple field:{} queries.
 | 
					                    # with multiple field:{} queries.
 | 
				
			||||||
| 
						 | 
					@ -1106,12 +1109,12 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
 | 
				
			||||||
        extra = "allow"
 | 
					        extra = "allow"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(__pydantic_self__, **data: Any) -> None:
 | 
					    def __init__(__pydantic_self__, **data: Any) -> None:
 | 
				
			||||||
        super().__init__(**data)
 | 
					 | 
				
			||||||
        __pydantic_self__.validate_primary_key()
 | 
					        __pydantic_self__.validate_primary_key()
 | 
				
			||||||
 | 
					        super().__init__(**data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __lt__(self, other):
 | 
					    def __lt__(self, other):
 | 
				
			||||||
        """Default sort: compare primary key of models."""
 | 
					        """Default sort: compare primary key of models."""
 | 
				
			||||||
        return self.pk < other.pk
 | 
					        return self.key() < other.key()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def key(self):
 | 
					    def key(self):
 | 
				
			||||||
        """Return the Redis key for this model."""
 | 
					        """Return the Redis key for this model."""
 | 
				
			||||||
| 
						 | 
					@ -1150,7 +1153,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
 | 
				
			||||||
        db = self._get_db(pipeline)
 | 
					        db = self._get_db(pipeline)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # TODO: Wrap any Redis response errors in a custom exception?
 | 
					        # TODO: Wrap any Redis response errors in a custom exception?
 | 
				
			||||||
        await db.expire(self.make_primary_key(self.pk), num_seconds)
 | 
					        await db.expire(self.key(), num_seconds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @validator("pk", always=True, allow_reuse=True)
 | 
					    @validator("pk", always=True, allow_reuse=True)
 | 
				
			||||||
    def validate_pk(cls, v):
 | 
					    def validate_pk(cls, v):
 | 
				
			||||||
| 
						 | 
					@ -1167,7 +1170,9 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
 | 
				
			||||||
                primary_keys += 1
 | 
					                primary_keys += 1
 | 
				
			||||||
        if primary_keys == 0:
 | 
					        if primary_keys == 0:
 | 
				
			||||||
            raise RedisModelError("You must define a primary key for the model")
 | 
					            raise RedisModelError("You must define a primary key for the model")
 | 
				
			||||||
        elif primary_keys > 1:
 | 
					        elif primary_keys == 2:
 | 
				
			||||||
 | 
					            cls.__fields__.pop('pk')
 | 
				
			||||||
 | 
					        elif primary_keys > 2:
 | 
				
			||||||
            raise RedisModelError("You must define only one primary key for a model")
 | 
					            raise RedisModelError("You must define only one primary key for a model")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
| 
						 | 
					@ -1275,7 +1280,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
 | 
				
			||||||
        db = cls._get_db(pipeline)
 | 
					        db = cls._get_db(pipeline)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for chunk in ichunked(models, 100):
 | 
					        for chunk in ichunked(models, 100):
 | 
				
			||||||
            pks = [cls.make_primary_key(model.pk) for model in chunk]
 | 
					            pks = [model.key() for model in chunk]
 | 
				
			||||||
            await cls._delete(db, *pks)
 | 
					            await cls._delete(db, *pks)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return len(models)
 | 
					        return len(models)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -46,7 +46,7 @@ async def m(key_prefix, redis):
 | 
				
			||||||
        created_on: datetime.datetime
 | 
					        created_on: datetime.datetime
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    class Member(BaseHashModel):
 | 
					    class Member(BaseHashModel):
 | 
				
			||||||
        id: int = Field(index=True)
 | 
					        id: int = Field(index=True, primary_key=True)
 | 
				
			||||||
        first_name: str = Field(index=True)
 | 
					        first_name: str = Field(index=True)
 | 
				
			||||||
        last_name: str = Field(index=True)
 | 
					        last_name: str = Field(index=True)
 | 
				
			||||||
        email: str = Field(index=True)
 | 
					        email: str = Field(index=True)
 | 
				
			||||||
| 
						 | 
					@ -445,7 +445,7 @@ async def test_saves_model_and_creates_pk(m):
 | 
				
			||||||
    # Save a model instance to Redis
 | 
					    # Save a model instance to Redis
 | 
				
			||||||
    await member.save()
 | 
					    await member.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    member2 = await m.Member.get(member.pk)
 | 
					    member2 = await m.Member.get(pk=member.id)
 | 
				
			||||||
    assert member2 == member
 | 
					    assert member2 == member
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -495,7 +495,7 @@ async def test_delete(m):
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    await member.save()
 | 
					    await member.save()
 | 
				
			||||||
    response = await m.Member.delete(member.pk)
 | 
					    response = await m.Member.delete(pk=member.id)
 | 
				
			||||||
    assert response == 1
 | 
					    assert response == 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -588,8 +588,8 @@ async def test_saves_many(m):
 | 
				
			||||||
    result = await m.Member.add(members)
 | 
					    result = await m.Member.add(members)
 | 
				
			||||||
    assert result == [member1, member2]
 | 
					    assert result == [member1, member2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    assert await m.Member.get(pk=member1.pk) == member1
 | 
					    assert await m.Member.get(pk=member1.id) == member1
 | 
				
			||||||
    assert await m.Member.get(pk=member2.pk) == member2
 | 
					    assert await m.Member.get(pk=member2.id) == member2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@py_test_mark_asyncio
 | 
					@py_test_mark_asyncio
 | 
				
			||||||
| 
						 | 
					@ -618,14 +618,14 @@ async def test_delete_many(m):
 | 
				
			||||||
    result = await m.Member.delete_many(members)
 | 
					    result = await m.Member.delete_many(members)
 | 
				
			||||||
    assert result == 2
 | 
					    assert result == 2
 | 
				
			||||||
    with pytest.raises(NotFoundError):
 | 
					    with pytest.raises(NotFoundError):
 | 
				
			||||||
        await m.Member.get(pk=member1.pk)
 | 
					        await m.Member.get(pk=member1.key())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@py_test_mark_asyncio
 | 
					@py_test_mark_asyncio
 | 
				
			||||||
async def test_updates_a_model(members, m):
 | 
					async def test_updates_a_model(members, m):
 | 
				
			||||||
    member1, member2, member3 = members
 | 
					    member1, member2, member3 = members
 | 
				
			||||||
    await member1.update(last_name="Smith")
 | 
					    await member1.update(last_name="Smith")
 | 
				
			||||||
    member = await m.Member.get(member1.pk)
 | 
					    member = await m.Member.get(member1.id)
 | 
				
			||||||
    assert member.last_name == "Smith"
 | 
					    assert member.last_name == "Smith"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -681,3 +681,59 @@ def test_schema(m):
 | 
				
			||||||
        Address.redisearch_schema()
 | 
					        Address.redisearch_schema()
 | 
				
			||||||
        == f"ON HASH PREFIX 1 {key_prefix} SCHEMA pk TAG SEPARATOR | a_string TAG SEPARATOR | a_full_text_string TAG SEPARATOR | a_full_text_string AS a_full_text_string_fts TEXT an_integer NUMERIC SORTABLE a_float NUMERIC"
 | 
					        == f"ON HASH PREFIX 1 {key_prefix} SCHEMA pk TAG SEPARATOR | a_string TAG SEPARATOR | a_full_text_string TAG SEPARATOR | a_full_text_string AS a_full_text_string_fts TEXT an_integer NUMERIC SORTABLE a_float NUMERIC"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@py_test_mark_asyncio
 | 
				
			||||||
 | 
					async def test_primary_key_model_error(m):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class Customer(m.BaseHashModel):
 | 
				
			||||||
 | 
					        id: int = Field(primary_key=True, index=True)
 | 
				
			||||||
 | 
					        first_name: str = Field(primary_key=True, index=True)
 | 
				
			||||||
 | 
					        last_name: str
 | 
				
			||||||
 | 
					        bio: Optional[str]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    await Migrator().run()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with pytest.raises(RedisModelError, match="You must define only one primary key for a model"):
 | 
				
			||||||
 | 
					        _ = Customer(
 | 
				
			||||||
 | 
					            id=0,
 | 
				
			||||||
 | 
					            first_name="Mahmoud",
 | 
				
			||||||
 | 
					            last_name="Harmouch",
 | 
				
			||||||
 | 
					            bio="Python developer, wanna work at Redis, Inc."
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@py_test_mark_asyncio
 | 
				
			||||||
 | 
					async def test_primary_pk_exists(m):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class Customer1(m.BaseHashModel):
 | 
				
			||||||
 | 
					        id: int
 | 
				
			||||||
 | 
					        first_name: str
 | 
				
			||||||
 | 
					        last_name: str
 | 
				
			||||||
 | 
					        bio: Optional[str]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class Customer2(m.BaseHashModel):
 | 
				
			||||||
 | 
					        id: int = Field(primary_key=True, index=True)
 | 
				
			||||||
 | 
					        first_name: str
 | 
				
			||||||
 | 
					        last_name: str
 | 
				
			||||||
 | 
					        bio: Optional[str]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    await Migrator().run()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    customer = Customer1(
 | 
				
			||||||
 | 
					        id=0,
 | 
				
			||||||
 | 
					        first_name="Mahmoud",
 | 
				
			||||||
 | 
					        last_name="Harmouch",
 | 
				
			||||||
 | 
					        bio="Python developer, wanna work at Redis, Inc."
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert 'pk' in customer.__fields__
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    customer = Customer2(
 | 
				
			||||||
 | 
					        id=1,
 | 
				
			||||||
 | 
					        first_name="Kim",
 | 
				
			||||||
 | 
					        last_name="Brookins",
 | 
				
			||||||
 | 
					        bio="This is member 2 who can be quite anxious until you get to know them.",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert 'pk' not in customer.__fields__
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue