Migrate from aioredis to redis-py with asyncio support (#233)

* Migrate from aioredis to redis with asyncio support

Add test for redis type
Fix imports from wrong module (for tests_sync)

* fixing merge conflicts and up to dating the lock file

Co-authored-by: Chayim I. Kirshen <c@kirshen.com>
main
Serhii Charykov 2 years ago committed by GitHub
parent e5e887229a
commit 4661459ddd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,3 +1,4 @@
from .async_redis import redis # isort:skip
from .checks import has_redis_json, has_redisearch from .checks import has_redis_json, has_redisearch
from .connections import get_redis_connection from .connections import get_redis_connection
from .model.migrations.migrator import MigrationError, Migrator from .model.migrations.migrator import MigrationError, Migrator

@ -0,0 +1 @@
from redis import asyncio as redis

@ -1,19 +1,19 @@
import os import os
import aioredis from . import redis
URL = os.environ.get("REDIS_OM_URL", None) URL = os.environ.get("REDIS_OM_URL", None)
def get_redis_connection(**kwargs) -> aioredis.Redis: def get_redis_connection(**kwargs) -> redis.Redis:
# If someone passed in a 'url' parameter, or specified a REDIS_OM_URL # If someone passed in a 'url' parameter, or specified a REDIS_OM_URL
# environment variable, we'll create the Redis client from the URL. # environment variable, we'll create the Redis client from the URL.
url = kwargs.pop("url", URL) url = kwargs.pop("url", URL)
if url: if url:
return aioredis.Redis.from_url(url, **kwargs) return redis.Redis.from_url(url, **kwargs)
# Decode from UTF-8 by default # Decode from UTF-8 by default
if "decode_responses" not in kwargs: if "decode_responses" not in kwargs:
kwargs["decode_responses"] = True kwargs["decode_responses"] = True
return aioredis.Redis(**kwargs) return redis.Redis(**kwargs)

@ -4,7 +4,7 @@ from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import List, Optional from typing import List, Optional
from aioredis import Redis, ResponseError from ... import redis
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -39,18 +39,19 @@ def schema_hash_key(index_name):
return f"{index_name}:hash" return f"{index_name}:hash"
async def create_index(redis: Redis, index_name, schema, current_hash): async def create_index(conn: redis.Redis, index_name, schema, current_hash):
db_number = redis.connection_pool.connection_kwargs.get("db") db_number = conn.connection_pool.connection_kwargs.get("db")
if db_number and db_number > 0: if db_number and db_number > 0:
raise MigrationError( raise MigrationError(
"Creating search indexes is only supported in database 0. " "Creating search indexes is only supported in database 0. "
f"You attempted to create an index in database {db_number}" f"You attempted to create an index in database {db_number}"
) )
try: try:
await redis.execute_command(f"ft.info {index_name}") await conn.execute_command(f"ft.info {index_name}")
except ResponseError: except redis.ResponseError:
await redis.execute_command(f"ft.create {index_name} {schema}") await conn.execute_command(f"ft.create {index_name} {schema}")
await redis.set(schema_hash_key(index_name), current_hash) # TODO: remove "type: ignore" when type stubs will be fixed
await conn.set(schema_hash_key(index_name), current_hash) # type: ignore
else: else:
log.info("Index already exists, skipping. Index hash: %s", index_name) log.info("Index already exists, skipping. Index hash: %s", index_name)
@ -67,7 +68,7 @@ class IndexMigration:
schema: str schema: str
hash: str hash: str
action: MigrationAction action: MigrationAction
redis: Redis conn: redis.Redis
previous_hash: Optional[str] = None previous_hash: Optional[str] = None
async def run(self): async def run(self):
@ -78,14 +79,14 @@ class IndexMigration:
async def create(self): async def create(self):
try: try:
await create_index(self.redis, self.index_name, self.schema, self.hash) await create_index(self.conn, self.index_name, self.schema, self.hash)
except ResponseError: except redis.ResponseError:
log.info("Index already exists: %s", self.index_name) log.info("Index already exists: %s", self.index_name)
async def drop(self): async def drop(self):
try: try:
await self.redis.execute_command(f"FT.DROPINDEX {self.index_name}") await self.conn.execute_command(f"FT.DROPINDEX {self.index_name}")
except ResponseError: except redis.ResponseError:
log.info("Index does not exist: %s", self.index_name) log.info("Index does not exist: %s", self.index_name)
@ -105,7 +106,7 @@ class Migrator:
for name, cls in model_registry.items(): for name, cls in model_registry.items():
hash_key = schema_hash_key(cls.Meta.index_name) hash_key = schema_hash_key(cls.Meta.index_name)
redis = cls.db() conn = cls.db()
try: try:
schema = cls.redisearch_schema() schema = cls.redisearch_schema()
except NotImplementedError: except NotImplementedError:
@ -114,8 +115,8 @@ class Migrator:
current_hash = hashlib.sha1(schema.encode("utf-8")).hexdigest() # nosec current_hash = hashlib.sha1(schema.encode("utf-8")).hexdigest() # nosec
try: try:
await redis.execute_command("ft.info", cls.Meta.index_name) await conn.execute_command("ft.info", cls.Meta.index_name)
except ResponseError: except redis.ResponseError:
self.migrations.append( self.migrations.append(
IndexMigration( IndexMigration(
name, name,
@ -123,12 +124,12 @@ class Migrator:
schema, schema,
current_hash, current_hash,
MigrationAction.CREATE, MigrationAction.CREATE,
redis, conn,
) )
) )
continue continue
stored_hash = await redis.get(hash_key) stored_hash = await conn.get(hash_key)
schema_out_of_date = current_hash != stored_hash schema_out_of_date = current_hash != stored_hash
if schema_out_of_date: if schema_out_of_date:
@ -140,7 +141,7 @@ class Migrator:
schema, schema,
current_hash, current_hash,
MigrationAction.DROP, MigrationAction.DROP,
redis, conn,
stored_hash, stored_hash,
) )
) )
@ -151,7 +152,7 @@ class Migrator:
schema, schema,
current_hash, current_hash,
MigrationAction.CREATE, MigrationAction.CREATE,
redis, conn,
stored_hash, stored_hash,
) )
) )

