Migrate from aioredis to redis-py with asyncio support (#233)
* Migrate from aioredis to redis with asyncio support Add test for redis type Fix imports from wrong module (for tests_sync) * fixing merge conflicts and up to dating the lock file Co-authored-by: Chayim I. Kirshen <c@kirshen.com>
This commit is contained in:
parent
e5e887229a
commit
4661459ddd
15 changed files with 484 additions and 494 deletions
|
@ -1,3 +1,4 @@
|
||||||
|
from .async_redis import redis # isort:skip
|
||||||
from .checks import has_redis_json, has_redisearch
|
from .checks import has_redis_json, has_redisearch
|
||||||
from .connections import get_redis_connection
|
from .connections import get_redis_connection
|
||||||
from .model.migrations.migrator import MigrationError, Migrator
|
from .model.migrations.migrator import MigrationError, Migrator
|
||||||
|
|
1
aredis_om/async_redis.py
Normal file
1
aredis_om/async_redis.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
from redis import asyncio as redis
|
|
@ -1,19 +1,19 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import aioredis
|
from . import redis
|
||||||
|
|
||||||
|
|
||||||
URL = os.environ.get("REDIS_OM_URL", None)
|
URL = os.environ.get("REDIS_OM_URL", None)
|
||||||
|
|
||||||
|
|
||||||
def get_redis_connection(**kwargs) -> aioredis.Redis:
|
def get_redis_connection(**kwargs) -> redis.Redis:
|
||||||
# If someone passed in a 'url' parameter, or specified a REDIS_OM_URL
|
# If someone passed in a 'url' parameter, or specified a REDIS_OM_URL
|
||||||
# environment variable, we'll create the Redis client from the URL.
|
# environment variable, we'll create the Redis client from the URL.
|
||||||
url = kwargs.pop("url", URL)
|
url = kwargs.pop("url", URL)
|
||||||
if url:
|
if url:
|
||||||
return aioredis.Redis.from_url(url, **kwargs)
|
return redis.Redis.from_url(url, **kwargs)
|
||||||
|
|
||||||
# Decode from UTF-8 by default
|
# Decode from UTF-8 by default
|
||||||
if "decode_responses" not in kwargs:
|
if "decode_responses" not in kwargs:
|
||||||
kwargs["decode_responses"] = True
|
kwargs["decode_responses"] = True
|
||||||
return aioredis.Redis(**kwargs)
|
return redis.Redis(**kwargs)
|
||||||
|
|
|
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from aioredis import Redis, ResponseError
|
from ... import redis
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -39,18 +39,19 @@ def schema_hash_key(index_name):
|
||||||
return f"{index_name}:hash"
|
return f"{index_name}:hash"
|
||||||
|
|
||||||
|
|
||||||
async def create_index(redis: Redis, index_name, schema, current_hash):
|
async def create_index(conn: redis.Redis, index_name, schema, current_hash):
|
||||||
db_number = redis.connection_pool.connection_kwargs.get("db")
|
db_number = conn.connection_pool.connection_kwargs.get("db")
|
||||||
if db_number and db_number > 0:
|
if db_number and db_number > 0:
|
||||||
raise MigrationError(
|
raise MigrationError(
|
||||||
"Creating search indexes is only supported in database 0. "
|
"Creating search indexes is only supported in database 0. "
|
||||||
f"You attempted to create an index in database {db_number}"
|
f"You attempted to create an index in database {db_number}"
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
await redis.execute_command(f"ft.info {index_name}")
|
await conn.execute_command(f"ft.info {index_name}")
|
||||||
except ResponseError:
|
except redis.ResponseError:
|
||||||
await redis.execute_command(f"ft.create {index_name} {schema}")
|
await conn.execute_command(f"ft.create {index_name} {schema}")
|
||||||
await redis.set(schema_hash_key(index_name), current_hash)
|
# TODO: remove "type: ignore" when type stubs will be fixed
|
||||||
|
await conn.set(schema_hash_key(index_name), current_hash) # type: ignore
|
||||||
else:
|
else:
|
||||||
log.info("Index already exists, skipping. Index hash: %s", index_name)
|
log.info("Index already exists, skipping. Index hash: %s", index_name)
|
||||||
|
|
||||||
|
@ -67,7 +68,7 @@ class IndexMigration:
|
||||||
schema: str
|
schema: str
|
||||||
hash: str
|
hash: str
|
||||||
action: MigrationAction
|
action: MigrationAction
|
||||||
redis: Redis
|
conn: redis.Redis
|
||||||
previous_hash: Optional[str] = None
|
previous_hash: Optional[str] = None
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
|
@ -78,14 +79,14 @@ class IndexMigration:
|
||||||
|
|
||||||
async def create(self):
|
async def create(self):
|
||||||
try:
|
try:
|
||||||
await create_index(self.redis, self.index_name, self.schema, self.hash)
|
await create_index(self.conn, self.index_name, self.schema, self.hash)
|
||||||
except ResponseError:
|
except redis.ResponseError:
|
||||||
log.info("Index already exists: %s", self.index_name)
|
log.info("Index already exists: %s", self.index_name)
|
||||||
|
|
||||||
async def drop(self):
|
async def drop(self):
|
||||||
try:
|
try:
|
||||||
await self.redis.execute_command(f"FT.DROPINDEX {self.index_name}")
|
await self.conn.execute_command(f"FT.DROPINDEX {self.index_name}")
|
||||||
except ResponseError:
|
except redis.ResponseError:
|
||||||
log.info("Index does not exist: %s", self.index_name)
|
log.info("Index does not exist: %s", self.index_name)
|
||||||
|
|
||||||
|
|
||||||
|
@ -105,7 +106,7 @@ class Migrator:
|
||||||
|
|
||||||
for name, cls in model_registry.items():
|
for name, cls in model_registry.items():
|
||||||
hash_key = schema_hash_key(cls.Meta.index_name)
|
hash_key = schema_hash_key(cls.Meta.index_name)
|
||||||
redis = cls.db()
|
conn = cls.db()
|
||||||
try:
|
try:
|
||||||
schema = cls.redisearch_schema()
|
schema = cls.redisearch_schema()
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
|
@ -114,8 +115,8 @@ class Migrator:
|
||||||
current_hash = hashlib.sha1(schema.encode("utf-8")).hexdigest() # nosec
|
current_hash = hashlib.sha1(schema.encode("utf-8")).hexdigest() # nosec
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await redis.execute_command("ft.info", cls.Meta.index_name)
|
await conn.execute_command("ft.info", cls.Meta.index_name)
|
||||||
except ResponseError:
|
except redis.ResponseError:
|
||||||
self.migrations.append(
|
self.migrations.append(
|
||||||
IndexMigration(
|
IndexMigration(
|
||||||
name,
|
name,
|
||||||
|
@ -123,12 +124,12 @@ class Migrator:
|
||||||
schema,
|
schema,
|
||||||
current_hash,
|
current_hash,
|
||||||
MigrationAction.CREATE,
|
MigrationAction.CREATE,
|
||||||
redis,
|
conn,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
stored_hash = await redis.get(hash_key)
|
stored_hash = await conn.get(hash_key)
|
||||||
schema_out_of_date = current_hash != stored_hash
|
schema_out_of_date = current_hash != stored_hash
|
||||||
|
|
||||||
if schema_out_of_date:
|
if schema_out_of_date:
|
||||||
|
@ -140,7 +141,7 @@ class Migrator:
|
||||||
schema,
|
schema,
|
||||||
current_hash,
|
current_hash,
|
||||||
MigrationAction.DROP,
|
MigrationAction.DROP,
|
||||||
redis,
|
conn,
|
||||||
stored_hash,
|
stored_hash,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -151,7 +152,7 @@ class Migrator:
|
||||||
schema,
|
schema,
|
||||||
current_hash,
|
current_hash,
|
||||||
MigrationAction.CREATE,
|
MigrationAction.CREATE,
|
||||||
redis,
|
conn,
|
||||||
stored_hash,
|
stored_hash,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -24,8 +24,6 @@ from typing import (
|
||||||
no_type_check,
|
no_type_check,
|
||||||
)
|
)
|
||||||
|
|
||||||
import aioredis
|
|
||||||
from aioredis.client import Pipeline
|
|
||||||
from pydantic import BaseModel, validator
|
from pydantic import BaseModel, validator
|
||||||
from pydantic.fields import FieldInfo as PydanticFieldInfo
|
from pydantic.fields import FieldInfo as PydanticFieldInfo
|
||||||
from pydantic.fields import ModelField, Undefined, UndefinedType
|
from pydantic.fields import ModelField, Undefined, UndefinedType
|
||||||
|
@ -35,9 +33,10 @@ from pydantic.utils import Representation
|
||||||
from typing_extensions import Protocol, get_args, get_origin
|
from typing_extensions import Protocol, get_args, get_origin
|
||||||
from ulid import ULID
|
from ulid import ULID
|
||||||
|
|
||||||
|
from .. import redis
|
||||||
from ..checks import has_redis_json, has_redisearch
|
from ..checks import has_redis_json, has_redisearch
|
||||||
from ..connections import get_redis_connection
|
from ..connections import get_redis_connection
|
||||||
from ..unasync_util import ASYNC_MODE
|
from ..util import ASYNC_MODE
|
||||||
from .encoders import jsonable_encoder
|
from .encoders import jsonable_encoder
|
||||||
from .render_tree import render_tree
|
from .render_tree import render_tree
|
||||||
from .token_escaper import TokenEscaper
|
from .token_escaper import TokenEscaper
|
||||||
|
@ -978,7 +977,7 @@ class BaseMeta(Protocol):
|
||||||
global_key_prefix: str
|
global_key_prefix: str
|
||||||
model_key_prefix: str
|
model_key_prefix: str
|
||||||
primary_key_pattern: str
|
primary_key_pattern: str
|
||||||
database: aioredis.Redis
|
database: redis.Redis
|
||||||
primary_key: PrimaryKey
|
primary_key: PrimaryKey
|
||||||
primary_key_creator_cls: Type[PrimaryKeyCreator]
|
primary_key_creator_cls: Type[PrimaryKeyCreator]
|
||||||
index_name: str
|
index_name: str
|
||||||
|
@ -997,7 +996,7 @@ class DefaultMeta:
|
||||||
global_key_prefix: Optional[str] = None
|
global_key_prefix: Optional[str] = None
|
||||||
model_key_prefix: Optional[str] = None
|
model_key_prefix: Optional[str] = None
|
||||||
primary_key_pattern: Optional[str] = None
|
primary_key_pattern: Optional[str] = None
|
||||||
database: Optional[aioredis.Redis] = None
|
database: Optional[redis.Redis] = None
|
||||||
primary_key: Optional[PrimaryKey] = None
|
primary_key: Optional[PrimaryKey] = None
|
||||||
primary_key_creator_cls: Optional[Type[PrimaryKeyCreator]] = None
|
primary_key_creator_cls: Optional[Type[PrimaryKeyCreator]] = None
|
||||||
index_name: Optional[str] = None
|
index_name: Optional[str] = None
|
||||||
|
@ -1130,10 +1129,14 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
||||||
"""Update this model instance with the specified key-value pairs."""
|
"""Update this model instance with the specified key-value pairs."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def save(self, pipeline: Optional[Pipeline] = None) -> "RedisModel":
|
async def save(
|
||||||
|
self, pipeline: Optional[redis.client.Pipeline] = None
|
||||||
|
) -> "RedisModel":
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def expire(self, num_seconds: int, pipeline: Optional[Pipeline] = None):
|
async def expire(
|
||||||
|
self, num_seconds: int, pipeline: Optional[redis.client.Pipeline] = None
|
||||||
|
):
|
||||||
if pipeline is None:
|
if pipeline is None:
|
||||||
db = self.db()
|
db = self.db()
|
||||||
else:
|
else:
|
||||||
|
@ -1226,7 +1229,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
||||||
async def add(
|
async def add(
|
||||||
cls,
|
cls,
|
||||||
models: Sequence["RedisModel"],
|
models: Sequence["RedisModel"],
|
||||||
pipeline: Optional[Pipeline] = None,
|
pipeline: Optional[redis.client.Pipeline] = None,
|
||||||
pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
|
pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
|
||||||
) -> Sequence["RedisModel"]:
|
) -> Sequence["RedisModel"]:
|
||||||
if pipeline is None:
|
if pipeline is None:
|
||||||
|
@ -1286,7 +1289,9 @@ class HashModel(RedisModel, abc.ABC):
|
||||||
f"HashModels cannot index dataclass fields. Field: {name}"
|
f"HashModels cannot index dataclass fields. Field: {name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def save(self, pipeline: Optional[Pipeline] = None) -> "HashModel":
|
async def save(
|
||||||
|
self, pipeline: Optional[redis.client.Pipeline] = None
|
||||||
|
) -> "HashModel":
|
||||||
self.check()
|
self.check()
|
||||||
if pipeline is None:
|
if pipeline is None:
|
||||||
db = self.db()
|
db = self.db()
|
||||||
|
@ -1458,7 +1463,9 @@ class JsonModel(RedisModel, abc.ABC):
|
||||||
)
|
)
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
async def save(self, pipeline: Optional[Pipeline] = None) -> "JsonModel":
|
async def save(
|
||||||
|
self, pipeline: Optional[redis.client.Pipeline] = None
|
||||||
|
) -> "JsonModel":
|
||||||
self.check()
|
self.check()
|
||||||
if pipeline is None:
|
if pipeline is None:
|
||||||
db = self.db()
|
db = self.db()
|
||||||
|
|
1
aredis_om/sync_redis.py
Normal file
1
aredis_om/sync_redis.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
import redis
|
|
@ -1,41 +0,0 @@
|
||||||
"""Set of utility functions for unasync that transform into sync counterparts cleanly"""
|
|
||||||
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
|
|
||||||
_original_next = next
|
|
||||||
|
|
||||||
|
|
||||||
def is_async_mode():
|
|
||||||
"""Tests if we're in the async part of the code or not"""
|
|
||||||
|
|
||||||
async def f():
|
|
||||||
"""Unasync transforms async functions in sync functions"""
|
|
||||||
return None
|
|
||||||
|
|
||||||
obj = f()
|
|
||||||
if obj is None:
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
obj.close() # prevent unawaited coroutine warning
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
ASYNC_MODE = is_async_mode()
|
|
||||||
|
|
||||||
|
|
||||||
async def anext(x):
|
|
||||||
return await x.__anext__()
|
|
||||||
|
|
||||||
|
|
||||||
async def await_if_coro(x):
|
|
||||||
if inspect.iscoroutine(x):
|
|
||||||
return await x
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
next = _original_next
|
|
||||||
|
|
||||||
|
|
||||||
def return_non_coro(x):
|
|
||||||
return x
|
|
12
aredis_om/util.py
Normal file
12
aredis_om/util.py
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
|
||||||
|
def is_async_mode():
|
||||||
|
async def f():
|
||||||
|
"""Unasync transforms async functions in sync functions"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
return inspect.iscoroutinefunction(f)
|
||||||
|
|
||||||
|
|
||||||
|
ASYNC_MODE = is_async_mode()
|
11
make_sync.py
11
make_sync.py
|
@ -5,7 +5,7 @@ import unasync
|
||||||
|
|
||||||
ADDITIONAL_REPLACEMENTS = {
|
ADDITIONAL_REPLACEMENTS = {
|
||||||
"aredis_om": "redis_om",
|
"aredis_om": "redis_om",
|
||||||
"aioredis": "redis",
|
"async_redis": "sync_redis",
|
||||||
":tests.": ":tests_sync.",
|
":tests.": ":tests_sync.",
|
||||||
"pytest_asyncio": "pytest",
|
"pytest_asyncio": "pytest",
|
||||||
"py_test_mark_asyncio": "py_test_mark_sync",
|
"py_test_mark_asyncio": "py_test_mark_sync",
|
||||||
|
@ -26,11 +26,12 @@ def main():
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
filepaths = []
|
filepaths = []
|
||||||
for root, _, filenames in os.walk(
|
for root, _, filenames in os.walk(Path(__file__).absolute().parent):
|
||||||
Path(__file__).absolute().parent
|
|
||||||
):
|
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
if filename.rpartition(".")[-1] in ("py", "pyi",):
|
if filename.rpartition(".")[-1] in (
|
||||||
|
"py",
|
||||||
|
"pyi",
|
||||||
|
):
|
||||||
filepaths.append(os.path.join(root, filename))
|
filepaths.append(os.path.join(root, filename))
|
||||||
|
|
||||||
unasync.unasync_files(filepaths, rules)
|
unasync.unasync_files(filepaths, rules)
|
||||||
|
|
813
poetry.lock
generated
813
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -23,8 +23,7 @@ include=[
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.7"
|
python = "^3.7"
|
||||||
redis = ">=3.5.3,<5.0.0"
|
redis = ">=4.2.0,<5.0.0"
|
||||||
aioredis = "^2.0.0"
|
|
||||||
pydantic = "^1.8.2"
|
pydantic = "^1.8.2"
|
||||||
click = "^8.0.1"
|
click = "^8.0.1"
|
||||||
pptree = "^3.1"
|
pptree = "^3.1"
|
||||||
|
|
|
@ -23,7 +23,9 @@ from aredis_om import (
|
||||||
# We need to run this check as sync code (during tests) even in async mode
|
# We need to run this check as sync code (during tests) even in async mode
|
||||||
# because we call it in the top-level module scope.
|
# because we call it in the top-level module scope.
|
||||||
from redis_om import has_redisearch
|
from redis_om import has_redisearch
|
||||||
from tests.conftest import py_test_mark_asyncio
|
|
||||||
|
from .conftest import py_test_mark_asyncio
|
||||||
|
|
||||||
|
|
||||||
if not has_redisearch():
|
if not has_redisearch():
|
||||||
pytestmark = pytest.mark.skip
|
pytestmark = pytest.mark.skip
|
||||||
|
|
|
@ -25,7 +25,9 @@ from aredis_om import (
|
||||||
# We need to run this check as sync code (during tests) even in async mode
|
# We need to run this check as sync code (during tests) even in async mode
|
||||||
# because we call it in the top-level module scope.
|
# because we call it in the top-level module scope.
|
||||||
from redis_om import has_redis_json
|
from redis_om import has_redis_json
|
||||||
from tests.conftest import py_test_mark_asyncio
|
|
||||||
|
from .conftest import py_test_mark_asyncio
|
||||||
|
|
||||||
|
|
||||||
if not has_redis_json():
|
if not has_redis_json():
|
||||||
pytestmark = pytest.mark.skip
|
pytestmark = pytest.mark.skip
|
||||||
|
|
|
@ -9,7 +9,8 @@ import pytest_asyncio
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from aredis_om import HashModel, Migrator, NotFoundError, RedisModelError
|
from aredis_om import HashModel, Migrator, NotFoundError, RedisModelError
|
||||||
from tests.conftest import py_test_mark_asyncio
|
|
||||||
|
from .conftest import py_test_mark_asyncio
|
||||||
|
|
||||||
|
|
||||||
today = datetime.date.today()
|
today = datetime.date.today()
|
||||||
|
|
10
tests/test_redis_type.py
Normal file
10
tests/test_redis_type.py
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
from aredis_om import redis
|
||||||
|
from aredis_om.util import ASYNC_MODE
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_type():
|
||||||
|
import redis as sync_redis_module
|
||||||
|
import redis.asyncio as async_redis_module
|
||||||
|
|
||||||
|
mapping = {True: async_redis_module, False: sync_redis_module}
|
||||||
|
assert mapping[ASYNC_MODE] is redis
|
Loading…
Reference in a new issue