Use TypeVars for return types of RedisModel and its subtype's methods (#476)
Co-authored-by: Chayim <chayim@users.noreply.github.com>
This commit is contained in:
parent
89b6c84c4a
commit
c68adacea2
1 changed files with 13 additions and 12 deletions
|
@ -47,6 +47,7 @@ from .token_escaper import TokenEscaper
|
|||
|
||||
model_registry = {}
|
||||
_T = TypeVar("_T")
|
||||
Model = TypeVar("Model", bound="RedisModel")
|
||||
log = logging.getLogger(__name__)
|
||||
escaper = TokenEscaper()
|
||||
|
||||
|
@ -1310,7 +1311,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
|||
return await cls._delete(db, cls.make_primary_key(pk))
|
||||
|
||||
@classmethod
|
||||
async def get(cls, pk: Any) -> "RedisModel":
|
||||
async def get(cls: Type["Model"], pk: Any) -> "Model":
|
||||
raise NotImplementedError
|
||||
|
||||
async def update(self, **field_values):
|
||||
|
@ -1318,8 +1319,8 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
|||
raise NotImplementedError
|
||||
|
||||
async def save(
|
||||
self, pipeline: Optional[redis.client.Pipeline] = None
|
||||
) -> "RedisModel":
|
||||
self: "Model", pipeline: Optional[redis.client.Pipeline] = None
|
||||
) -> "Model":
|
||||
raise NotImplementedError
|
||||
|
||||
async def expire(
|
||||
|
@ -1423,11 +1424,11 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
|||
|
||||
@classmethod
|
||||
async def add(
|
||||
cls,
|
||||
models: Sequence["RedisModel"],
|
||||
cls: Type["Model"],
|
||||
models: Sequence["Model"],
|
||||
pipeline: Optional[redis.client.Pipeline] = None,
|
||||
pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
|
||||
) -> Sequence["RedisModel"]:
|
||||
) -> Sequence["Model"]:
|
||||
db = cls._get_db(pipeline, bulk=True)
|
||||
|
||||
for model in models:
|
||||
|
@ -1502,8 +1503,8 @@ class HashModel(RedisModel, abc.ABC):
|
|||
)
|
||||
|
||||
async def save(
|
||||
self, pipeline: Optional[redis.client.Pipeline] = None
|
||||
) -> "HashModel":
|
||||
self: "Model", pipeline: Optional[redis.client.Pipeline] = None
|
||||
) -> "Model":
|
||||
self.check()
|
||||
db = self._get_db(pipeline)
|
||||
|
||||
|
@ -1525,7 +1526,7 @@ class HashModel(RedisModel, abc.ABC):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
async def get(cls, pk: Any) -> "HashModel":
|
||||
async def get(cls: Type["Model"], pk: Any) -> "Model":
|
||||
document = await cls.db().hgetall(cls.make_primary_key(pk))
|
||||
if not document:
|
||||
raise NotFoundError
|
||||
|
@ -1676,8 +1677,8 @@ class JsonModel(RedisModel, abc.ABC):
|
|||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def save(
|
||||
self, pipeline: Optional[redis.client.Pipeline] = None
|
||||
) -> "JsonModel":
|
||||
self: "Model", pipeline: Optional[redis.client.Pipeline] = None
|
||||
) -> "Model":
|
||||
self.check()
|
||||
db = self._get_db(pipeline)
|
||||
|
||||
|
@ -1722,7 +1723,7 @@ class JsonModel(RedisModel, abc.ABC):
|
|||
await self.save()
|
||||
|
||||
@classmethod
|
||||
async def get(cls, pk: Any) -> "JsonModel":
|
||||
async def get(cls: Type["Model"], pk: Any) -> "Model":
|
||||
document = json.dumps(await cls.db().json().get(cls.make_key(pk)))
|
||||
if document == "null":
|
||||
raise NotFoundError
|
||||
|
|
Loading…
Reference in a new issue