@ -24,8 +24,6 @@ from typing import (
no_type_check, no_type_check,
) )
import aioredis
from aioredis.client import Pipeline
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
from pydantic.fields import FieldInfo as PydanticFieldInfo from pydantic.fields import FieldInfo as PydanticFieldInfo
from pydantic.fields import ModelField, Undefined, UndefinedType from pydantic.fields import ModelField, Undefined, UndefinedType
@ -35,9 +33,10 @@ from pydantic.utils import Representation
from typing_extensions import Protocol, get_args, get_origin from typing_extensions import Protocol, get_args, get_origin
from ulid import ULID from ulid import ULID
from .. import redis
from ..checks import has_redis_json, has_redisearch from ..checks import has_redis_json, has_redisearch
from ..connections import get_redis_connection from ..connections import get_redis_connection
from ..unasync_util import ASYNC_MODE from ..util import ASYNC_MODE
from .encoders import jsonable_encoder from .encoders import jsonable_encoder
from .render_tree import render_tree from .render_tree import render_tree
from .token_escaper import TokenEscaper from .token_escaper import TokenEscaper
@ -978,7 +977,7 @@ class BaseMeta(Protocol):
global_key_prefix: str global_key_prefix: str
model_key_prefix: str model_key_prefix: str
primary_key_pattern: str primary_key_pattern: str
database: aioredis.Redis database: redis.Redis
primary_key: PrimaryKey primary_key: PrimaryKey
primary_key_creator_cls: Type[PrimaryKeyCreator] primary_key_creator_cls: Type[PrimaryKeyCreator]
index_name: str index_name: str
@ -997,7 +996,7 @@ class DefaultMeta:
global_key_prefix: Optional[str] = None global_key_prefix: Optional[str] = None
model_key_prefix: Optional[str] = None model_key_prefix: Optional[str] = None
primary_key_pattern: Optional[str] = None primary_key_pattern: Optional[str] = None
database: Optional[aioredis.Redis] = None database: Optional[redis.Redis] = None
primary_key: Optional[PrimaryKey] = None primary_key: Optional[PrimaryKey] = None
primary_key_creator_cls: Optional[Type[PrimaryKeyCreator]] = None primary_key_creator_cls: Optional[Type[PrimaryKeyCreator]] = None
index_name: Optional[str] = None index_name: Optional[str] = None
@ -1130,10 +1129,14 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
"""Update this model instance with the specified key-value pairs.""" """Update this model instance with the specified key-value pairs."""
raise NotImplementedError raise NotImplementedError
async def save(self, pipeline: Optional[Pipeline] = None) -> "RedisModel": async def save(
self, pipeline: Optional[redis.client.Pipeline] = None
) -> "RedisModel":
raise NotImplementedError raise NotImplementedError
async def expire(self, num_seconds: int, pipeline: Optional[Pipeline] = None): async def expire(
self, num_seconds: int, pipeline: Optional[redis.client.Pipeline] = None
):
if pipeline is None: if pipeline is None:
db = self.db() db = self.db()
else: else:
@ -1226,7 +1229,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
async def add( async def add(
cls, cls,
models: Sequence["RedisModel"], models: Sequence["RedisModel"],
pipeline: Optional[Pipeline] = None, pipeline: Optional[redis.client.Pipeline] = None,
pipeline_verifier: Callable[..., Any] = verify_pipeline_response, pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
) -> Sequence["RedisModel"]: ) -> Sequence["RedisModel"]:
if pipeline is None: if pipeline is None:
@ -1286,7 +1289,9 @@ class HashModel(RedisModel, abc.ABC):
f"HashModels cannot index dataclass fields. Field: {name}" f"HashModels cannot index dataclass fields. Field: {name}"
) )
async def save(self, pipeline: Optional[Pipeline] = None) -> "HashModel": async def save(
self, pipeline: Optional[redis.client.Pipeline] = None
) -> "HashModel":
self.check() self.check()
if pipeline is None: if pipeline is None:
db = self.db() db = self.db()
@ -1458,7 +1463,9 @@ class JsonModel(RedisModel, abc.ABC):
) )
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
async def save(self, pipeline: Optional[Pipeline] = None) -> "JsonModel": async def save(
self, pipeline: Optional[redis.client.Pipeline] = None
) -> "JsonModel":
self.check() self.check()
if pipeline is None: if pipeline is None:
db = self.db() db = self.db()

