Use TypeVars for return types of RedisModel and its subtype's methods (#476)

Co-authored-by: Chayim <chayim@users.noreply.github.com>
This commit is contained in:
Marián Hlaváč 2023-07-12 10:08:20 +02:00 committed by GitHub
parent 89b6c84c4a
commit c68adacea2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -47,6 +47,7 @@ from .token_escaper import TokenEscaper
model_registry = {} model_registry = {}
_T = TypeVar("_T") _T = TypeVar("_T")
Model = TypeVar("Model", bound="RedisModel")
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
escaper = TokenEscaper() escaper = TokenEscaper()
@ -1310,7 +1311,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
return await cls._delete(db, cls.make_primary_key(pk)) return await cls._delete(db, cls.make_primary_key(pk))
@classmethod @classmethod
async def get(cls, pk: Any) -> "RedisModel": async def get(cls: Type["Model"], pk: Any) -> "Model":
raise NotImplementedError raise NotImplementedError
async def update(self, **field_values): async def update(self, **field_values):
@ -1318,8 +1319,8 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
raise NotImplementedError raise NotImplementedError
async def save( async def save(
self, pipeline: Optional[redis.client.Pipeline] = None self: "Model", pipeline: Optional[redis.client.Pipeline] = None
) -> "RedisModel": ) -> "Model":
raise NotImplementedError raise NotImplementedError
async def expire( async def expire(
@ -1423,11 +1424,11 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
@classmethod @classmethod
async def add( async def add(
cls, cls: Type["Model"],
models: Sequence["RedisModel"], models: Sequence["Model"],
pipeline: Optional[redis.client.Pipeline] = None, pipeline: Optional[redis.client.Pipeline] = None,
pipeline_verifier: Callable[..., Any] = verify_pipeline_response, pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
) -> Sequence["RedisModel"]: ) -> Sequence["Model"]:
db = cls._get_db(pipeline, bulk=True) db = cls._get_db(pipeline, bulk=True)
for model in models: for model in models:
@ -1502,8 +1503,8 @@ class HashModel(RedisModel, abc.ABC):
) )
async def save( async def save(
self, pipeline: Optional[redis.client.Pipeline] = None self: "Model", pipeline: Optional[redis.client.Pipeline] = None
) -> "HashModel": ) -> "Model":
self.check() self.check()
db = self._get_db(pipeline) db = self._get_db(pipeline)
@ -1525,7 +1526,7 @@ class HashModel(RedisModel, abc.ABC):
) )
@classmethod @classmethod
async def get(cls, pk: Any) -> "HashModel": async def get(cls: Type["Model"], pk: Any) -> "Model":
document = await cls.db().hgetall(cls.make_primary_key(pk)) document = await cls.db().hgetall(cls.make_primary_key(pk))
if not document: if not document:
raise NotFoundError raise NotFoundError
@ -1676,8 +1677,8 @@ class JsonModel(RedisModel, abc.ABC):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
async def save( async def save(
self, pipeline: Optional[redis.client.Pipeline] = None self: "Model", pipeline: Optional[redis.client.Pipeline] = None
) -> "JsonModel": ) -> "Model":
self.check() self.check()
db = self._get_db(pipeline) db = self._get_db(pipeline)
@ -1722,7 +1723,7 @@ class JsonModel(RedisModel, abc.ABC):
await self.save() await self.save()
@classmethod @classmethod
async def get(cls, pk: Any) -> "JsonModel": async def get(cls: Type["Model"], pk: Any) -> "Model":
document = json.dumps(await cls.db().json().get(cls.make_key(pk))) document = json.dumps(await cls.db().json().get(cls.make_key(pk)))
if document == "null": if document == "null":
raise NotFoundError raise NotFoundError