Use TypeVars for return types of RedisModel and its subtype's methods (#476)
Co-authored-by: Chayim <chayim@users.noreply.github.com>
This commit is contained in:
parent
89b6c84c4a
commit
c68adacea2
1 changed files with 13 additions and 12 deletions
|
@ -47,6 +47,7 @@ from .token_escaper import TokenEscaper
|
||||||
|
|
||||||
model_registry = {}
|
model_registry = {}
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
|
Model = TypeVar("Model", bound="RedisModel")
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
escaper = TokenEscaper()
|
escaper = TokenEscaper()
|
||||||
|
|
||||||
|
@ -1310,7 +1311,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
||||||
return await cls._delete(db, cls.make_primary_key(pk))
|
return await cls._delete(db, cls.make_primary_key(pk))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get(cls, pk: Any) -> "RedisModel":
|
async def get(cls: Type["Model"], pk: Any) -> "Model":
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def update(self, **field_values):
|
async def update(self, **field_values):
|
||||||
|
@ -1318,8 +1319,8 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def save(
|
async def save(
|
||||||
self, pipeline: Optional[redis.client.Pipeline] = None
|
self: "Model", pipeline: Optional[redis.client.Pipeline] = None
|
||||||
) -> "RedisModel":
|
) -> "Model":
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def expire(
|
async def expire(
|
||||||
|
@ -1423,11 +1424,11 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def add(
|
async def add(
|
||||||
cls,
|
cls: Type["Model"],
|
||||||
models: Sequence["RedisModel"],
|
models: Sequence["Model"],
|
||||||
pipeline: Optional[redis.client.Pipeline] = None,
|
pipeline: Optional[redis.client.Pipeline] = None,
|
||||||
pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
|
pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
|
||||||
) -> Sequence["RedisModel"]:
|
) -> Sequence["Model"]:
|
||||||
db = cls._get_db(pipeline, bulk=True)
|
db = cls._get_db(pipeline, bulk=True)
|
||||||
|
|
||||||
for model in models:
|
for model in models:
|
||||||
|
@ -1502,8 +1503,8 @@ class HashModel(RedisModel, abc.ABC):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def save(
|
async def save(
|
||||||
self, pipeline: Optional[redis.client.Pipeline] = None
|
self: "Model", pipeline: Optional[redis.client.Pipeline] = None
|
||||||
) -> "HashModel":
|
) -> "Model":
|
||||||
self.check()
|
self.check()
|
||||||
db = self._get_db(pipeline)
|
db = self._get_db(pipeline)
|
||||||
|
|
||||||
|
@ -1525,7 +1526,7 @@ class HashModel(RedisModel, abc.ABC):
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get(cls, pk: Any) -> "HashModel":
|
async def get(cls: Type["Model"], pk: Any) -> "Model":
|
||||||
document = await cls.db().hgetall(cls.make_primary_key(pk))
|
document = await cls.db().hgetall(cls.make_primary_key(pk))
|
||||||
if not document:
|
if not document:
|
||||||
raise NotFoundError
|
raise NotFoundError
|
||||||
|
@ -1676,8 +1677,8 @@ class JsonModel(RedisModel, abc.ABC):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
async def save(
|
async def save(
|
||||||
self, pipeline: Optional[redis.client.Pipeline] = None
|
self: "Model", pipeline: Optional[redis.client.Pipeline] = None
|
||||||
) -> "JsonModel":
|
) -> "Model":
|
||||||
self.check()
|
self.check()
|
||||||
db = self._get_db(pipeline)
|
db = self._get_db(pipeline)
|
||||||
|
|
||||||
|
@ -1722,7 +1723,7 @@ class JsonModel(RedisModel, abc.ABC):
|
||||||
await self.save()
|
await self.save()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get(cls, pk: Any) -> "JsonModel":
|
async def get(cls: Type["Model"], pk: Any) -> "Model":
|
||||||
document = json.dumps(await cls.db().json().get(cls.make_key(pk)))
|
document = json.dumps(await cls.db().json().get(cls.make_key(pk)))
|
||||||
if document == "null":
|
if document == "null":
|
||||||
raise NotFoundError
|
raise NotFoundError
|
||||||
|
|
Loading…
Reference in a new issue