Use TypeVars for return types of RedisModel and its subtype's methods (#476)
Co-authored-by: Chayim <chayim@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									89b6c84c4a
								
							
						
					
					
						commit
						c68adacea2
					
				
					 1 changed files with 13 additions and 12 deletions
				
			
		| 
						 | 
				
			
			@ -47,6 +47,7 @@ from .token_escaper import TokenEscaper
 | 
			
		|||
 | 
			
		||||
model_registry = {}
 | 
			
		||||
_T = TypeVar("_T")
 | 
			
		||||
Model = TypeVar("Model", bound="RedisModel")
 | 
			
		||||
log = logging.getLogger(__name__)
 | 
			
		||||
escaper = TokenEscaper()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1310,7 +1311,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
 | 
			
		|||
        return await cls._delete(db, cls.make_primary_key(pk))
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    async def get(cls, pk: Any) -> "RedisModel":
 | 
			
		||||
    async def get(cls: Type["Model"], pk: Any) -> "Model":
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    async def update(self, **field_values):
 | 
			
		||||
| 
						 | 
				
			
			@ -1318,8 +1319,8 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
 | 
			
		|||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    async def save(
 | 
			
		||||
        self, pipeline: Optional[redis.client.Pipeline] = None
 | 
			
		||||
    ) -> "RedisModel":
 | 
			
		||||
        self: "Model", pipeline: Optional[redis.client.Pipeline] = None
 | 
			
		||||
    ) -> "Model":
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    async def expire(
 | 
			
		||||
| 
						 | 
				
			
			@ -1423,11 +1424,11 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
 | 
			
		|||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    async def add(
 | 
			
		||||
        cls,
 | 
			
		||||
        models: Sequence["RedisModel"],
 | 
			
		||||
        cls: Type["Model"],
 | 
			
		||||
        models: Sequence["Model"],
 | 
			
		||||
        pipeline: Optional[redis.client.Pipeline] = None,
 | 
			
		||||
        pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
 | 
			
		||||
    ) -> Sequence["RedisModel"]:
 | 
			
		||||
    ) -> Sequence["Model"]:
 | 
			
		||||
        db = cls._get_db(pipeline, bulk=True)
 | 
			
		||||
 | 
			
		||||
        for model in models:
 | 
			
		||||
| 
						 | 
				
			
			@ -1502,8 +1503,8 @@ class HashModel(RedisModel, abc.ABC):
 | 
			
		|||
                )
 | 
			
		||||
 | 
			
		||||
    async def save(
 | 
			
		||||
        self, pipeline: Optional[redis.client.Pipeline] = None
 | 
			
		||||
    ) -> "HashModel":
 | 
			
		||||
        self: "Model", pipeline: Optional[redis.client.Pipeline] = None
 | 
			
		||||
    ) -> "Model":
 | 
			
		||||
        self.check()
 | 
			
		||||
        db = self._get_db(pipeline)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1525,7 +1526,7 @@ class HashModel(RedisModel, abc.ABC):
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    async def get(cls, pk: Any) -> "HashModel":
 | 
			
		||||
    async def get(cls: Type["Model"], pk: Any) -> "Model":
 | 
			
		||||
        document = await cls.db().hgetall(cls.make_primary_key(pk))
 | 
			
		||||
        if not document:
 | 
			
		||||
            raise NotFoundError
 | 
			
		||||
| 
						 | 
				
			
			@ -1676,8 +1677,8 @@ class JsonModel(RedisModel, abc.ABC):
 | 
			
		|||
        super().__init__(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    async def save(
 | 
			
		||||
        self, pipeline: Optional[redis.client.Pipeline] = None
 | 
			
		||||
    ) -> "JsonModel":
 | 
			
		||||
        self: "Model", pipeline: Optional[redis.client.Pipeline] = None
 | 
			
		||||
    ) -> "Model":
 | 
			
		||||
        self.check()
 | 
			
		||||
        db = self._get_db(pipeline)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1722,7 +1723,7 @@ class JsonModel(RedisModel, abc.ABC):
 | 
			
		|||
        await self.save()
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    async def get(cls, pk: Any) -> "JsonModel":
 | 
			
		||||
    async def get(cls: Type["Model"], pk: Any) -> "Model":
 | 
			
		||||
        document = json.dumps(await cls.db().json().get(cls.make_key(pk)))
 | 
			
		||||
        if document == "null":
 | 
			
		||||
            raise NotFoundError
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue