WIP on async - test failure due to closed event loop

This commit is contained in:
Andrew Brookins 2021-10-22 06:33:05 -07:00
parent 0f9f7aa868
commit b2c2dd9f6f
22 changed files with 348 additions and 190 deletions

View file

@ -1,4 +1,4 @@
NAME := redis_developer
NAME := redis_om
INSTALL_STAMP := .install.stamp
POETRY := $(shell command -v poetry 2> /dev/null)

View file

@ -52,7 +52,7 @@ Check out this example:
import datetime
from typing import Optional
from redis_developer.model import (
from redis_om.model import (
EmbeddedJsonModel,
JsonModel,
Field,
@ -172,9 +172,9 @@ Don't want to run Redis yourself? RediSearch and RedisJSON are also available on
We'd love your contributions!
**Bug reports** are especially helpful at this stage of the project. [You can open a bug report on GitHub](https://github.com/redis-developer/redis-developer-python/issues/new).
**Bug reports** are especially helpful at this stage of the project. [You can open a bug report on GitHub](https://github.com/redis-om/redis-om-python/issues/new).
You can also **contribute documentation** -- or just let us know if something needs more detail. [Open an issue on GitHub](https://github.com/redis-developer/redis-developer-python/issues/new) to get started.
You can also **contribute documentation** -- or just let us know if something needs more detail. [Open an issue on GitHub](https://github.com/redis-om/redis-om-python/issues/new) to get started.
## License
@ -184,17 +184,17 @@ Redis OM is [MIT licensed][license-url].
[version-svg]: https://img.shields.io/pypi/v/redis-om?style=flat-square
[package-url]: https://pypi.org/project/redis-om/
[ci-svg]: https://img.shields.io/github/workflow/status/redis-developer/redis-developer-python/python?style=flat-square
[ci-url]: https://github.com/redis-developer/redis-developer-python/actions/workflows/build.yml
[ci-svg]: https://img.shields.io/github/workflow/status/redis-om/redis-om-python/python?style=flat-square
[ci-url]: https://github.com/redis-om/redis-om-python/actions/workflows/build.yml
[license-image]: http://img.shields.io/badge/license-MIT-green.svg?style=flat-square
[license-url]: LICENSE
<!-- Links -->
[redis-developer-website]: https://developer.redis.com
[redis-om-js]: https://github.com/redis-developer/redis-om-js
[redis-om-dotnet]: https://github.com/redis-developer/redis-om-dotnet
[redis-om-spring]: https://github.com/redis-developer/redis-om-spring
[redis-om-website]: https://developer.redis.com
[redis-om-js]: https://github.com/redis-om/redis-om-js
[redis-om-dotnet]: https://github.com/redis-om/redis-om-dotnet
[redis-om-spring]: https://github.com/redis-om/redis-om-spring
[redisearch-url]: https://oss.redis.com/redisearch/
[redis-json-url]: https://oss.redis.com/redisjson/
[pydantic-url]: https://github.com/samuelcolvin/pydantic

10
build.py Normal file
View file

@ -0,0 +1,10 @@
import unasync
def build(setup_kwargs):
setup_kwargs.update(
{"cmdclass": {'build_py': unasync.cmdclass_build_py(rules=[
unasync.Rule("/aredis_om/", "/redis_om/"),
unasync.Rule("/aredis_om/tests/", "/redis_om/tests/", additional_replacements={"aredis_om": "redis_om"}),
])}}
)

32
poetry.lock generated
View file

@ -538,6 +538,20 @@ toml = "*"
[package.extras]
testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"]
[[package]]
name = "pytest-asyncio"
version = "0.16.0"
description = "Pytest support for asyncio."
category = "dev"
optional = false
python-versions = ">= 3.6"
[package.dependencies]
pytest = ">=5.4.0"
[package.extras]
testing = ["coverage", "hypothesis (>=5.7.1)"]
[[package]]
name = "pytest-cov"
version = "3.0.0"
@ -707,6 +721,14 @@ category = "main"
optional = false
python-versions = "*"
[[package]]
name = "unasync"
version = "0.5.0"
description = "The async transformation code."
category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4"
[[package]]
name = "wcwidth"
version = "0.2.5"
@ -726,7 +748,7 @@ python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7"
[metadata]
lock-version = "1.1"
python-versions = "^3.8"
content-hash = "56b381dd9b79bd082e978019124176491c63f09dd5ce90e5f8ab642a7f79480f"
content-hash = "d2d83b8cd3b094879e1aeb058d0036203942143f12fafa8be03fb0c79460028f"
[metadata.files]
aioredis = [
@ -1003,6 +1025,10 @@ pytest = [
{file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"},
{file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"},
]
pytest-asyncio = [
{file = "pytest-asyncio-0.16.0.tar.gz", hash = "sha256:7496c5977ce88c34379df64a66459fe395cd05543f0a2f837016e7144391fcfb"},
{file = "pytest_asyncio-0.16.0-py3-none-any.whl", hash = "sha256:5f2a21273c47b331ae6aa5b36087047b4899e40f03f18397c0e65fa5cca54e9b"},
]
pytest-cov = [
{file = "pytest-cov-3.0.0.tar.gz", hash = "sha256:e7f0f5b1617d2210a2cabc266dfe2f4c75a8d32fb89eafb7ad9d06f6d076d470"},
{file = "pytest_cov-3.0.0-py3-none-any.whl", hash = "sha256:578d5d15ac4a25e5f961c938b85a05b09fdaae9deef3bb6de9a6e766622ca7a6"},
@ -1148,6 +1174,10 @@ typing-extensions = [
{file = "typing_extensions-3.10.0.2-py3-none-any.whl", hash = "sha256:f1d25edafde516b146ecd0613dabcc61409817af4766fbbcfb8d1ad4ec441a34"},
{file = "typing_extensions-3.10.0.2.tar.gz", hash = "sha256:49f75d16ff11f1cd258e1b988ccff82a3ca5570217d7ad8c5f48205dd99a677e"},
]
unasync = [
{file = "unasync-0.5.0-py3-none-any.whl", hash = "sha256:8d4536dae85e87b8751dfcc776f7656fd0baf54bb022a7889440dc1b9dc3becb"},
{file = "unasync-0.5.0.tar.gz", hash = "sha256:b675d87cf56da68bd065d3b7a67ac71df85591978d84c53083c20d79a7e5096d"},
]
wcwidth = [
{file = "wcwidth-0.2.5-py2.py3-none-any.whl", hash = "sha256:beb4802a9cebb9144e99086eff703a642a13d6a0052920003a230f3294bbe784"},
{file = "wcwidth-0.2.5.tar.gz", hash = "sha256:c4d647b99872929fdb7bdcaa4fbe7f01413ed3d98077df798530e5b04f116c83"},

View file

@ -1,9 +1,10 @@
[tool.poetry]
name = "redis-developer"
name = "redis-om"
version = "0.1.0"
description = "A high-level library containing useful Redis abstractions and tools, like an ORM and leaderboard."
authors = ["Andrew Brookins <andrew.brookins@redislabs.com>"]
license = "MIT"
build = "build.py"
[tool.poetry.dependencies]
python = "^3.8"
@ -30,10 +31,11 @@ bandit = "^1.7.0"
coverage = "^6.0.2"
pytest-cov = "^3.0.0"
pytest-xdist = "^2.4.0"
unasync = "^0.5.0"
pytest-asyncio = "^0.16.0"
[tool.poetry.scripts]
migrate = "redis_developer.orm.cli.migrate:migrate"
migrate = "redis_om.orm.cli.migrate:migrate"
[build-system]
requires = ["poetry-core>=1.0.0"]

View file

@ -1,22 +1,28 @@
import os
from typing import Union
import dotenv
import aioredis
import redis
from redis_om.unasync_util import ASYNC_MODE
dotenv.load_dotenv()
URL = os.environ.get("REDIS_OM_URL", None)
if ASYNC_MODE:
client = aioredis.Redis
else:
client = redis.Redis
def get_redis_connection(**kwargs) -> redis.Redis:
def get_redis_connection(**kwargs) -> Union[aioredis.Redis, redis.Redis]:
# If someone passed in a 'url' parameter, or specified a REDIS_OM_URL
# environment variable, we'll create the Redis client from the URL.
url = kwargs.pop("url", URL)
if url:
return redis.from_url(url, **kwargs)
return client.from_url(url, **kwargs)
# Decode from UTF-8 by default
if "decode_responses" not in kwargs:
kwargs["decode_responses"] = True
return redis.Redis(**kwargs)
return client(**kwargs)

View file

@ -1,10 +1,10 @@
import click
from redis_developer.model.migrations.migrator import Migrator
from redis_om.model.migrations.migrator import Migrator
@click.command()
@click.option("--module", default="redis_developer")
@click.option("--module", default="redis_om")
def migrate(module):
migrator = Migrator(module)

View file

@ -2,15 +2,14 @@ import hashlib
import logging
from dataclasses import dataclass
from enum import Enum
from typing import Optional
from typing import Optional, Union
from redis import ResponseError
from redis import ResponseError, Redis
from aioredis import ResponseError as AResponseError, Redis as ARedis
from redis_developer.connections import get_redis_connection
from redis_developer.model.model import model_registry
from redis_om.model.model import model_registry
redis = get_redis_connection()
log = logging.getLogger(__name__)
@ -43,12 +42,12 @@ def schema_hash_key(index_name):
return f"{index_name}:hash"
def create_index(index_name, schema, current_hash):
async def create_index(redis: Union[Redis, ARedis], index_name, schema, current_hash):
try:
redis.execute_command(f"ft.info {index_name}")
except ResponseError:
redis.execute_command(f"ft.create {index_name} {schema}")
redis.set(schema_hash_key(index_name), current_hash)
await redis.execute_command(f"ft.info {index_name}")
except (ResponseError, AResponseError):
await redis.execute_command(f"ft.create {index_name} {schema}")
await redis.set(schema_hash_key(index_name), current_hash)
else:
log.info("Index already exists, skipping. Index hash: %s", index_name)
@ -65,34 +64,38 @@ class IndexMigration:
schema: str
hash: str
action: MigrationAction
redis: Union[Redis, ARedis]
previous_hash: Optional[str] = None
def run(self):
async def run(self):
if self.action is MigrationAction.CREATE:
self.create()
await self.create()
elif self.action is MigrationAction.DROP:
self.drop()
await self.drop()
def create(self):
async def create(self):
try:
return create_index(self.index_name, self.schema, self.hash)
await create_index(self.redis, self.index_name, self.schema, self.hash)
except ResponseError:
log.info("Index already exists: %s", self.index_name)
def drop(self):
async def drop(self):
try:
redis.execute_command(f"FT.DROPINDEX {self.index_name}")
await self.redis.execute_command(f"FT.DROPINDEX {self.index_name}")
except ResponseError:
log.info("Index does not exist: %s", self.index_name)
class Migrator:
def __init__(self, module=None):
# Try to load any modules found under the given path or module name.
if module:
import_submodules(module)
def __init__(self, redis: Union[Redis, ARedis], module=None):
self.module = module
self.migrations = []
self.redis = redis
async def run(self):
# Try to load any modules found under the given path or module name.
if self.module:
import_submodules(self.module)
for name, cls in model_registry.items():
hash_key = schema_hash_key(cls.Meta.index_name)
@ -104,8 +107,8 @@ class Migrator:
current_hash = hashlib.sha1(schema.encode("utf-8")).hexdigest() # nosec
try:
redis.execute_command("ft.info", cls.Meta.index_name)
except ResponseError:
await self.redis.execute_command("ft.info", cls.Meta.index_name)
except (ResponseError, AResponseError):
self.migrations.append(
IndexMigration(
name,
@ -113,11 +116,12 @@ class Migrator:
schema,
current_hash,
MigrationAction.CREATE,
self.redis
)
)
continue
stored_hash = redis.get(hash_key)
stored_hash = self.redis.get(hash_key)
schema_out_of_date = current_hash != stored_hash
if schema_out_of_date:
@ -129,7 +133,8 @@ class Migrator:
schema,
current_hash,
MigrationAction.DROP,
stored_hash,
self.redis,
stored_hash
)
)
self.migrations.append(
@ -139,12 +144,12 @@ class Migrator:
schema,
current_hash,
MigrationAction.CREATE,
stored_hash,
self.redis,
stored_hash
)
)
def run(self):
# TODO: Migration history
# TODO: Dry run with output
for migration in self.migrations:
migration.run()
await migration.run()

View file

@ -4,7 +4,7 @@ import decimal
import json
import logging
import operator
from copy import copy, deepcopy
from copy import copy
from enum import Enum
from functools import reduce
from typing import (
@ -27,6 +27,7 @@ from typing import (
no_type_check,
)
import aioredis
import redis
from pydantic import BaseModel, validator
from pydantic.fields import FieldInfo as PydanticFieldInfo
@ -37,11 +38,11 @@ from pydantic.utils import Representation
from redis.client import Pipeline
from ulid import ULID
from ..connections import get_redis_connection
from redis_om.connections import get_redis_connection
from .encoders import jsonable_encoder
from .render_tree import render_tree
from .token_escaper import TokenEscaper
from ..unasync_util import ASYNC_MODE
model_registry = {}
_T = TypeVar("_T")
@ -521,7 +522,7 @@ class FindQuery:
# this is not going to work.
log.warning(
"Your query against the field %s is for a single character, %s, "
"that is used internally by redis-developer-python. We must ignore "
"that is used internally by redis-om-python. We must ignore "
"this portion of the query. Please review your query to find "
"an alternative query that uses a string containing more than "
"just the character %s.",
@ -680,7 +681,7 @@ class FindQuery:
return result
def execute(self, exhaust_results=True):
async def execute(self, exhaust_results=True):
args = ["ft.search", self.model.Meta.index_name, self.query, *self.pagination]
if self.sort_fields:
args += self.resolve_redisearch_sort_fields()
@ -691,7 +692,7 @@ class FindQuery:
# If the offset is greater than 0, we're paginating through a result set,
# so append the new results to results already in the cache.
raw_result = self.model.db().execute_command(*args)
raw_result = await self.model.db().execute_command(*args)
count = raw_result[0]
results = self.model.from_redis(raw_result)
self._model_cache += results
@ -710,31 +711,31 @@ class FindQuery:
# Make a query for each pass of the loop, with a new offset equal to the
# current offset plus `page_size`, until we stop getting results back.
query = query.copy(offset=query.offset + query.page_size)
_results = query.execute(exhaust_results=False)
_results = await query.execute(exhaust_results=False)
if not _results:
break
self._model_cache += _results
return self._model_cache
def first(self):
async def first(self):
query = self.copy(offset=0, limit=1, sort_fields=self.sort_fields)
results = query.execute()
results = await query.execute()
if not results:
raise NotFoundError()
return results[0]
def all(self, batch_size=10):
async def all(self, batch_size=10):
if batch_size != self.page_size:
query = self.copy(page_size=batch_size, limit=batch_size)
return query.execute()
return self.execute()
return await query.execute()
return await self.execute()
def sort_by(self, *fields: str):
if not fields:
return self
return self.copy(sort_fields=list(fields))
def update(self, use_transaction=True, **field_values):
async def update(self, use_transaction=True, **field_values):
"""
Update models that match this query to the given field-value pairs.
@ -743,31 +744,32 @@ class FindQuery:
given fields.
"""
validate_model_fields(self.model, field_values)
pipeline = self.model.db().pipeline() if use_transaction else None
pipeline = await self.model.db().pipeline() if use_transaction else None
for model in self.all():
# TODO: async for here?
for model in await self.all():
for field, value in field_values.items():
setattr(model, field, value)
# TODO: In the non-transaction case, can we do more to detect
# failure responses from Redis?
model.save(pipeline=pipeline)
await model.save(pipeline=pipeline)
if pipeline:
# TODO: Response type?
# TODO: Better error detection for transactions.
pipeline.execute()
def delete(self):
async def delete(self):
"""Delete all matching records in this query."""
# TODO: Better response type, error detection
return self.model.db().delete(*[m.key() for m in self.all()])
return await self.model.db().delete(*[m.key() for m in await self.all()])
def __iter__(self):
async def __aiter__(self):
if self._model_cache:
for m in self._model_cache:
yield m
else:
for m in self.execute():
for m in await self.execute():
yield m
def __getitem__(self, item: int):
@ -784,12 +786,39 @@ class FindQuery:
that result, then we should clone the current query and
give it a new offset and limit: offset=n, limit=1.
"""
if ASYNC_MODE:
raise QuerySyntaxError("Cannot use [] notation with async code. "
"Use FindQuery.get_item() instead.")
if self._model_cache and len(self._model_cache) >= item:
return self._model_cache[item]
query = self.copy(offset=item, limit=1)
return query.execute()[0]
return query.execute()[0] # noqa
async def get_item(self, item: int):
"""
Given this code:
await Model.find().get_item(1000)
We should return only the 1000th result.
1. If the result is loaded in the query cache for this query,
we can return it directly from the cache.
2. If the query cache does not have enough elements to return
that result, then we should clone the current query and
give it a new offset and limit: offset=n, limit=1.
NOTE: This method is included specifically for async users, who
cannot use the notation Model.find()[1000].
"""
if self._model_cache and len(self._model_cache) >= item:
return self._model_cache[item]
query = self.copy(offset=item, limit=1)
result = await query.execute()
return result[0]
class PrimaryKeyCreator(Protocol):
@ -913,7 +942,7 @@ class MetaProtocol(Protocol):
global_key_prefix: str
model_key_prefix: str
primary_key_pattern: str
database: redis.Redis
database: aioredis.Redis
primary_key: PrimaryKey
primary_key_creator_cls: Type[PrimaryKeyCreator]
index_name: str
@ -932,7 +961,7 @@ class DefaultMeta:
global_key_prefix: Optional[str] = None
model_key_prefix: Optional[str] = None
primary_key_pattern: Optional[str] = None
database: Optional[redis.Redis] = None
database: Optional[Union[redis.Redis, aioredis.Redis]] = None
primary_key: Optional[PrimaryKey] = None
primary_key_creator_cls: Optional[Type[PrimaryKeyCreator]] = None
index_name: Optional[str] = None
@ -1049,14 +1078,18 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
pk = getattr(self, self._meta.primary_key.field.name)
return self.make_primary_key(pk)
def delete(self):
return self.db().delete(self.key())
async def delete(self):
return await self.db().delete(self.key())
def update(self, **field_values):
@classmethod
async def get(cls, pk: Any) -> 'RedisModel':
raise NotImplementedError
async def update(self, **field_values):
"""Update this model instance with the specified key-value pairs."""
raise NotImplementedError
def save(self, pipeline: Optional[Pipeline] = None) -> "RedisModel":
async def save(self, pipeline: Optional[Pipeline] = None) -> "RedisModel":
raise NotImplementedError
@validator("pk", always=True)
@ -1158,9 +1191,9 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
return d
@classmethod
def add(cls, models: Sequence["RedisModel"]) -> Sequence["RedisModel"]:
async def add(cls, models: Sequence["RedisModel"]) -> Sequence["RedisModel"]:
# TODO: Add transaction support
return [model.save() for model in models]
return [await model.save() for model in models]
@classmethod
def values(cls):
@ -1189,17 +1222,18 @@ class HashModel(RedisModel, abc.ABC):
f" or mapping fields. Field: {name}"
)
def save(self, pipeline: Optional[Pipeline] = None) -> "HashModel":
async def save(self, pipeline: Optional[Pipeline] = None) -> "HashModel":
if pipeline is None:
db = self.db()
else:
db = pipeline
document = jsonable_encoder(self.dict())
db.hset(self.key(), mapping=document)
# TODO: Wrap any Redis response errors in a custom exception?
await db.hset(self.key(), mapping=document)
return self
@classmethod
def get(cls, pk: Any) -> "HashModel":
async def get(cls, pk: Any) -> "HashModel":
document = cls.db().hgetall(cls.make_primary_key(pk))
if not document:
raise NotFoundError
@ -1311,23 +1345,24 @@ class JsonModel(RedisModel, abc.ABC):
# Generate the RediSearch schema once to validate fields.
cls.redisearch_schema()
def save(self, pipeline: Optional[Pipeline] = None) -> "JsonModel":
async def save(self, pipeline: Optional[Pipeline] = None) -> "JsonModel":
if pipeline is None:
db = self.db()
else:
db = pipeline
db.execute_command("JSON.SET", self.key(), ".", self.json())
# TODO: Wrap response errors in a custom exception?
await db.execute_command("JSON.SET", self.key(), ".", self.json())
return self
def update(self, **field_values):
async def update(self, **field_values):
validate_model_fields(self.__class__, field_values)
for field, value in field_values.items():
setattr(self, field, value)
self.save()
await self.save()
@classmethod
def get(cls, pk: Any) -> "JsonModel":
document = cls.db().execute_command("JSON.GET", cls.make_primary_key(pk))
async def get(cls, pk: Any) -> "JsonModel":
document = await cls.db().execute_command("JSON.GET", cls.make_primary_key(pk))
if not document:
raise NotFoundError
return cls.parse_raw(document)

View file

@ -1,17 +1,17 @@
import abc
from typing import Optional
from redis_developer.model.model import HashModel, JsonModel
from redis_om.model.model import HashModel, JsonModel
class BaseJsonModel(JsonModel, abc.ABC):
class Meta:
global_key_prefix = "redis-developer"
global_key_prefix = "redis-om"
class BaseHashModel(HashModel, abc.ABC):
class Meta:
global_key_prefix = "redis-developer"
global_key_prefix = "redis-om"
# class AddressJson(BaseJsonModel):

View file

@ -1,7 +1,7 @@
from collections import Sequence
from typing import Any, Dict, List, Mapping, Union
from redis_developer.model.model import Expression
from redis_om.model.model import Expression
class LogicalOperatorForListOfExpressions(Expression):

40
redis_om/unasync_util.py Normal file
View file

@ -0,0 +1,40 @@
"""Set of utility functions for unasync that transform into sync counterparts cleanly"""
import inspect
_original_next = next
def is_async_mode():
"""Tests if we're in the async part of the code or not"""
async def f():
"""Unasync transforms async functions in sync functions"""
return None
obj = f()
if obj is None:
return False
else:
obj.close() # prevent unawaited coroutine warning
return True
ASYNC_MODE = is_async_mode()
async def anext(x):
return await x.__anext__()
async def await_if_coro(x):
if inspect.iscoroutine(x):
return await x
return x
next = _original_next
def return_non_coro(x):
return x

View file

@ -1,19 +1,18 @@
import random
import pytest
from redis import Redis
from redis_developer.connections import get_redis_connection
from redis_om.connections import get_redis_connection
@pytest.fixture
def redis():
def redis(event_loop):
yield get_redis_connection()
def _delete_test_keys(prefix: str, conn: Redis):
async def _delete_test_keys(prefix: str, conn):
keys = []
for key in conn.scan_iter(f"{prefix}:*"):
async for key in conn.scan_iter(f"{prefix}:*"):
keys.append(key)
if keys:
conn.delete(*keys)
@ -21,11 +20,10 @@ def _delete_test_keys(prefix: str, conn: Redis):
@pytest.fixture
def key_prefix(redis):
key_prefix = f"redis-developer:{random.random()}"
key_prefix = f"redis-om:{random.random()}"
yield key_prefix
_delete_test_keys(key_prefix, redis)
@pytest.fixture(autouse=True)
def delete_test_keys(redis, request, key_prefix):
_delete_test_keys(key_prefix, redis)
async def delete_test_keys(redis, request, key_prefix):
await _delete_test_keys(key_prefix, redis)

View file

@ -8,9 +8,9 @@ from unittest import mock
import pytest
from pydantic import ValidationError
from redis_developer.model import Field, HashModel
from redis_developer.model.migrations.migrator import Migrator
from redis_developer.model.model import (
from redis_om.model import Field, HashModel
from redis_om.model.migrations.migrator import Migrator
from redis_om.model.model import (
NotFoundError,
QueryNotSupportedError,
RedisModelError,

View file

@ -1,4 +1,5 @@
import abc
import asyncio
import datetime
import decimal
from collections import namedtuple
@ -8,9 +9,9 @@ from unittest import mock
import pytest
from pydantic import ValidationError
from redis_developer.model import EmbeddedJsonModel, Field, JsonModel
from redis_developer.model.migrations.migrator import Migrator
from redis_developer.model.model import (
from redis_om.model import EmbeddedJsonModel, Field, JsonModel
from redis_om.model.migrations.migrator import Migrator
from redis_om.model.model import (
NotFoundError,
QueryNotSupportedError,
RedisModelError,
@ -21,7 +22,7 @@ today = datetime.date.today()
@pytest.fixture
def m(key_prefix):
async def m(key_prefix, redis):
class BaseJsonModel(JsonModel, abc.ABC):
class Meta:
global_key_prefix = key_prefix
@ -64,7 +65,7 @@ def m(key_prefix):
# Creates an embedded list of models.
orders: Optional[List[Order]]
Migrator().run()
await Migrator(redis).run()
return namedtuple(
"Models", ["BaseJsonModel", "Note", "Address", "Item", "Order", "Member"]
@ -83,7 +84,7 @@ def address(m):
@pytest.fixture()
def members(address, m):
async def members(address, m):
member1 = m.Member(
first_name="Andrew",
last_name="Brookins",
@ -111,14 +112,15 @@ def members(address, m):
address=address,
)
member1.save()
member2.save()
member3.save()
await member1.save()
await member2.save()
await member3.save()
yield member1, member2, member3
def test_validates_required_fields(address, m):
@pytest.mark.asyncio
async def test_validates_required_fields(address, m):
# Raises ValidationError address is required
with pytest.raises(ValidationError):
m.Member(
@ -129,7 +131,8 @@ def test_validates_required_fields(address, m):
)
def test_validates_field(address, m):
@pytest.mark.asyncio
async def test_validates_field(address, m):
# Raises ValidationError: join_date is not a date
with pytest.raises(ValidationError):
m.Member(
@ -141,7 +144,8 @@ def test_validates_field(address, m):
# Passes validation
def test_validation_passes(address, m):
@pytest.mark.asyncio
async def test_validation_passes(address, m):
member = m.Member(
first_name="Andrew",
last_name="Brookins",
@ -153,7 +157,10 @@ def test_validation_passes(address, m):
assert member.first_name == "Andrew"
def test_saves_model_and_creates_pk(address, m):
@pytest.mark.asyncio
async def test_saves_model_and_creates_pk(address, m, redis):
await Migrator(redis).run()
member = m.Member(
first_name="Andrew",
last_name="Brookins",
@ -163,15 +170,16 @@ def test_saves_model_and_creates_pk(address, m):
address=address,
)
# Save a model instance to Redis
member.save()
await member.save()
member2 = m.Member.get(member.pk)
member2 = await m.Member.get(member.pk)
assert member2 == member
assert member2.address == address
@pytest.mark.skip("Not implemented yet")
def test_saves_many(address, m):
@pytest.mark.asyncio
async def test_saves_many(address, m):
members = [
m.Member(
first_name="Andrew",
@ -193,9 +201,16 @@ def test_saves_many(address, m):
m.Member.add(members)
async def save(members):
for m in members:
await m.save()
return members
@pytest.mark.skip("Not ready yet")
def test_updates_a_model(members, m):
member1, member2, member3 = members
@pytest.mark.asyncio
async def test_updates_a_model(members, m):
member1, member2, member3 = await save(members)
# Or, with an implicit save:
member1.update(last_name="Smith")
@ -213,18 +228,20 @@ def test_updates_a_model(members, m):
)
def test_paginate_query(members, m):
@pytest.mark.asyncio
async def test_paginate_query(members, m):
member1, member2, member3 = members
actual = m.Member.find().sort_by("age").all(batch_size=1)
actual = await m.Member.find().sort_by("age").all(batch_size=1)
assert actual == [member2, member1, member3]
def test_access_result_by_index_cached(members, m):
@pytest.mark.asyncio
async def test_access_result_by_index_cached(members, m):
member1, member2, member3 = members
query = m.Member.find().sort_by("age")
# Load the cache, throw away the result.
assert query._model_cache == []
query.execute()
await query.execute()
assert query._model_cache == [member2, member1, member3]
# Access an item that should be in the cache.
@ -233,21 +250,23 @@ def test_access_result_by_index_cached(members, m):
assert not mock_db.called
def test_access_result_by_index_not_cached(members, m):
@pytest.mark.asyncio
async def test_access_result_by_index_not_cached(members, m):
member1, member2, member3 = members
query = m.Member.find().sort_by("age")
# Assert that we don't have any models in the cache yet -- we
# haven't made any requests of Redis.
assert query._model_cache == []
assert query[0] == member2
assert query[1] == member1
assert query[2] == member3
assert query.get_item(0) == member2
assert query.get_item(1) == member1
assert query.get_item(2) == member3
def test_in_query(members, m):
@pytest.mark.asyncio
async def test_in_query(members, m):
member1, member2, member3 = members
actual = (
actual = await (
m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk])
.sort_by("age")
.all()
@ -256,12 +275,13 @@ def test_in_query(members, m):
@pytest.mark.skip("Not implemented yet")
def test_update_query(members, m):
@pytest.mark.asyncio
async def test_update_query(members, m):
member1, member2, member3 = members
m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk]).update(
await m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk]).update(
first_name="Bobby"
)
actual = (
actual = await (
m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk])
.sort_by("age")
.all()
@ -270,24 +290,25 @@ def test_update_query(members, m):
assert all([m.name == "Bobby" for m in actual])
def test_exact_match_queries(members, m):
@pytest.mark.asyncio
async def test_exact_match_queries(members, m):
member1, member2, member3 = members
actual = m.Member.find(m.Member.last_name == "Brookins").sort_by("age").all()
actual = await m.Member.find(m.Member.last_name == "Brookins").sort_by("age").all()
assert actual == [member2, member1]
actual = m.Member.find(
actual = await m.Member.find(
(m.Member.last_name == "Brookins") & ~(m.Member.first_name == "Andrew")
).all()
assert actual == [member2]
actual = m.Member.find(~(m.Member.last_name == "Brookins")).all()
actual = await m.Member.find(~(m.Member.last_name == "Brookins")).all()
assert actual == [member3]
actual = m.Member.find(m.Member.last_name != "Brookins").all()
actual = await m.Member.find(m.Member.last_name != "Brookins").all()
assert actual == [member3]
actual = (
actual = await (
m.Member.find(
(m.Member.last_name == "Brookins") & (m.Member.first_name == "Andrew")
| (m.Member.first_name == "Kim")
@ -297,19 +318,20 @@ def test_exact_match_queries(members, m):
)
assert actual == [member2, member1]
actual = m.Member.find(
actual = await m.Member.find(
m.Member.first_name == "Kim", m.Member.last_name == "Brookins"
).all()
assert actual == [member2]
actual = m.Member.find(m.Member.address.city == "Portland").sort_by("age").all()
actual = await m.Member.find(m.Member.address.city == "Portland").sort_by("age").all()
assert actual == [member2, member1, member3]
def test_recursive_query_expression_resolution(members, m):
@pytest.mark.asyncio
async def test_recursive_query_expression_resolution(members, m):
member1, member2, member3 = members
actual = (
actual = await (
m.Member.find(
(m.Member.last_name == "Brookins")
| (m.Member.age == 100) & (m.Member.last_name == "Smith")
@ -320,13 +342,14 @@ def test_recursive_query_expression_resolution(members, m):
assert actual == [member2, member1, member3]
def test_recursive_query_field_resolution(members, m):
@pytest.mark.asyncio
async def test_recursive_query_field_resolution(members, m):
member1, _, _ = members
member1.address.note = m.Note(
description="Weird house", created_on=datetime.datetime.now()
)
member1.save()
actual = m.Member.find(m.Member.address.note.description == "Weird house").all()
await member1.save()
actual = await m.Member.find(m.Member.address.note.description == "Weird house").all()
assert actual == [member1]
member1.orders = [
@ -336,29 +359,31 @@ def test_recursive_query_field_resolution(members, m):
created_on=datetime.datetime.now(),
)
]
member1.save()
actual = m.Member.find(m.Member.orders.items.name == "Ball").all()
await member1.save()
actual = await m.Member.find(m.Member.orders.items.name == "Ball").all()
assert actual == [member1]
assert actual[0].orders[0].items[0].name == "Ball"
def test_full_text_search(members, m):
@pytest.mark.asyncio
async def test_full_text_search(members, m):
member1, member2, _ = members
member1.update(bio="Hates sunsets, likes beaches")
member2.update(bio="Hates beaches, likes forests")
await member1.update(bio="Hates sunsets, likes beaches")
await member2.update(bio="Hates beaches, likes forests")
actual = m.Member.find(m.Member.bio % "beaches").sort_by("age").all()
actual = await m.Member.find(m.Member.bio % "beaches").sort_by("age").all()
assert actual == [member2, member1]
actual = m.Member.find(m.Member.bio % "forests").all()
actual = await m.Member.find(m.Member.bio % "forests").all()
assert actual == [member2]
def test_tag_queries_boolean_logic(members, m):
@pytest.mark.asyncio
async def test_tag_queries_boolean_logic(members, m):
member1, member2, member3 = members
actual = (
m.Member.find(
await m.Member.find(
(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
| (m.Member.last_name == "Smith")
)
@ -368,7 +393,8 @@ def test_tag_queries_boolean_logic(members, m):
assert actual == [member1, member3]
def test_tag_queries_punctuation(address, m):
@pytest.mark.asyncio
async def test_tag_queries_punctuation(address, m):
member1 = m.Member(
first_name="Andrew, the Michael",
last_name="St. Brookins-on-Pier",
@ -377,7 +403,7 @@ def test_tag_queries_punctuation(address, m):
join_date=today,
address=address,
)
member1.save()
await member1.save()
member2 = m.Member(
first_name="Bob",
@ -387,24 +413,25 @@ def test_tag_queries_punctuation(address, m):
join_date=today,
address=address,
)
member2.save()
await member2.save()
assert (
m.Member.find(m.Member.first_name == "Andrew, the Michael").first() == member1
await m.Member.find(m.Member.first_name == "Andrew, the Michael").first() == member1
)
assert (
m.Member.find(m.Member.last_name == "St. Brookins-on-Pier").first() == member1
await m.Member.find(m.Member.last_name == "St. Brookins-on-Pier").first() == member1
)
# Notice that when we index and query multiple values that use the internal
# TAG separator for single-value exact-match fields, like an indexed string,
# the queries will succeed. We apply a workaround that queries for the union
# of the two values separated by the tag separator.
assert m.Member.find(m.Member.email == "a|b@example.com").all() == [member1]
assert m.Member.find(m.Member.email == "a|villain@example.com").all() == [member2]
assert await m.Member.find(m.Member.email == "a|b@example.com").all() == [member1]
assert await m.Member.find(m.Member.email == "a|villain@example.com").all() == [member2]
def test_tag_queries_negation(members, m):
@pytest.mark.asyncio
async def test_tag_queries_negation(members, m):
member1, member2, member3 = members
"""
@ -414,7 +441,7 @@ def test_tag_queries_negation(members, m):
"""
query = m.Member.find(~(m.Member.first_name == "Andrew"))
assert query.all() == [member2]
assert await query.all() == [member2]
"""
first_name
@ -429,7 +456,7 @@ def test_tag_queries_negation(members, m):
query = m.Member.find(
~(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
)
assert query.all() == [member2]
assert await query.all() == [member2]
"""
first_name
@ -448,7 +475,7 @@ def test_tag_queries_negation(members, m):
~(m.Member.first_name == "Andrew")
& ((m.Member.last_name == "Brookins") | (m.Member.last_name == "Smith"))
)
assert query.all() == [member2]
assert await query.all() == [member2]
"""
first_name
@ -467,67 +494,71 @@ def test_tag_queries_negation(members, m):
~(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
| (m.Member.last_name == "Smith")
)
assert query.sort_by("age").all() == [member2, member3]
assert await query.sort_by("age").all() == [member2, member3]
actual = m.Member.find(
actual = await m.Member.find(
(m.Member.first_name == "Andrew") & ~(m.Member.last_name == "Brookins")
).all()
assert actual == [member3]
def test_numeric_queries(members, m):
@pytest.mark.asyncio
async def test_numeric_queries(members, m):
member1, member2, member3 = members
actual = m.Member.find(m.Member.age == 34).all()
actual = await m.Member.find(m.Member.age == 34).all()
assert actual == [member2]
actual = m.Member.find(m.Member.age > 34).all()
actual = await m.Member.find(m.Member.age > 34).all()
assert actual == [member1, member3]
actual = m.Member.find(m.Member.age < 35).all()
actual = await m.Member.find(m.Member.age < 35).all()
assert actual == [member2]
actual = m.Member.find(m.Member.age <= 34).all()
actual = await m.Member.find(m.Member.age <= 34).all()
assert actual == [member2]
actual = m.Member.find(m.Member.age >= 100).all()
actual = await m.Member.find(m.Member.age >= 100).all()
assert actual == [member3]
actual = m.Member.find(~(m.Member.age == 100)).sort_by("age").all()
actual = await m.Member.find(~(m.Member.age == 100)).sort_by("age").all()
assert actual == [member2, member1]
actual = m.Member.find(m.Member.age > 30, m.Member.age < 40).sort_by("age").all()
actual = await m.Member.find(m.Member.age > 30, m.Member.age < 40).sort_by("age").all()
assert actual == [member2, member1]
actual = m.Member.find(m.Member.age != 34).sort_by("age").all()
actual = await m.Member.find(m.Member.age != 34).sort_by("age").all()
assert actual == [member1, member3]
def test_sorting(members, m):
@pytest.mark.asyncio
async def test_sorting(members, m):
member1, member2, member3 = members
actual = m.Member.find(m.Member.age > 34).sort_by("age").all()
actual = await m.Member.find(m.Member.age > 34).sort_by("age").all()
assert actual == [member1, member3]
actual = m.Member.find(m.Member.age > 34).sort_by("-age").all()
actual = await m.Member.find(m.Member.age > 34).sort_by("-age").all()
assert actual == [member3, member1]
with pytest.raises(QueryNotSupportedError):
# This field does not exist.
m.Member.find().sort_by("not-a-real-field").all()
await m.Member.find().sort_by("not-a-real-field").all()
with pytest.raises(QueryNotSupportedError):
# This field is not sortable.
m.Member.find().sort_by("join_date").all()
await m.Member.find().sort_by("join_date").all()
def test_not_found(m):
@pytest.mark.asyncio
async def test_not_found(m):
with pytest.raises(NotFoundError):
# This ID does not exist.
m.Member.get(1000)
await m.Member.get(1000)
def test_list_field_limitations(m):
@pytest.mark.asyncio
async def test_list_field_limitations(m, redis):
with pytest.raises(RedisModelError):
class SortableTarotWitch(m.BaseJsonModel):
@ -571,15 +602,16 @@ def test_list_field_limitations(m):
# We need to import and run this manually because we defined
# our model classes within a function that runs after the test
# suite's migrator has already looked for migrations to run.
Migrator().run()
await Migrator(redis).run()
witch = TarotWitch(tarot_cards=["death"])
witch.save()
actual = TarotWitch.find(TarotWitch.tarot_cards << "death").all()
await witch.save()
actual = await TarotWitch.find(TarotWitch.tarot_cards << "death").all()
assert actual == [witch]
def test_schema(m, key_prefix):
@pytest.mark.asyncio
async def test_schema(m, key_prefix):
assert (
m.Member.redisearch_schema()
== f"ON JSON PREFIX 1 {key_prefix}:tests.test_json_model.Member: SCHEMA $.pk AS pk TAG SEPARATOR | $.first_name AS first_name TAG SEPARATOR | $.last_name AS last_name TAG SEPARATOR | $.email AS email TAG SEPARATOR | $.age AS age NUMERIC $.bio AS bio TAG SEPARATOR | $.bio AS bio_fts TEXT $.address.pk AS address_pk TAG SEPARATOR | $.address.city AS address_city TAG SEPARATOR | $.address.postal_code AS address_postal_code TAG SEPARATOR | $.address.note.pk AS address_note_pk TAG SEPARATOR | $.address.note.description AS address_note_description TAG SEPARATOR | $.orders[*].pk AS orders_pk TAG SEPARATOR | $.orders[*].items[*].pk AS orders_items_pk TAG SEPARATOR | $.orders[*].items[*].name AS orders_items_name TAG SEPARATOR |"