diff --git a/Makefile b/Makefile index fa06195..9c88b0b 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -NAME := redis_developer +NAME := redis_om INSTALL_STAMP := .install.stamp POETRY := $(shell command -v poetry 2> /dev/null) diff --git a/README.md b/README.md index 246dfd4..950ad06 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ Check out this example: import datetime from typing import Optional -from redis_developer.model import ( +from redis_om.model import ( EmbeddedJsonModel, JsonModel, Field, @@ -172,9 +172,9 @@ Don't want to run Redis yourself? RediSearch and RedisJSON are also available on We'd love your contributions! -**Bug reports** are especially helpful at this stage of the project. [You can open a bug report on GitHub](https://github.com/redis-developer/redis-developer-python/issues/new). +**Bug reports** are especially helpful at this stage of the project. [You can open a bug report on GitHub](https://github.com/redis-om/redis-om-python/issues/new). -You can also **contribute documentation** -- or just let us know if something needs more detail. [Open an issue on GitHub](https://github.com/redis-developer/redis-developer-python/issues/new) to get started. +You can also **contribute documentation** -- or just let us know if something needs more detail. [Open an issue on GitHub](https://github.com/redis-om/redis-om-python/issues/new) to get started. ## License @@ -184,17 +184,17 @@ Redis OM is [MIT licensed][license-url]. [version-svg]: https://img.shields.io/pypi/v/redis-om?style=flat-square [package-url]: https://pypi.org/project/redis-om/ -[ci-svg]: https://img.shields.io/github/workflow/status/redis-developer/redis-developer-python/python?style=flat-square -[ci-url]: https://github.com/redis-developer/redis-developer-python/actions/workflows/build.yml +[ci-svg]: https://img.shields.io/github/workflow/status/redis-om/redis-om-python/python?style=flat-square +[ci-url]: https://github.com/redis-om/redis-om-python/actions/workflows/build.yml [license-image]: http://img.shields.io/badge/license-MIT-green.svg?style=flat-square [license-url]: LICENSE -[redis-developer-website]: https://developer.redis.com -[redis-om-js]: https://github.com/redis-developer/redis-om-js -[redis-om-dotnet]: https://github.com/redis-developer/redis-om-dotnet -[redis-om-spring]: https://github.com/redis-developer/redis-om-spring +[redis-om-website]: https://developer.redis.com +[redis-om-js]: https://github.com/redis-om/redis-om-js +[redis-om-dotnet]: https://github.com/redis-om/redis-om-dotnet +[redis-om-spring]: https://github.com/redis-om/redis-om-spring [redisearch-url]: https://oss.redis.com/redisearch/ [redis-json-url]: https://oss.redis.com/redisjson/ [pydantic-url]: https://github.com/samuelcolvin/pydantic diff --git a/build.py b/build.py new file mode 100644 index 0000000..12a8f25 --- /dev/null +++ b/build.py @@ -0,0 +1,10 @@ +import unasync + + +def build(setup_kwargs): + setup_kwargs.update( + {"cmdclass": {'build_py': unasync.cmdclass_build_py(rules=[ + unasync.Rule("/aredis_om/", "/redis_om/"), + unasync.Rule("/aredis_om/tests/", "/redis_om/tests/", additional_replacements={"aredis_om": "redis_om"}), + ])}} + ) diff --git a/poetry.lock b/poetry.lock index ae97c84..c93d667 100644 --- a/poetry.lock +++ b/poetry.lock @@ -538,6 +538,20 @@ toml = "*" [package.extras] testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.16.0" +description = "Pytest support for asyncio." +category = "dev" +optional = false +python-versions = ">= 3.6" + +[package.dependencies] +pytest = ">=5.4.0" + +[package.extras] +testing = ["coverage", "hypothesis (>=5.7.1)"] + [[package]] name = "pytest-cov" version = "3.0.0" @@ -707,6 +721,14 @@ category = "main" optional = false python-versions = "*" +[[package]] +name = "unasync" +version = "0.5.0" +description = "The async transformation code." +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4" + [[package]] name = "wcwidth" version = "0.2.5" @@ -726,7 +748,7 @@ python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "56b381dd9b79bd082e978019124176491c63f09dd5ce90e5f8ab642a7f79480f" +content-hash = "d2d83b8cd3b094879e1aeb058d0036203942143f12fafa8be03fb0c79460028f" [metadata.files] aioredis = [ @@ -1003,6 +1025,10 @@ pytest = [ {file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"}, {file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"}, ] +pytest-asyncio = [ + {file = "pytest-asyncio-0.16.0.tar.gz", hash = "sha256:7496c5977ce88c34379df64a66459fe395cd05543f0a2f837016e7144391fcfb"}, + {file = "pytest_asyncio-0.16.0-py3-none-any.whl", hash = "sha256:5f2a21273c47b331ae6aa5b36087047b4899e40f03f18397c0e65fa5cca54e9b"}, +] pytest-cov = [ {file = "pytest-cov-3.0.0.tar.gz", hash = "sha256:e7f0f5b1617d2210a2cabc266dfe2f4c75a8d32fb89eafb7ad9d06f6d076d470"}, {file = "pytest_cov-3.0.0-py3-none-any.whl", hash = "sha256:578d5d15ac4a25e5f961c938b85a05b09fdaae9deef3bb6de9a6e766622ca7a6"}, @@ -1148,6 +1174,10 @@ typing-extensions = [ {file = "typing_extensions-3.10.0.2-py3-none-any.whl", hash = "sha256:f1d25edafde516b146ecd0613dabcc61409817af4766fbbcfb8d1ad4ec441a34"}, {file = "typing_extensions-3.10.0.2.tar.gz", hash = "sha256:49f75d16ff11f1cd258e1b988ccff82a3ca5570217d7ad8c5f48205dd99a677e"}, ] +unasync = [ + {file = "unasync-0.5.0-py3-none-any.whl", hash = "sha256:8d4536dae85e87b8751dfcc776f7656fd0baf54bb022a7889440dc1b9dc3becb"}, + {file = "unasync-0.5.0.tar.gz", hash = "sha256:b675d87cf56da68bd065d3b7a67ac71df85591978d84c53083c20d79a7e5096d"}, +] wcwidth = [ {file = "wcwidth-0.2.5-py2.py3-none-any.whl", hash = "sha256:beb4802a9cebb9144e99086eff703a642a13d6a0052920003a230f3294bbe784"}, {file = "wcwidth-0.2.5.tar.gz", hash = "sha256:c4d647b99872929fdb7bdcaa4fbe7f01413ed3d98077df798530e5b04f116c83"}, diff --git a/pyproject.toml b/pyproject.toml index 4f56fe3..56be6de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,10 @@ [tool.poetry] -name = "redis-developer" +name = "redis-om" version = "0.1.0" description = "A high-level library containing useful Redis abstractions and tools, like an ORM and leaderboard." authors = ["Andrew Brookins "] license = "MIT" +build = "build.py" [tool.poetry.dependencies] python = "^3.8" @@ -30,10 +31,11 @@ bandit = "^1.7.0" coverage = "^6.0.2" pytest-cov = "^3.0.0" pytest-xdist = "^2.4.0" - +unasync = "^0.5.0" +pytest-asyncio = "^0.16.0" [tool.poetry.scripts] -migrate = "redis_developer.orm.cli.migrate:migrate" +migrate = "redis_om.orm.cli.migrate:migrate" [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/redis_developer/__init__.py b/redis_om/__init__.py similarity index 100% rename from redis_developer/__init__.py rename to redis_om/__init__.py diff --git a/redis_developer/connections.py b/redis_om/connections.py similarity index 57% rename from redis_developer/connections.py rename to redis_om/connections.py index 07858c3..80eecf3 100644 --- a/redis_developer/connections.py +++ b/redis_om/connections.py @@ -1,22 +1,28 @@ import os +from typing import Union import dotenv +import aioredis import redis - +from redis_om.unasync_util import ASYNC_MODE dotenv.load_dotenv() URL = os.environ.get("REDIS_OM_URL", None) +if ASYNC_MODE: + client = aioredis.Redis +else: + client = redis.Redis -def get_redis_connection(**kwargs) -> redis.Redis: +def get_redis_connection(**kwargs) -> Union[aioredis.Redis, redis.Redis]: # If someone passed in a 'url' parameter, or specified a REDIS_OM_URL # environment variable, we'll create the Redis client from the URL. url = kwargs.pop("url", URL) if url: - return redis.from_url(url, **kwargs) + return client.from_url(url, **kwargs) # Decode from UTF-8 by default if "decode_responses" not in kwargs: kwargs["decode_responses"] = True - return redis.Redis(**kwargs) + return client(**kwargs) diff --git a/redis_developer/model/__init__.py b/redis_om/model/__init__.py similarity index 100% rename from redis_developer/model/__init__.py rename to redis_om/model/__init__.py diff --git a/redis_developer/model/cli/__init__.py b/redis_om/model/cli/__init__.py similarity index 100% rename from redis_developer/model/cli/__init__.py rename to redis_om/model/cli/__init__.py diff --git a/redis_developer/model/cli/migrate.py b/redis_om/model/cli/migrate.py similarity index 72% rename from redis_developer/model/cli/migrate.py rename to redis_om/model/cli/migrate.py index 28880c0..5c3c442 100644 --- a/redis_developer/model/cli/migrate.py +++ b/redis_om/model/cli/migrate.py @@ -1,10 +1,10 @@ import click -from redis_developer.model.migrations.migrator import Migrator +from redis_om.model.migrations.migrator import Migrator @click.command() -@click.option("--module", default="redis_developer") +@click.option("--module", default="redis_om") def migrate(module): migrator = Migrator(module) diff --git a/redis_developer/model/encoders.py b/redis_om/model/encoders.py similarity index 100% rename from redis_developer/model/encoders.py rename to redis_om/model/encoders.py diff --git a/redis_developer/model/migrations/__init__.py b/redis_om/model/migrations/__init__.py similarity index 100% rename from redis_developer/model/migrations/__init__.py rename to redis_om/model/migrations/__init__.py diff --git a/redis_developer/model/migrations/migrator.py b/redis_om/model/migrations/migrator.py similarity index 69% rename from redis_developer/model/migrations/migrator.py rename to redis_om/model/migrations/migrator.py index 027a6c3..0f11e11 100644 --- a/redis_developer/model/migrations/migrator.py +++ b/redis_om/model/migrations/migrator.py @@ -2,15 +2,14 @@ import hashlib import logging from dataclasses import dataclass from enum import Enum -from typing import Optional +from typing import Optional, Union -from redis import ResponseError +from redis import ResponseError, Redis +from aioredis import ResponseError as AResponseError, Redis as ARedis -from redis_developer.connections import get_redis_connection -from redis_developer.model.model import model_registry +from redis_om.model.model import model_registry -redis = get_redis_connection() log = logging.getLogger(__name__) @@ -43,12 +42,12 @@ def schema_hash_key(index_name): return f"{index_name}:hash" -def create_index(index_name, schema, current_hash): +async def create_index(redis: Union[Redis, ARedis], index_name, schema, current_hash): try: - redis.execute_command(f"ft.info {index_name}") - except ResponseError: - redis.execute_command(f"ft.create {index_name} {schema}") - redis.set(schema_hash_key(index_name), current_hash) + await redis.execute_command(f"ft.info {index_name}") + except (ResponseError, AResponseError): + await redis.execute_command(f"ft.create {index_name} {schema}") + await redis.set(schema_hash_key(index_name), current_hash) else: log.info("Index already exists, skipping. Index hash: %s", index_name) @@ -65,34 +64,38 @@ class IndexMigration: schema: str hash: str action: MigrationAction + redis: Union[Redis, ARedis] previous_hash: Optional[str] = None - def run(self): + async def run(self): if self.action is MigrationAction.CREATE: - self.create() + await self.create() elif self.action is MigrationAction.DROP: - self.drop() + await self.drop() - def create(self): + async def create(self): try: - return create_index(self.index_name, self.schema, self.hash) + await create_index(self.redis, self.index_name, self.schema, self.hash) except ResponseError: log.info("Index already exists: %s", self.index_name) - def drop(self): + async def drop(self): try: - redis.execute_command(f"FT.DROPINDEX {self.index_name}") + await self.redis.execute_command(f"FT.DROPINDEX {self.index_name}") except ResponseError: log.info("Index does not exist: %s", self.index_name) class Migrator: - def __init__(self, module=None): - # Try to load any modules found under the given path or module name. - if module: - import_submodules(module) - + def __init__(self, redis: Union[Redis, ARedis], module=None): + self.module = module self.migrations = [] + self.redis = redis + + async def run(self): + # Try to load any modules found under the given path or module name. + if self.module: + import_submodules(self.module) for name, cls in model_registry.items(): hash_key = schema_hash_key(cls.Meta.index_name) @@ -104,8 +107,8 @@ class Migrator: current_hash = hashlib.sha1(schema.encode("utf-8")).hexdigest() # nosec try: - redis.execute_command("ft.info", cls.Meta.index_name) - except ResponseError: + await self.redis.execute_command("ft.info", cls.Meta.index_name) + except (ResponseError, AResponseError): self.migrations.append( IndexMigration( name, @@ -113,11 +116,12 @@ class Migrator: schema, current_hash, MigrationAction.CREATE, + self.redis ) ) continue - stored_hash = redis.get(hash_key) + stored_hash = self.redis.get(hash_key) schema_out_of_date = current_hash != stored_hash if schema_out_of_date: @@ -129,7 +133,8 @@ class Migrator: schema, current_hash, MigrationAction.DROP, - stored_hash, + self.redis, + stored_hash ) ) self.migrations.append( @@ -139,12 +144,12 @@ class Migrator: schema, current_hash, MigrationAction.CREATE, - stored_hash, + self.redis, + stored_hash ) ) - def run(self): # TODO: Migration history # TODO: Dry run with output for migration in self.migrations: - migration.run() + await migration.run() diff --git a/redis_developer/model/model.py b/redis_om/model/model.py similarity index 94% rename from redis_developer/model/model.py rename to redis_om/model/model.py index 168f49c..4d295c5 100644 --- a/redis_developer/model/model.py +++ b/redis_om/model/model.py @@ -4,7 +4,7 @@ import decimal import json import logging import operator -from copy import copy, deepcopy +from copy import copy from enum import Enum from functools import reduce from typing import ( @@ -27,6 +27,7 @@ from typing import ( no_type_check, ) +import aioredis import redis from pydantic import BaseModel, validator from pydantic.fields import FieldInfo as PydanticFieldInfo @@ -37,11 +38,11 @@ from pydantic.utils import Representation from redis.client import Pipeline from ulid import ULID -from ..connections import get_redis_connection +from redis_om.connections import get_redis_connection from .encoders import jsonable_encoder from .render_tree import render_tree from .token_escaper import TokenEscaper - +from ..unasync_util import ASYNC_MODE model_registry = {} _T = TypeVar("_T") @@ -521,7 +522,7 @@ class FindQuery: # this is not going to work. log.warning( "Your query against the field %s is for a single character, %s, " - "that is used internally by redis-developer-python. We must ignore " + "that is used internally by redis-om-python. We must ignore " "this portion of the query. Please review your query to find " "an alternative query that uses a string containing more than " "just the character %s.", @@ -680,7 +681,7 @@ class FindQuery: return result - def execute(self, exhaust_results=True): + async def execute(self, exhaust_results=True): args = ["ft.search", self.model.Meta.index_name, self.query, *self.pagination] if self.sort_fields: args += self.resolve_redisearch_sort_fields() @@ -691,7 +692,7 @@ class FindQuery: # If the offset is greater than 0, we're paginating through a result set, # so append the new results to results already in the cache. - raw_result = self.model.db().execute_command(*args) + raw_result = await self.model.db().execute_command(*args) count = raw_result[0] results = self.model.from_redis(raw_result) self._model_cache += results @@ -710,31 +711,31 @@ class FindQuery: # Make a query for each pass of the loop, with a new offset equal to the # current offset plus `page_size`, until we stop getting results back. query = query.copy(offset=query.offset + query.page_size) - _results = query.execute(exhaust_results=False) + _results = await query.execute(exhaust_results=False) if not _results: break self._model_cache += _results return self._model_cache - def first(self): + async def first(self): query = self.copy(offset=0, limit=1, sort_fields=self.sort_fields) - results = query.execute() + results = await query.execute() if not results: raise NotFoundError() return results[0] - def all(self, batch_size=10): + async def all(self, batch_size=10): if batch_size != self.page_size: query = self.copy(page_size=batch_size, limit=batch_size) - return query.execute() - return self.execute() + return await query.execute() + return await self.execute() def sort_by(self, *fields: str): if not fields: return self return self.copy(sort_fields=list(fields)) - def update(self, use_transaction=True, **field_values): + async def update(self, use_transaction=True, **field_values): """ Update models that match this query to the given field-value pairs. @@ -743,31 +744,32 @@ class FindQuery: given fields. """ validate_model_fields(self.model, field_values) - pipeline = self.model.db().pipeline() if use_transaction else None + pipeline = await self.model.db().pipeline() if use_transaction else None - for model in self.all(): + # TODO: async for here? + for model in await self.all(): for field, value in field_values.items(): setattr(model, field, value) # TODO: In the non-transaction case, can we do more to detect # failure responses from Redis? - model.save(pipeline=pipeline) + await model.save(pipeline=pipeline) if pipeline: # TODO: Response type? # TODO: Better error detection for transactions. pipeline.execute() - def delete(self): + async def delete(self): """Delete all matching records in this query.""" # TODO: Better response type, error detection - return self.model.db().delete(*[m.key() for m in self.all()]) + return await self.model.db().delete(*[m.key() for m in await self.all()]) - def __iter__(self): + async def __aiter__(self): if self._model_cache: for m in self._model_cache: yield m else: - for m in self.execute(): + for m in await self.execute(): yield m def __getitem__(self, item: int): @@ -784,12 +786,39 @@ class FindQuery: that result, then we should clone the current query and give it a new offset and limit: offset=n, limit=1. """ + if ASYNC_MODE: + raise QuerySyntaxError("Cannot use [] notation with async code. " + "Use FindQuery.get_item() instead.") if self._model_cache and len(self._model_cache) >= item: return self._model_cache[item] query = self.copy(offset=item, limit=1) - return query.execute()[0] + return query.execute()[0] # noqa + + async def get_item(self, item: int): + """ + Given this code: + await Model.find().get_item(1000) + + We should return only the 1000th result. + + 1. If the result is loaded in the query cache for this query, + we can return it directly from the cache. + + 2. If the query cache does not have enough elements to return + that result, then we should clone the current query and + give it a new offset and limit: offset=n, limit=1. + + NOTE: This method is included specifically for async users, who + cannot use the notation Model.find()[1000]. + """ + if self._model_cache and len(self._model_cache) >= item: + return self._model_cache[item] + + query = self.copy(offset=item, limit=1) + result = await query.execute() + return result[0] class PrimaryKeyCreator(Protocol): @@ -913,7 +942,7 @@ class MetaProtocol(Protocol): global_key_prefix: str model_key_prefix: str primary_key_pattern: str - database: redis.Redis + database: aioredis.Redis primary_key: PrimaryKey primary_key_creator_cls: Type[PrimaryKeyCreator] index_name: str @@ -932,7 +961,7 @@ class DefaultMeta: global_key_prefix: Optional[str] = None model_key_prefix: Optional[str] = None primary_key_pattern: Optional[str] = None - database: Optional[redis.Redis] = None + database: Optional[Union[redis.Redis, aioredis.Redis]] = None primary_key: Optional[PrimaryKey] = None primary_key_creator_cls: Optional[Type[PrimaryKeyCreator]] = None index_name: Optional[str] = None @@ -1049,14 +1078,18 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): pk = getattr(self, self._meta.primary_key.field.name) return self.make_primary_key(pk) - def delete(self): - return self.db().delete(self.key()) + async def delete(self): + return await self.db().delete(self.key()) - def update(self, **field_values): + @classmethod + async def get(cls, pk: Any) -> 'RedisModel': + raise NotImplementedError + + async def update(self, **field_values): """Update this model instance with the specified key-value pairs.""" raise NotImplementedError - def save(self, pipeline: Optional[Pipeline] = None) -> "RedisModel": + async def save(self, pipeline: Optional[Pipeline] = None) -> "RedisModel": raise NotImplementedError @validator("pk", always=True) @@ -1158,9 +1191,9 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): return d @classmethod - def add(cls, models: Sequence["RedisModel"]) -> Sequence["RedisModel"]: + async def add(cls, models: Sequence["RedisModel"]) -> Sequence["RedisModel"]: # TODO: Add transaction support - return [model.save() for model in models] + return [await model.save() for model in models] @classmethod def values(cls): @@ -1189,17 +1222,18 @@ class HashModel(RedisModel, abc.ABC): f" or mapping fields. Field: {name}" ) - def save(self, pipeline: Optional[Pipeline] = None) -> "HashModel": + async def save(self, pipeline: Optional[Pipeline] = None) -> "HashModel": if pipeline is None: db = self.db() else: db = pipeline document = jsonable_encoder(self.dict()) - db.hset(self.key(), mapping=document) + # TODO: Wrap any Redis response errors in a custom exception? + await db.hset(self.key(), mapping=document) return self @classmethod - def get(cls, pk: Any) -> "HashModel": + async def get(cls, pk: Any) -> "HashModel": document = cls.db().hgetall(cls.make_primary_key(pk)) if not document: raise NotFoundError @@ -1311,23 +1345,24 @@ class JsonModel(RedisModel, abc.ABC): # Generate the RediSearch schema once to validate fields. cls.redisearch_schema() - def save(self, pipeline: Optional[Pipeline] = None) -> "JsonModel": + async def save(self, pipeline: Optional[Pipeline] = None) -> "JsonModel": if pipeline is None: db = self.db() else: db = pipeline - db.execute_command("JSON.SET", self.key(), ".", self.json()) + # TODO: Wrap response errors in a custom exception? + await db.execute_command("JSON.SET", self.key(), ".", self.json()) return self - def update(self, **field_values): + async def update(self, **field_values): validate_model_fields(self.__class__, field_values) for field, value in field_values.items(): setattr(self, field, value) - self.save() + await self.save() @classmethod - def get(cls, pk: Any) -> "JsonModel": - document = cls.db().execute_command("JSON.GET", cls.make_primary_key(pk)) + async def get(cls, pk: Any) -> "JsonModel": + document = await cls.db().execute_command("JSON.GET", cls.make_primary_key(pk)) if not document: raise NotFoundError return cls.parse_raw(document) diff --git a/redis_developer/model/models.py b/redis_om/model/models.py similarity index 75% rename from redis_developer/model/models.py rename to redis_om/model/models.py index c6c1a5c..81655e5 100644 --- a/redis_developer/model/models.py +++ b/redis_om/model/models.py @@ -1,17 +1,17 @@ import abc from typing import Optional -from redis_developer.model.model import HashModel, JsonModel +from redis_om.model.model import HashModel, JsonModel class BaseJsonModel(JsonModel, abc.ABC): class Meta: - global_key_prefix = "redis-developer" + global_key_prefix = "redis-om" class BaseHashModel(HashModel, abc.ABC): class Meta: - global_key_prefix = "redis-developer" + global_key_prefix = "redis-om" # class AddressJson(BaseJsonModel): diff --git a/redis_developer/model/query_resolver.py b/redis_om/model/query_resolver.py similarity index 97% rename from redis_developer/model/query_resolver.py rename to redis_om/model/query_resolver.py index 8616c92..f27fc36 100644 --- a/redis_developer/model/query_resolver.py +++ b/redis_om/model/query_resolver.py @@ -1,7 +1,7 @@ from collections import Sequence from typing import Any, Dict, List, Mapping, Union -from redis_developer.model.model import Expression +from redis_om.model.model import Expression class LogicalOperatorForListOfExpressions(Expression): diff --git a/redis_developer/model/render_tree.py b/redis_om/model/render_tree.py similarity index 100% rename from redis_developer/model/render_tree.py rename to redis_om/model/render_tree.py diff --git a/redis_developer/model/token_escaper.py b/redis_om/model/token_escaper.py similarity index 100% rename from redis_developer/model/token_escaper.py rename to redis_om/model/token_escaper.py diff --git a/redis_om/unasync_util.py b/redis_om/unasync_util.py new file mode 100644 index 0000000..093dcb3 --- /dev/null +++ b/redis_om/unasync_util.py @@ -0,0 +1,40 @@ +"""Set of utility functions for unasync that transform into sync counterparts cleanly""" + +import inspect + +_original_next = next + + +def is_async_mode(): + """Tests if we're in the async part of the code or not""" + + async def f(): + """Unasync transforms async functions in sync functions""" + return None + + obj = f() + if obj is None: + return False + else: + obj.close() # prevent unawaited coroutine warning + return True + + +ASYNC_MODE = is_async_mode() + + +async def anext(x): + return await x.__anext__() + + +async def await_if_coro(x): + if inspect.iscoroutine(x): + return await x + return x + + +next = _original_next + + +def return_non_coro(x): + return x \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 807deb3..ebb25c7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,19 +1,18 @@ import random import pytest -from redis import Redis -from redis_developer.connections import get_redis_connection +from redis_om.connections import get_redis_connection @pytest.fixture -def redis(): +def redis(event_loop): yield get_redis_connection() -def _delete_test_keys(prefix: str, conn: Redis): +async def _delete_test_keys(prefix: str, conn): keys = [] - for key in conn.scan_iter(f"{prefix}:*"): + async for key in conn.scan_iter(f"{prefix}:*"): keys.append(key) if keys: conn.delete(*keys) @@ -21,11 +20,10 @@ def _delete_test_keys(prefix: str, conn: Redis): @pytest.fixture def key_prefix(redis): - key_prefix = f"redis-developer:{random.random()}" + key_prefix = f"redis-om:{random.random()}" yield key_prefix - _delete_test_keys(key_prefix, redis) @pytest.fixture(autouse=True) -def delete_test_keys(redis, request, key_prefix): - _delete_test_keys(key_prefix, redis) +async def delete_test_keys(redis, request, key_prefix): + await _delete_test_keys(key_prefix, redis) diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index 3705a84..e67a42e 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -8,9 +8,9 @@ from unittest import mock import pytest from pydantic import ValidationError -from redis_developer.model import Field, HashModel -from redis_developer.model.migrations.migrator import Migrator -from redis_developer.model.model import ( +from redis_om.model import Field, HashModel +from redis_om.model.migrations.migrator import Migrator +from redis_om.model.model import ( NotFoundError, QueryNotSupportedError, RedisModelError, diff --git a/tests/test_json_model.py b/tests/test_json_model.py index ffa6594..b09a4de 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -1,4 +1,5 @@ import abc +import asyncio import datetime import decimal from collections import namedtuple @@ -8,9 +9,9 @@ from unittest import mock import pytest from pydantic import ValidationError -from redis_developer.model import EmbeddedJsonModel, Field, JsonModel -from redis_developer.model.migrations.migrator import Migrator -from redis_developer.model.model import ( +from redis_om.model import EmbeddedJsonModel, Field, JsonModel +from redis_om.model.migrations.migrator import Migrator +from redis_om.model.model import ( NotFoundError, QueryNotSupportedError, RedisModelError, @@ -21,7 +22,7 @@ today = datetime.date.today() @pytest.fixture -def m(key_prefix): +async def m(key_prefix, redis): class BaseJsonModel(JsonModel, abc.ABC): class Meta: global_key_prefix = key_prefix @@ -64,7 +65,7 @@ def m(key_prefix): # Creates an embedded list of models. orders: Optional[List[Order]] - Migrator().run() + await Migrator(redis).run() return namedtuple( "Models", ["BaseJsonModel", "Note", "Address", "Item", "Order", "Member"] @@ -83,7 +84,7 @@ def address(m): @pytest.fixture() -def members(address, m): +async def members(address, m): member1 = m.Member( first_name="Andrew", last_name="Brookins", @@ -111,14 +112,15 @@ def members(address, m): address=address, ) - member1.save() - member2.save() - member3.save() + await member1.save() + await member2.save() + await member3.save() yield member1, member2, member3 -def test_validates_required_fields(address, m): +@pytest.mark.asyncio +async def test_validates_required_fields(address, m): # Raises ValidationError address is required with pytest.raises(ValidationError): m.Member( @@ -129,7 +131,8 @@ def test_validates_required_fields(address, m): ) -def test_validates_field(address, m): +@pytest.mark.asyncio +async def test_validates_field(address, m): # Raises ValidationError: join_date is not a date with pytest.raises(ValidationError): m.Member( @@ -141,7 +144,8 @@ def test_validates_field(address, m): # Passes validation -def test_validation_passes(address, m): +@pytest.mark.asyncio +async def test_validation_passes(address, m): member = m.Member( first_name="Andrew", last_name="Brookins", @@ -153,7 +157,10 @@ def test_validation_passes(address, m): assert member.first_name == "Andrew" -def test_saves_model_and_creates_pk(address, m): +@pytest.mark.asyncio +async def test_saves_model_and_creates_pk(address, m, redis): + await Migrator(redis).run() + member = m.Member( first_name="Andrew", last_name="Brookins", @@ -163,15 +170,16 @@ def test_saves_model_and_creates_pk(address, m): address=address, ) # Save a model instance to Redis - member.save() + await member.save() - member2 = m.Member.get(member.pk) + member2 = await m.Member.get(member.pk) assert member2 == member assert member2.address == address @pytest.mark.skip("Not implemented yet") -def test_saves_many(address, m): +@pytest.mark.asyncio +async def test_saves_many(address, m): members = [ m.Member( first_name="Andrew", @@ -193,9 +201,16 @@ def test_saves_many(address, m): m.Member.add(members) +async def save(members): + for m in members: + await m.save() + return members + + @pytest.mark.skip("Not ready yet") -def test_updates_a_model(members, m): - member1, member2, member3 = members +@pytest.mark.asyncio +async def test_updates_a_model(members, m): + member1, member2, member3 = await save(members) # Or, with an implicit save: member1.update(last_name="Smith") @@ -213,18 +228,20 @@ def test_updates_a_model(members, m): ) -def test_paginate_query(members, m): +@pytest.mark.asyncio +async def test_paginate_query(members, m): member1, member2, member3 = members - actual = m.Member.find().sort_by("age").all(batch_size=1) + actual = await m.Member.find().sort_by("age").all(batch_size=1) assert actual == [member2, member1, member3] -def test_access_result_by_index_cached(members, m): +@pytest.mark.asyncio +async def test_access_result_by_index_cached(members, m): member1, member2, member3 = members query = m.Member.find().sort_by("age") # Load the cache, throw away the result. assert query._model_cache == [] - query.execute() + await query.execute() assert query._model_cache == [member2, member1, member3] # Access an item that should be in the cache. @@ -233,21 +250,23 @@ def test_access_result_by_index_cached(members, m): assert not mock_db.called -def test_access_result_by_index_not_cached(members, m): +@pytest.mark.asyncio +async def test_access_result_by_index_not_cached(members, m): member1, member2, member3 = members query = m.Member.find().sort_by("age") # Assert that we don't have any models in the cache yet -- we # haven't made any requests of Redis. assert query._model_cache == [] - assert query[0] == member2 - assert query[1] == member1 - assert query[2] == member3 + assert query.get_item(0) == member2 + assert query.get_item(1) == member1 + assert query.get_item(2) == member3 -def test_in_query(members, m): +@pytest.mark.asyncio +async def test_in_query(members, m): member1, member2, member3 = members - actual = ( + actual = await ( m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk]) .sort_by("age") .all() @@ -256,12 +275,13 @@ def test_in_query(members, m): @pytest.mark.skip("Not implemented yet") -def test_update_query(members, m): +@pytest.mark.asyncio +async def test_update_query(members, m): member1, member2, member3 = members - m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk]).update( + await m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk]).update( first_name="Bobby" ) - actual = ( + actual = await ( m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk]) .sort_by("age") .all() @@ -270,24 +290,25 @@ def test_update_query(members, m): assert all([m.name == "Bobby" for m in actual]) -def test_exact_match_queries(members, m): +@pytest.mark.asyncio +async def test_exact_match_queries(members, m): member1, member2, member3 = members - actual = m.Member.find(m.Member.last_name == "Brookins").sort_by("age").all() + actual = await m.Member.find(m.Member.last_name == "Brookins").sort_by("age").all() assert actual == [member2, member1] - actual = m.Member.find( + actual = await m.Member.find( (m.Member.last_name == "Brookins") & ~(m.Member.first_name == "Andrew") ).all() assert actual == [member2] - actual = m.Member.find(~(m.Member.last_name == "Brookins")).all() + actual = await m.Member.find(~(m.Member.last_name == "Brookins")).all() assert actual == [member3] - actual = m.Member.find(m.Member.last_name != "Brookins").all() + actual = await m.Member.find(m.Member.last_name != "Brookins").all() assert actual == [member3] - actual = ( + actual = await ( m.Member.find( (m.Member.last_name == "Brookins") & (m.Member.first_name == "Andrew") | (m.Member.first_name == "Kim") @@ -297,19 +318,20 @@ def test_exact_match_queries(members, m): ) assert actual == [member2, member1] - actual = m.Member.find( + actual = await m.Member.find( m.Member.first_name == "Kim", m.Member.last_name == "Brookins" ).all() assert actual == [member2] - actual = m.Member.find(m.Member.address.city == "Portland").sort_by("age").all() + actual = await m.Member.find(m.Member.address.city == "Portland").sort_by("age").all() assert actual == [member2, member1, member3] -def test_recursive_query_expression_resolution(members, m): +@pytest.mark.asyncio +async def test_recursive_query_expression_resolution(members, m): member1, member2, member3 = members - actual = ( + actual = await ( m.Member.find( (m.Member.last_name == "Brookins") | (m.Member.age == 100) & (m.Member.last_name == "Smith") @@ -320,13 +342,14 @@ def test_recursive_query_expression_resolution(members, m): assert actual == [member2, member1, member3] -def test_recursive_query_field_resolution(members, m): +@pytest.mark.asyncio +async def test_recursive_query_field_resolution(members, m): member1, _, _ = members member1.address.note = m.Note( description="Weird house", created_on=datetime.datetime.now() ) - member1.save() - actual = m.Member.find(m.Member.address.note.description == "Weird house").all() + await member1.save() + actual = await m.Member.find(m.Member.address.note.description == "Weird house").all() assert actual == [member1] member1.orders = [ @@ -336,29 +359,31 @@ def test_recursive_query_field_resolution(members, m): created_on=datetime.datetime.now(), ) ] - member1.save() - actual = m.Member.find(m.Member.orders.items.name == "Ball").all() + await member1.save() + actual = await m.Member.find(m.Member.orders.items.name == "Ball").all() assert actual == [member1] assert actual[0].orders[0].items[0].name == "Ball" -def test_full_text_search(members, m): +@pytest.mark.asyncio +async def test_full_text_search(members, m): member1, member2, _ = members - member1.update(bio="Hates sunsets, likes beaches") - member2.update(bio="Hates beaches, likes forests") + await member1.update(bio="Hates sunsets, likes beaches") + await member2.update(bio="Hates beaches, likes forests") - actual = m.Member.find(m.Member.bio % "beaches").sort_by("age").all() + actual = await m.Member.find(m.Member.bio % "beaches").sort_by("age").all() assert actual == [member2, member1] - actual = m.Member.find(m.Member.bio % "forests").all() + actual = await m.Member.find(m.Member.bio % "forests").all() assert actual == [member2] -def test_tag_queries_boolean_logic(members, m): +@pytest.mark.asyncio +async def test_tag_queries_boolean_logic(members, m): member1, member2, member3 = members actual = ( - m.Member.find( + await m.Member.find( (m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins") | (m.Member.last_name == "Smith") ) @@ -368,7 +393,8 @@ def test_tag_queries_boolean_logic(members, m): assert actual == [member1, member3] -def test_tag_queries_punctuation(address, m): +@pytest.mark.asyncio +async def test_tag_queries_punctuation(address, m): member1 = m.Member( first_name="Andrew, the Michael", last_name="St. Brookins-on-Pier", @@ -377,7 +403,7 @@ def test_tag_queries_punctuation(address, m): join_date=today, address=address, ) - member1.save() + await member1.save() member2 = m.Member( first_name="Bob", @@ -387,24 +413,25 @@ def test_tag_queries_punctuation(address, m): join_date=today, address=address, ) - member2.save() + await member2.save() assert ( - m.Member.find(m.Member.first_name == "Andrew, the Michael").first() == member1 + await m.Member.find(m.Member.first_name == "Andrew, the Michael").first() == member1 ) assert ( - m.Member.find(m.Member.last_name == "St. Brookins-on-Pier").first() == member1 + await m.Member.find(m.Member.last_name == "St. Brookins-on-Pier").first() == member1 ) # Notice that when we index and query multiple values that use the internal # TAG separator for single-value exact-match fields, like an indexed string, # the queries will succeed. We apply a workaround that queries for the union # of the two values separated by the tag separator. - assert m.Member.find(m.Member.email == "a|b@example.com").all() == [member1] - assert m.Member.find(m.Member.email == "a|villain@example.com").all() == [member2] + assert await m.Member.find(m.Member.email == "a|b@example.com").all() == [member1] + assert await m.Member.find(m.Member.email == "a|villain@example.com").all() == [member2] -def test_tag_queries_negation(members, m): +@pytest.mark.asyncio +async def test_tag_queries_negation(members, m): member1, member2, member3 = members """ @@ -414,7 +441,7 @@ def test_tag_queries_negation(members, m): """ query = m.Member.find(~(m.Member.first_name == "Andrew")) - assert query.all() == [member2] + assert await query.all() == [member2] """ ┌first_name @@ -429,7 +456,7 @@ def test_tag_queries_negation(members, m): query = m.Member.find( ~(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins") ) - assert query.all() == [member2] + assert await query.all() == [member2] """ ┌first_name @@ -448,7 +475,7 @@ def test_tag_queries_negation(members, m): ~(m.Member.first_name == "Andrew") & ((m.Member.last_name == "Brookins") | (m.Member.last_name == "Smith")) ) - assert query.all() == [member2] + assert await query.all() == [member2] """ ┌first_name @@ -467,67 +494,71 @@ def test_tag_queries_negation(members, m): ~(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins") | (m.Member.last_name == "Smith") ) - assert query.sort_by("age").all() == [member2, member3] + assert await query.sort_by("age").all() == [member2, member3] - actual = m.Member.find( + actual = await m.Member.find( (m.Member.first_name == "Andrew") & ~(m.Member.last_name == "Brookins") ).all() assert actual == [member3] -def test_numeric_queries(members, m): +@pytest.mark.asyncio +async def test_numeric_queries(members, m): member1, member2, member3 = members - actual = m.Member.find(m.Member.age == 34).all() + actual = await m.Member.find(m.Member.age == 34).all() assert actual == [member2] - actual = m.Member.find(m.Member.age > 34).all() + actual = await m.Member.find(m.Member.age > 34).all() assert actual == [member1, member3] - actual = m.Member.find(m.Member.age < 35).all() + actual = await m.Member.find(m.Member.age < 35).all() assert actual == [member2] - actual = m.Member.find(m.Member.age <= 34).all() + actual = await m.Member.find(m.Member.age <= 34).all() assert actual == [member2] - actual = m.Member.find(m.Member.age >= 100).all() + actual = await m.Member.find(m.Member.age >= 100).all() assert actual == [member3] - actual = m.Member.find(~(m.Member.age == 100)).sort_by("age").all() + actual = await m.Member.find(~(m.Member.age == 100)).sort_by("age").all() assert actual == [member2, member1] - actual = m.Member.find(m.Member.age > 30, m.Member.age < 40).sort_by("age").all() + actual = await m.Member.find(m.Member.age > 30, m.Member.age < 40).sort_by("age").all() assert actual == [member2, member1] - actual = m.Member.find(m.Member.age != 34).sort_by("age").all() + actual = await m.Member.find(m.Member.age != 34).sort_by("age").all() assert actual == [member1, member3] -def test_sorting(members, m): +@pytest.mark.asyncio +async def test_sorting(members, m): member1, member2, member3 = members - actual = m.Member.find(m.Member.age > 34).sort_by("age").all() + actual = await m.Member.find(m.Member.age > 34).sort_by("age").all() assert actual == [member1, member3] - actual = m.Member.find(m.Member.age > 34).sort_by("-age").all() + actual = await m.Member.find(m.Member.age > 34).sort_by("-age").all() assert actual == [member3, member1] with pytest.raises(QueryNotSupportedError): # This field does not exist. - m.Member.find().sort_by("not-a-real-field").all() + await m.Member.find().sort_by("not-a-real-field").all() with pytest.raises(QueryNotSupportedError): # This field is not sortable. - m.Member.find().sort_by("join_date").all() + await m.Member.find().sort_by("join_date").all() -def test_not_found(m): +@pytest.mark.asyncio +async def test_not_found(m): with pytest.raises(NotFoundError): # This ID does not exist. - m.Member.get(1000) + await m.Member.get(1000) -def test_list_field_limitations(m): +@pytest.mark.asyncio +async def test_list_field_limitations(m, redis): with pytest.raises(RedisModelError): class SortableTarotWitch(m.BaseJsonModel): @@ -571,15 +602,16 @@ def test_list_field_limitations(m): # We need to import and run this manually because we defined # our model classes within a function that runs after the test # suite's migrator has already looked for migrations to run. - Migrator().run() + await Migrator(redis).run() witch = TarotWitch(tarot_cards=["death"]) - witch.save() - actual = TarotWitch.find(TarotWitch.tarot_cards << "death").all() + await witch.save() + actual = await TarotWitch.find(TarotWitch.tarot_cards << "death").all() assert actual == [witch] -def test_schema(m, key_prefix): +@pytest.mark.asyncio +async def test_schema(m, key_prefix): assert ( m.Member.redisearch_schema() == f"ON JSON PREFIX 1 {key_prefix}:tests.test_json_model.Member: SCHEMA $.pk AS pk TAG SEPARATOR | $.first_name AS first_name TAG SEPARATOR | $.last_name AS last_name TAG SEPARATOR | $.email AS email TAG SEPARATOR | $.age AS age NUMERIC $.bio AS bio TAG SEPARATOR | $.bio AS bio_fts TEXT $.address.pk AS address_pk TAG SEPARATOR | $.address.city AS address_city TAG SEPARATOR | $.address.postal_code AS address_postal_code TAG SEPARATOR | $.address.note.pk AS address_note_pk TAG SEPARATOR | $.address.note.description AS address_note_description TAG SEPARATOR | $.orders[*].pk AS orders_pk TAG SEPARATOR | $.orders[*].items[*].pk AS orders_items_pk TAG SEPARATOR | $.orders[*].items[*].name AS orders_items_name TAG SEPARATOR |"