@ -0,0 +1 @@
import redis

@ -1,41 +0,0 @@
"""Set of utility functions for unasync that transform into sync counterparts cleanly"""
import inspect
_original_next = next
def is_async_mode():
"""Tests if we're in the async part of the code or not"""
async def f():
"""Unasync transforms async functions in sync functions"""
return None
obj = f()
if obj is None:
return False
else:
obj.close() # prevent unawaited coroutine warning
return True
ASYNC_MODE = is_async_mode()
async def anext(x):
return await x.__anext__()
async def await_if_coro(x):
if inspect.iscoroutine(x):
return await x
return x
next = _original_next
def return_non_coro(x):
return x

@ -0,0 +1,12 @@
import inspect
def is_async_mode():
async def f():
"""Unasync transforms async functions in sync functions"""
return None
return inspect.iscoroutinefunction(f)
ASYNC_MODE = is_async_mode()

@ -5,7 +5,7 @@ import unasync
ADDITIONAL_REPLACEMENTS = { ADDITIONAL_REPLACEMENTS = {
"aredis_om": "redis_om", "aredis_om": "redis_om",
"aioredis": "redis", "async_redis": "sync_redis",
":tests.": ":tests_sync.", ":tests.": ":tests_sync.",
"pytest_asyncio": "pytest", "pytest_asyncio": "pytest",
"py_test_mark_asyncio": "py_test_mark_sync", "py_test_mark_asyncio": "py_test_mark_sync",
@ -26,11 +26,12 @@ def main():
), ),
] ]
filepaths = [] filepaths = []
for root, _, filenames in os.walk( for root, _, filenames in os.walk(Path(__file__).absolute().parent):
Path(__file__).absolute().parent
):
for filename in filenames: for filename in filenames:
if filename.rpartition(".")[-1] in ("py", "pyi",): if filename.rpartition(".")[-1] in (
"py",
"pyi",
):
filepaths.append(os.path.join(root, filename)) filepaths.append(os.path.join(root, filename))
unasync.unasync_files(filepaths, rules) unasync.unasync_files(filepaths, rules)

797
poetry.lock generated

File diff suppressed because it is too large Load Diff

@ -23,8 +23,7 @@ include=[
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.7" python = "^3.7"
redis = ">=3.5.3,<5.0.0" redis = ">=4.2.0,<5.0.0"
aioredis = "^2.0.0"
pydantic = "^1.8.2" pydantic = "^1.8.2"
click = "^8.0.1" click = "^8.0.1"
pptree = "^3.1" pptree = "^3.1"

@ -23,7 +23,9 @@ from aredis_om import (
# We need to run this check as sync code (during tests) even in async mode # We need to run this check as sync code (during tests) even in async mode
# because we call it in the top-level module scope. # because we call it in the top-level module scope.
from redis_om import has_redisearch from redis_om import has_redisearch
from tests.conftest import py_test_mark_asyncio
from .conftest import py_test_mark_asyncio
if not has_redisearch(): if not has_redisearch():
pytestmark = pytest.mark.skip pytestmark = pytest.mark.skip

@ -25,7 +25,9 @@ from aredis_om import (
# We need to run this check as sync code (during tests) even in async mode # We need to run this check as sync code (during tests) even in async mode
# because we call it in the top-level module scope. # because we call it in the top-level module scope.
from redis_om import has_redis_json from redis_om import has_redis_json
from tests.conftest import py_test_mark_asyncio
from .conftest import py_test_mark_asyncio
if not has_redis_json(): if not has_redis_json():
pytestmark = pytest.mark.skip pytestmark = pytest.mark.skip

@ -9,7 +9,8 @@ import pytest_asyncio
from pydantic import ValidationError from pydantic import ValidationError
from aredis_om import HashModel, Migrator, NotFoundError, RedisModelError from aredis_om import HashModel, Migrator, NotFoundError, RedisModelError
from tests.conftest import py_test_mark_asyncio
from .conftest import py_test_mark_asyncio
today = datetime.date.today() today = datetime.date.today()

@ -0,0 +1,10 @@
from aredis_om import redis
from aredis_om.util import ASYNC_MODE
def test_redis_type():
import redis as sync_redis_module
import redis.asyncio as async_redis_module
mapping = {True: async_redis_module, False: sync_redis_module}
assert mapping[ASYNC_MODE] is redis
Loading…
Cancel
Save