From c68adacea2ce059982deae2a5c10f111e99055eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mari=C3=A1n=20Hlav=C3=A1=C4=8D?= Date: Wed, 12 Jul 2023 10:08:20 +0200 Subject: [PATCH] Use TypeVars for return types of RedisModel and its subtype's methods (#476) Co-authored-by: Chayim --- aredis_om/model/model.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 609ce1d..0ccc5b0 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -47,6 +47,7 @@ from .token_escaper import TokenEscaper model_registry = {} _T = TypeVar("_T") +Model = TypeVar("Model", bound="RedisModel") log = logging.getLogger(__name__) escaper = TokenEscaper() @@ -1310,7 +1311,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): return await cls._delete(db, cls.make_primary_key(pk)) @classmethod - async def get(cls, pk: Any) -> "RedisModel": + async def get(cls: Type["Model"], pk: Any) -> "Model": raise NotImplementedError async def update(self, **field_values): @@ -1318,8 +1319,8 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): raise NotImplementedError async def save( - self, pipeline: Optional[redis.client.Pipeline] = None - ) -> "RedisModel": + self: "Model", pipeline: Optional[redis.client.Pipeline] = None + ) -> "Model": raise NotImplementedError async def expire( @@ -1423,11 +1424,11 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): @classmethod async def add( - cls, - models: Sequence["RedisModel"], + cls: Type["Model"], + models: Sequence["Model"], pipeline: Optional[redis.client.Pipeline] = None, pipeline_verifier: Callable[..., Any] = verify_pipeline_response, - ) -> Sequence["RedisModel"]: + ) -> Sequence["Model"]: db = cls._get_db(pipeline, bulk=True) for model in models: @@ -1502,8 +1503,8 @@ class HashModel(RedisModel, abc.ABC): ) async def save( - self, pipeline: Optional[redis.client.Pipeline] = None - ) -> "HashModel": + self: "Model", pipeline: Optional[redis.client.Pipeline] = None + ) -> "Model": self.check() db = self._get_db(pipeline) @@ -1525,7 +1526,7 @@ class HashModel(RedisModel, abc.ABC): ) @classmethod - async def get(cls, pk: Any) -> "HashModel": + async def get(cls: Type["Model"], pk: Any) -> "Model": document = await cls.db().hgetall(cls.make_primary_key(pk)) if not document: raise NotFoundError @@ -1676,8 +1677,8 @@ class JsonModel(RedisModel, abc.ABC): super().__init__(*args, **kwargs) async def save( - self, pipeline: Optional[redis.client.Pipeline] = None - ) -> "JsonModel": + self: "Model", pipeline: Optional[redis.client.Pipeline] = None + ) -> "Model": self.check() db = self._get_db(pipeline) @@ -1722,7 +1723,7 @@ class JsonModel(RedisModel, abc.ABC): await self.save() @classmethod - async def get(cls, pk: Any) -> "JsonModel": + async def get(cls: Type["Model"], pk: Any) -> "Model": document = json.dumps(await cls.db().json().get(cls.make_key(pk))) if document == "null": raise NotFoundError