diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index a4c6b9e..3644b63 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -25,6 +25,8 @@ from typing import ( ) from more_itertools import ichunked +from redis import Redis +from redis.asyncio import Redis as RedisAsync from redis.commands.json.path import Path from redis.exceptions import ResponseError from typing_extensions import Protocol, get_args, get_origin @@ -1255,9 +1257,7 @@ class ModelMeta(ModelMetaclass): base_meta, "primary_key_pattern", "{pk}" ) if not getattr(new_class._meta, "database", None): - new_class._meta.database = getattr( - base_meta, "database", get_redis_connection() - ) + new_class._meta.database = getattr(base_meta, "database", None) if not getattr(new_class._meta, "encoding", None): new_class._meta.encoding = getattr(base_meta, "encoding") if not getattr(new_class._meta, "primary_key_creator_cls", None): @@ -1282,6 +1282,7 @@ class ModelMeta(ModelMetaclass): class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): pk: Optional[str] = Field(default=None, primary_key=True) + _conn: Optional[Union[Redis, RedisAsync]] = None Meta = DefaultMeta @@ -1370,7 +1371,19 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): @classmethod def db(cls): - return cls._meta.database + if not cls._conn: + conn = ( + cls._meta.database() + if callable(cls._meta.database) + else cls._meta.database or get_redis_connection() + ) + if not has_redis_json(conn): + log.error( + "Your Redis instance does not have the RedisJson module " + "loaded. JsonModel depends on RedisJson." + ) + cls._conn = conn + return cls._conn @classmethod def find( @@ -1674,14 +1687,6 @@ class JsonModel(RedisModel, abc.ABC): # Generate the RediSearch schema once to validate fields. cls.redisearch_schema() - def __init__(self, *args, **kwargs): - if not has_redis_json(self.db()): - log.error( - "Your Redis instance does not have the RedisJson module " - "loaded. JsonModel depends on RedisJson." - ) - super().__init__(*args, **kwargs) - async def save( self: "Model", pipeline: Optional[redis.client.Pipeline] = None ) -> "Model":