Use TypeVars for return types of RedisModel and its subtype's methods (#476)

Co-authored-by: Chayim <chayim@users.noreply.github.com>
This commit is contained in:
Marián Hlaváč 2023-07-12 10:08:20 +02:00 committed by GitHub
parent 89b6c84c4a
commit c68adacea2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -47,6 +47,7 @@ from .token_escaper import TokenEscaper
model_registry = {}
_T = TypeVar("_T")
Model = TypeVar("Model", bound="RedisModel")
log = logging.getLogger(__name__)
escaper = TokenEscaper()
@ -1310,7 +1311,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
return await cls._delete(db, cls.make_primary_key(pk))
@classmethod
async def get(cls, pk: Any) -> "RedisModel":
async def get(cls: Type["Model"], pk: Any) -> "Model":
raise NotImplementedError
async def update(self, **field_values):
@ -1318,8 +1319,8 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
raise NotImplementedError
async def save(
self, pipeline: Optional[redis.client.Pipeline] = None
) -> "RedisModel":
self: "Model", pipeline: Optional[redis.client.Pipeline] = None
) -> "Model":
raise NotImplementedError
async def expire(
@ -1423,11 +1424,11 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
@classmethod
async def add(
cls,
models: Sequence["RedisModel"],
cls: Type["Model"],
models: Sequence["Model"],
pipeline: Optional[redis.client.Pipeline] = None,
pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
) -> Sequence["RedisModel"]:
) -> Sequence["Model"]:
db = cls._get_db(pipeline, bulk=True)
for model in models:
@ -1502,8 +1503,8 @@ class HashModel(RedisModel, abc.ABC):
)
async def save(
self, pipeline: Optional[redis.client.Pipeline] = None
) -> "HashModel":
self: "Model", pipeline: Optional[redis.client.Pipeline] = None
) -> "Model":
self.check()
db = self._get_db(pipeline)
@ -1525,7 +1526,7 @@ class HashModel(RedisModel, abc.ABC):
)
@classmethod
async def get(cls, pk: Any) -> "HashModel":
async def get(cls: Type["Model"], pk: Any) -> "Model":
document = await cls.db().hgetall(cls.make_primary_key(pk))
if not document:
raise NotFoundError
@ -1676,8 +1677,8 @@ class JsonModel(RedisModel, abc.ABC):
super().__init__(*args, **kwargs)
async def save(
self, pipeline: Optional[redis.client.Pipeline] = None
) -> "JsonModel":
self: "Model", pipeline: Optional[redis.client.Pipeline] = None
) -> "Model":
self.check()
db = self._get_db(pipeline)
@ -1722,7 +1723,7 @@ class JsonModel(RedisModel, abc.ABC):
await self.save()
@classmethod
async def get(cls, pk: Any) -> "JsonModel":
async def get(cls: Type["Model"], pk: Any) -> "Model":
document = json.dumps(await cls.db().json().get(cls.make_key(pk)))
if document == "null":
raise NotFoundError