Use TypeVars for return types of RedisModel and its subtype's methods (#476)
Co-authored-by: Chayim <chayim@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									89b6c84c4a
								
							
						
					
					
						commit
						c68adacea2
					
				
					 1 changed files with 13 additions and 12 deletions
				
			
		| 
						 | 
					@ -47,6 +47,7 @@ from .token_escaper import TokenEscaper
 | 
				
			||||||
 | 
					
 | 
				
			||||||
model_registry = {}
 | 
					model_registry = {}
 | 
				
			||||||
_T = TypeVar("_T")
 | 
					_T = TypeVar("_T")
 | 
				
			||||||
 | 
					Model = TypeVar("Model", bound="RedisModel")
 | 
				
			||||||
log = logging.getLogger(__name__)
 | 
					log = logging.getLogger(__name__)
 | 
				
			||||||
escaper = TokenEscaper()
 | 
					escaper = TokenEscaper()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1310,7 +1311,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
 | 
				
			||||||
        return await cls._delete(db, cls.make_primary_key(pk))
 | 
					        return await cls._delete(db, cls.make_primary_key(pk))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    async def get(cls, pk: Any) -> "RedisModel":
 | 
					    async def get(cls: Type["Model"], pk: Any) -> "Model":
 | 
				
			||||||
        raise NotImplementedError
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def update(self, **field_values):
 | 
					    async def update(self, **field_values):
 | 
				
			||||||
| 
						 | 
					@ -1318,8 +1319,8 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
 | 
				
			||||||
        raise NotImplementedError
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def save(
 | 
					    async def save(
 | 
				
			||||||
        self, pipeline: Optional[redis.client.Pipeline] = None
 | 
					        self: "Model", pipeline: Optional[redis.client.Pipeline] = None
 | 
				
			||||||
    ) -> "RedisModel":
 | 
					    ) -> "Model":
 | 
				
			||||||
        raise NotImplementedError
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def expire(
 | 
					    async def expire(
 | 
				
			||||||
| 
						 | 
					@ -1423,11 +1424,11 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    async def add(
 | 
					    async def add(
 | 
				
			||||||
        cls,
 | 
					        cls: Type["Model"],
 | 
				
			||||||
        models: Sequence["RedisModel"],
 | 
					        models: Sequence["Model"],
 | 
				
			||||||
        pipeline: Optional[redis.client.Pipeline] = None,
 | 
					        pipeline: Optional[redis.client.Pipeline] = None,
 | 
				
			||||||
        pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
 | 
					        pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
 | 
				
			||||||
    ) -> Sequence["RedisModel"]:
 | 
					    ) -> Sequence["Model"]:
 | 
				
			||||||
        db = cls._get_db(pipeline, bulk=True)
 | 
					        db = cls._get_db(pipeline, bulk=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for model in models:
 | 
					        for model in models:
 | 
				
			||||||
| 
						 | 
					@ -1502,8 +1503,8 @@ class HashModel(RedisModel, abc.ABC):
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def save(
 | 
					    async def save(
 | 
				
			||||||
        self, pipeline: Optional[redis.client.Pipeline] = None
 | 
					        self: "Model", pipeline: Optional[redis.client.Pipeline] = None
 | 
				
			||||||
    ) -> "HashModel":
 | 
					    ) -> "Model":
 | 
				
			||||||
        self.check()
 | 
					        self.check()
 | 
				
			||||||
        db = self._get_db(pipeline)
 | 
					        db = self._get_db(pipeline)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1525,7 +1526,7 @@ class HashModel(RedisModel, abc.ABC):
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    async def get(cls, pk: Any) -> "HashModel":
 | 
					    async def get(cls: Type["Model"], pk: Any) -> "Model":
 | 
				
			||||||
        document = await cls.db().hgetall(cls.make_primary_key(pk))
 | 
					        document = await cls.db().hgetall(cls.make_primary_key(pk))
 | 
				
			||||||
        if not document:
 | 
					        if not document:
 | 
				
			||||||
            raise NotFoundError
 | 
					            raise NotFoundError
 | 
				
			||||||
| 
						 | 
					@ -1676,8 +1677,8 @@ class JsonModel(RedisModel, abc.ABC):
 | 
				
			||||||
        super().__init__(*args, **kwargs)
 | 
					        super().__init__(*args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def save(
 | 
					    async def save(
 | 
				
			||||||
        self, pipeline: Optional[redis.client.Pipeline] = None
 | 
					        self: "Model", pipeline: Optional[redis.client.Pipeline] = None
 | 
				
			||||||
    ) -> "JsonModel":
 | 
					    ) -> "Model":
 | 
				
			||||||
        self.check()
 | 
					        self.check()
 | 
				
			||||||
        db = self._get_db(pipeline)
 | 
					        db = self._get_db(pipeline)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1722,7 +1723,7 @@ class JsonModel(RedisModel, abc.ABC):
 | 
				
			||||||
        await self.save()
 | 
					        await self.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    async def get(cls, pk: Any) -> "JsonModel":
 | 
					    async def get(cls: Type["Model"], pk: Any) -> "Model":
 | 
				
			||||||
        document = json.dumps(await cls.db().json().get(cls.make_key(pk)))
 | 
					        document = json.dumps(await cls.db().json().get(cls.make_key(pk)))
 | 
				
			||||||
        if document == "null":
 | 
					        if document == "null":
 | 
				
			||||||
            raise NotFoundError
 | 
					            raise NotFoundError
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue