WIP on async - test failure due to closed event loop

This commit is contained in:
Andrew Brookins 2021-10-22 06:33:05 -07:00
parent 0f9f7aa868
commit b2c2dd9f6f
22 changed files with 348 additions and 190 deletions

View file

@ -1,4 +1,4 @@
NAME := redis_developer NAME := redis_om
INSTALL_STAMP := .install.stamp INSTALL_STAMP := .install.stamp
POETRY := $(shell command -v poetry 2> /dev/null) POETRY := $(shell command -v poetry 2> /dev/null)

View file

@ -52,7 +52,7 @@ Check out this example:
import datetime import datetime
from typing import Optional from typing import Optional
from redis_developer.model import ( from redis_om.model import (
EmbeddedJsonModel, EmbeddedJsonModel,
JsonModel, JsonModel,
Field, Field,
@ -172,9 +172,9 @@ Don't want to run Redis yourself? RediSearch and RedisJSON are also available on
We'd love your contributions! We'd love your contributions!
**Bug reports** are especially helpful at this stage of the project. [You can open a bug report on GitHub](https://github.com/redis-developer/redis-developer-python/issues/new). **Bug reports** are especially helpful at this stage of the project. [You can open a bug report on GitHub](https://github.com/redis-om/redis-om-python/issues/new).
You can also **contribute documentation** -- or just let us know if something needs more detail. [Open an issue on GitHub](https://github.com/redis-developer/redis-developer-python/issues/new) to get started. You can also **contribute documentation** -- or just let us know if something needs more detail. [Open an issue on GitHub](https://github.com/redis-om/redis-om-python/issues/new) to get started.
## License ## License
@ -184,17 +184,17 @@ Redis OM is [MIT licensed][license-url].
[version-svg]: https://img.shields.io/pypi/v/redis-om?style=flat-square [version-svg]: https://img.shields.io/pypi/v/redis-om?style=flat-square
[package-url]: https://pypi.org/project/redis-om/ [package-url]: https://pypi.org/project/redis-om/
[ci-svg]: https://img.shields.io/github/workflow/status/redis-developer/redis-developer-python/python?style=flat-square [ci-svg]: https://img.shields.io/github/workflow/status/redis-om/redis-om-python/python?style=flat-square
[ci-url]: https://github.com/redis-developer/redis-developer-python/actions/workflows/build.yml [ci-url]: https://github.com/redis-om/redis-om-python/actions/workflows/build.yml
[license-image]: http://img.shields.io/badge/license-MIT-green.svg?style=flat-square [license-image]: http://img.shields.io/badge/license-MIT-green.svg?style=flat-square
[license-url]: LICENSE [license-url]: LICENSE
<!-- Links --> <!-- Links -->
[redis-developer-website]: https://developer.redis.com [redis-om-website]: https://developer.redis.com
[redis-om-js]: https://github.com/redis-developer/redis-om-js [redis-om-js]: https://github.com/redis-om/redis-om-js
[redis-om-dotnet]: https://github.com/redis-developer/redis-om-dotnet [redis-om-dotnet]: https://github.com/redis-om/redis-om-dotnet
[redis-om-spring]: https://github.com/redis-developer/redis-om-spring [redis-om-spring]: https://github.com/redis-om/redis-om-spring
[redisearch-url]: https://oss.redis.com/redisearch/ [redisearch-url]: https://oss.redis.com/redisearch/
[redis-json-url]: https://oss.redis.com/redisjson/ [redis-json-url]: https://oss.redis.com/redisjson/
[pydantic-url]: https://github.com/samuelcolvin/pydantic [pydantic-url]: https://github.com/samuelcolvin/pydantic

10
build.py Normal file
View file

@ -0,0 +1,10 @@
import unasync
def build(setup_kwargs):
setup_kwargs.update(
{"cmdclass": {'build_py': unasync.cmdclass_build_py(rules=[
unasync.Rule("/aredis_om/", "/redis_om/"),
unasync.Rule("/aredis_om/tests/", "/redis_om/tests/", additional_replacements={"aredis_om": "redis_om"}),
])}}
)

32
poetry.lock generated
View file

@ -538,6 +538,20 @@ toml = "*"
[package.extras] [package.extras]
testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"] testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"]
[[package]]
name = "pytest-asyncio"
version = "0.16.0"
description = "Pytest support for asyncio."
category = "dev"
optional = false
python-versions = ">= 3.6"
[package.dependencies]
pytest = ">=5.4.0"
[package.extras]
testing = ["coverage", "hypothesis (>=5.7.1)"]
[[package]] [[package]]
name = "pytest-cov" name = "pytest-cov"
version = "3.0.0" version = "3.0.0"
@ -707,6 +721,14 @@ category = "main"
optional = false optional = false
python-versions = "*" python-versions = "*"
[[package]]
name = "unasync"
version = "0.5.0"
description = "The async transformation code."
category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4"
[[package]] [[package]]
name = "wcwidth" name = "wcwidth"
version = "0.2.5" version = "0.2.5"
@ -726,7 +748,7 @@ python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7"
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.8" python-versions = "^3.8"
content-hash = "56b381dd9b79bd082e978019124176491c63f09dd5ce90e5f8ab642a7f79480f" content-hash = "d2d83b8cd3b094879e1aeb058d0036203942143f12fafa8be03fb0c79460028f"
[metadata.files] [metadata.files]
aioredis = [ aioredis = [
@ -1003,6 +1025,10 @@ pytest = [
{file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"}, {file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"},
{file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"}, {file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"},
] ]
pytest-asyncio = [
{file = "pytest-asyncio-0.16.0.tar.gz", hash = "sha256:7496c5977ce88c34379df64a66459fe395cd05543f0a2f837016e7144391fcfb"},
{file = "pytest_asyncio-0.16.0-py3-none-any.whl", hash = "sha256:5f2a21273c47b331ae6aa5b36087047b4899e40f03f18397c0e65fa5cca54e9b"},
]
pytest-cov = [ pytest-cov = [
{file = "pytest-cov-3.0.0.tar.gz", hash = "sha256:e7f0f5b1617d2210a2cabc266dfe2f4c75a8d32fb89eafb7ad9d06f6d076d470"}, {file = "pytest-cov-3.0.0.tar.gz", hash = "sha256:e7f0f5b1617d2210a2cabc266dfe2f4c75a8d32fb89eafb7ad9d06f6d076d470"},
{file = "pytest_cov-3.0.0-py3-none-any.whl", hash = "sha256:578d5d15ac4a25e5f961c938b85a05b09fdaae9deef3bb6de9a6e766622ca7a6"}, {file = "pytest_cov-3.0.0-py3-none-any.whl", hash = "sha256:578d5d15ac4a25e5f961c938b85a05b09fdaae9deef3bb6de9a6e766622ca7a6"},
@ -1148,6 +1174,10 @@ typing-extensions = [
{file = "typing_extensions-3.10.0.2-py3-none-any.whl", hash = "sha256:f1d25edafde516b146ecd0613dabcc61409817af4766fbbcfb8d1ad4ec441a34"}, {file = "typing_extensions-3.10.0.2-py3-none-any.whl", hash = "sha256:f1d25edafde516b146ecd0613dabcc61409817af4766fbbcfb8d1ad4ec441a34"},
{file = "typing_extensions-3.10.0.2.tar.gz", hash = "sha256:49f75d16ff11f1cd258e1b988ccff82a3ca5570217d7ad8c5f48205dd99a677e"}, {file = "typing_extensions-3.10.0.2.tar.gz", hash = "sha256:49f75d16ff11f1cd258e1b988ccff82a3ca5570217d7ad8c5f48205dd99a677e"},
] ]
unasync = [
{file = "unasync-0.5.0-py3-none-any.whl", hash = "sha256:8d4536dae85e87b8751dfcc776f7656fd0baf54bb022a7889440dc1b9dc3becb"},
{file = "unasync-0.5.0.tar.gz", hash = "sha256:b675d87cf56da68bd065d3b7a67ac71df85591978d84c53083c20d79a7e5096d"},
]
wcwidth = [ wcwidth = [
{file = "wcwidth-0.2.5-py2.py3-none-any.whl", hash = "sha256:beb4802a9cebb9144e99086eff703a642a13d6a0052920003a230f3294bbe784"}, {file = "wcwidth-0.2.5-py2.py3-none-any.whl", hash = "sha256:beb4802a9cebb9144e99086eff703a642a13d6a0052920003a230f3294bbe784"},
{file = "wcwidth-0.2.5.tar.gz", hash = "sha256:c4d647b99872929fdb7bdcaa4fbe7f01413ed3d98077df798530e5b04f116c83"}, {file = "wcwidth-0.2.5.tar.gz", hash = "sha256:c4d647b99872929fdb7bdcaa4fbe7f01413ed3d98077df798530e5b04f116c83"},

View file

@ -1,9 +1,10 @@
[tool.poetry] [tool.poetry]
name = "redis-developer" name = "redis-om"
version = "0.1.0" version = "0.1.0"
description = "A high-level library containing useful Redis abstractions and tools, like an ORM and leaderboard." description = "A high-level library containing useful Redis abstractions and tools, like an ORM and leaderboard."
authors = ["Andrew Brookins <andrew.brookins@redislabs.com>"] authors = ["Andrew Brookins <andrew.brookins@redislabs.com>"]
license = "MIT" license = "MIT"
build = "build.py"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.8" python = "^3.8"
@ -30,10 +31,11 @@ bandit = "^1.7.0"
coverage = "^6.0.2" coverage = "^6.0.2"
pytest-cov = "^3.0.0" pytest-cov = "^3.0.0"
pytest-xdist = "^2.4.0" pytest-xdist = "^2.4.0"
unasync = "^0.5.0"
pytest-asyncio = "^0.16.0"
[tool.poetry.scripts] [tool.poetry.scripts]
migrate = "redis_developer.orm.cli.migrate:migrate" migrate = "redis_om.orm.cli.migrate:migrate"
[build-system] [build-system]
requires = ["poetry-core>=1.0.0"] requires = ["poetry-core>=1.0.0"]

View file

@ -1,22 +1,28 @@
import os import os
from typing import Union
import dotenv import dotenv
import aioredis
import redis import redis
from redis_om.unasync_util import ASYNC_MODE
dotenv.load_dotenv() dotenv.load_dotenv()
URL = os.environ.get("REDIS_OM_URL", None) URL = os.environ.get("REDIS_OM_URL", None)
if ASYNC_MODE:
client = aioredis.Redis
else:
client = redis.Redis
def get_redis_connection(**kwargs) -> redis.Redis: def get_redis_connection(**kwargs) -> Union[aioredis.Redis, redis.Redis]:
# If someone passed in a 'url' parameter, or specified a REDIS_OM_URL # If someone passed in a 'url' parameter, or specified a REDIS_OM_URL
# environment variable, we'll create the Redis client from the URL. # environment variable, we'll create the Redis client from the URL.
url = kwargs.pop("url", URL) url = kwargs.pop("url", URL)
if url: if url:
return redis.from_url(url, **kwargs) return client.from_url(url, **kwargs)
# Decode from UTF-8 by default # Decode from UTF-8 by default
if "decode_responses" not in kwargs: if "decode_responses" not in kwargs:
kwargs["decode_responses"] = True kwargs["decode_responses"] = True
return redis.Redis(**kwargs) return client(**kwargs)

View file

@ -1,10 +1,10 @@
import click import click
from redis_developer.model.migrations.migrator import Migrator from redis_om.model.migrations.migrator import Migrator
@click.command() @click.command()
@click.option("--module", default="redis_developer") @click.option("--module", default="redis_om")
def migrate(module): def migrate(module):
migrator = Migrator(module) migrator = Migrator(module)

View file

@ -2,15 +2,14 @@ import hashlib
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional, Union
from redis import ResponseError from redis import ResponseError, Redis
from aioredis import ResponseError as AResponseError, Redis as ARedis
from redis_developer.connections import get_redis_connection from redis_om.model.model import model_registry
from redis_developer.model.model import model_registry
redis = get_redis_connection()
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -43,12 +42,12 @@ def schema_hash_key(index_name):
return f"{index_name}:hash" return f"{index_name}:hash"
def create_index(index_name, schema, current_hash): async def create_index(redis: Union[Redis, ARedis], index_name, schema, current_hash):
try: try:
redis.execute_command(f"ft.info {index_name}") await redis.execute_command(f"ft.info {index_name}")
except ResponseError: except (ResponseError, AResponseError):
redis.execute_command(f"ft.create {index_name} {schema}") await redis.execute_command(f"ft.create {index_name} {schema}")
redis.set(schema_hash_key(index_name), current_hash) await redis.set(schema_hash_key(index_name), current_hash)
else: else:
log.info("Index already exists, skipping. Index hash: %s", index_name) log.info("Index already exists, skipping. Index hash: %s", index_name)
@ -65,34 +64,38 @@ class IndexMigration:
schema: str schema: str
hash: str hash: str
action: MigrationAction action: MigrationAction
redis: Union[Redis, ARedis]
previous_hash: Optional[str] = None previous_hash: Optional[str] = None
def run(self): async def run(self):
if self.action is MigrationAction.CREATE: if self.action is MigrationAction.CREATE:
self.create() await self.create()
elif self.action is MigrationAction.DROP: elif self.action is MigrationAction.DROP:
self.drop() await self.drop()
def create(self): async def create(self):
try: try:
return create_index(self.index_name, self.schema, self.hash) await create_index(self.redis, self.index_name, self.schema, self.hash)
except ResponseError: except ResponseError:
log.info("Index already exists: %s", self.index_name) log.info("Index already exists: %s", self.index_name)
def drop(self): async def drop(self):
try: try:
redis.execute_command(f"FT.DROPINDEX {self.index_name}") await self.redis.execute_command(f"FT.DROPINDEX {self.index_name}")
except ResponseError: except ResponseError:
log.info("Index does not exist: %s", self.index_name) log.info("Index does not exist: %s", self.index_name)
class Migrator: class Migrator:
def __init__(self, module=None): def __init__(self, redis: Union[Redis, ARedis], module=None):
# Try to load any modules found under the given path or module name. self.module = module
if module:
import_submodules(module)
self.migrations = [] self.migrations = []
self.redis = redis
async def run(self):
# Try to load any modules found under the given path or module name.
if self.module:
import_submodules(self.module)
for name, cls in model_registry.items(): for name, cls in model_registry.items():
hash_key = schema_hash_key(cls.Meta.index_name) hash_key = schema_hash_key(cls.Meta.index_name)
@ -104,8 +107,8 @@ class Migrator:
current_hash = hashlib.sha1(schema.encode("utf-8")).hexdigest() # nosec current_hash = hashlib.sha1(schema.encode("utf-8")).hexdigest() # nosec
try: try:
redis.execute_command("ft.info", cls.Meta.index_name) await self.redis.execute_command("ft.info", cls.Meta.index_name)
except ResponseError: except (ResponseError, AResponseError):
self.migrations.append( self.migrations.append(
IndexMigration( IndexMigration(
name, name,
@ -113,11 +116,12 @@ class Migrator:
schema, schema,
current_hash, current_hash,
MigrationAction.CREATE, MigrationAction.CREATE,
self.redis
) )
) )
continue continue
stored_hash = redis.get(hash_key) stored_hash = self.redis.get(hash_key)
schema_out_of_date = current_hash != stored_hash schema_out_of_date = current_hash != stored_hash
if schema_out_of_date: if schema_out_of_date:
@ -129,7 +133,8 @@ class Migrator:
schema, schema,
current_hash, current_hash,
MigrationAction.DROP, MigrationAction.DROP,
stored_hash, self.redis,
stored_hash
) )
) )
self.migrations.append( self.migrations.append(
@ -139,12 +144,12 @@ class Migrator:
schema, schema,
current_hash, current_hash,
MigrationAction.CREATE, MigrationAction.CREATE,
stored_hash, self.redis,
stored_hash
) )
) )
def run(self):
# TODO: Migration history # TODO: Migration history
# TODO: Dry run with output # TODO: Dry run with output
for migration in self.migrations: for migration in self.migrations:
migration.run() await migration.run()

View file

@ -4,7 +4,7 @@ import decimal
import json import json
import logging import logging
import operator import operator
from copy import copy, deepcopy from copy import copy
from enum import Enum from enum import Enum
from functools import reduce from functools import reduce
from typing import ( from typing import (
@ -27,6 +27,7 @@ from typing import (
no_type_check, no_type_check,
) )
import aioredis
import redis import redis
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
from pydantic.fields import FieldInfo as PydanticFieldInfo from pydantic.fields import FieldInfo as PydanticFieldInfo
@ -37,11 +38,11 @@ from pydantic.utils import Representation
from redis.client import Pipeline from redis.client import Pipeline
from ulid import ULID from ulid import ULID
from ..connections import get_redis_connection from redis_om.connections import get_redis_connection
from .encoders import jsonable_encoder from .encoders import jsonable_encoder
from .render_tree import render_tree from .render_tree import render_tree
from .token_escaper import TokenEscaper from .token_escaper import TokenEscaper
from ..unasync_util import ASYNC_MODE
model_registry = {} model_registry = {}
_T = TypeVar("_T") _T = TypeVar("_T")
@ -521,7 +522,7 @@ class FindQuery:
# this is not going to work. # this is not going to work.
log.warning( log.warning(
"Your query against the field %s is for a single character, %s, " "Your query against the field %s is for a single character, %s, "
"that is used internally by redis-developer-python. We must ignore " "that is used internally by redis-om-python. We must ignore "
"this portion of the query. Please review your query to find " "this portion of the query. Please review your query to find "
"an alternative query that uses a string containing more than " "an alternative query that uses a string containing more than "
"just the character %s.", "just the character %s.",
@ -680,7 +681,7 @@ class FindQuery:
return result return result
def execute(self, exhaust_results=True): async def execute(self, exhaust_results=True):
args = ["ft.search", self.model.Meta.index_name, self.query, *self.pagination] args = ["ft.search", self.model.Meta.index_name, self.query, *self.pagination]
if self.sort_fields: if self.sort_fields:
args += self.resolve_redisearch_sort_fields() args += self.resolve_redisearch_sort_fields()
@ -691,7 +692,7 @@ class FindQuery:
# If the offset is greater than 0, we're paginating through a result set, # If the offset is greater than 0, we're paginating through a result set,
# so append the new results to results already in the cache. # so append the new results to results already in the cache.
raw_result = self.model.db().execute_command(*args) raw_result = await self.model.db().execute_command(*args)
count = raw_result[0] count = raw_result[0]
results = self.model.from_redis(raw_result) results = self.model.from_redis(raw_result)
self._model_cache += results self._model_cache += results
@ -710,31 +711,31 @@ class FindQuery:
# Make a query for each pass of the loop, with a new offset equal to the # Make a query for each pass of the loop, with a new offset equal to the
# current offset plus `page_size`, until we stop getting results back. # current offset plus `page_size`, until we stop getting results back.
query = query.copy(offset=query.offset + query.page_size) query = query.copy(offset=query.offset + query.page_size)
_results = query.execute(exhaust_results=False) _results = await query.execute(exhaust_results=False)
if not _results: if not _results:
break break
self._model_cache += _results self._model_cache += _results
return self._model_cache return self._model_cache
def first(self): async def first(self):
query = self.copy(offset=0, limit=1, sort_fields=self.sort_fields) query = self.copy(offset=0, limit=1, sort_fields=self.sort_fields)
results = query.execute() results = await query.execute()
if not results: if not results:
raise NotFoundError() raise NotFoundError()
return results[0] return results[0]
def all(self, batch_size=10): async def all(self, batch_size=10):
if batch_size != self.page_size: if batch_size != self.page_size:
query = self.copy(page_size=batch_size, limit=batch_size) query = self.copy(page_size=batch_size, limit=batch_size)
return query.execute() return await query.execute()
return self.execute() return await self.execute()
def sort_by(self, *fields: str): def sort_by(self, *fields: str):
if not fields: if not fields:
return self return self
return self.copy(sort_fields=list(fields)) return self.copy(sort_fields=list(fields))
def update(self, use_transaction=True, **field_values): async def update(self, use_transaction=True, **field_values):
""" """
Update models that match this query to the given field-value pairs. Update models that match this query to the given field-value pairs.
@ -743,31 +744,32 @@ class FindQuery:
given fields. given fields.
""" """
validate_model_fields(self.model, field_values) validate_model_fields(self.model, field_values)
pipeline = self.model.db().pipeline() if use_transaction else None pipeline = await self.model.db().pipeline() if use_transaction else None
for model in self.all(): # TODO: async for here?
for model in await self.all():
for field, value in field_values.items(): for field, value in field_values.items():
setattr(model, field, value) setattr(model, field, value)
# TODO: In the non-transaction case, can we do more to detect # TODO: In the non-transaction case, can we do more to detect
# failure responses from Redis? # failure responses from Redis?
model.save(pipeline=pipeline) await model.save(pipeline=pipeline)
if pipeline: if pipeline:
# TODO: Response type? # TODO: Response type?
# TODO: Better error detection for transactions. # TODO: Better error detection for transactions.
pipeline.execute() pipeline.execute()
def delete(self): async def delete(self):
"""Delete all matching records in this query.""" """Delete all matching records in this query."""
# TODO: Better response type, error detection # TODO: Better response type, error detection
return self.model.db().delete(*[m.key() for m in self.all()]) return await self.model.db().delete(*[m.key() for m in await self.all()])
def __iter__(self): async def __aiter__(self):
if self._model_cache: if self._model_cache:
for m in self._model_cache: for m in self._model_cache:
yield m yield m
else: else:
for m in self.execute(): for m in await self.execute():
yield m yield m
def __getitem__(self, item: int): def __getitem__(self, item: int):
@ -784,12 +786,39 @@ class FindQuery:
that result, then we should clone the current query and that result, then we should clone the current query and
give it a new offset and limit: offset=n, limit=1. give it a new offset and limit: offset=n, limit=1.
""" """
if ASYNC_MODE:
raise QuerySyntaxError("Cannot use [] notation with async code. "
"Use FindQuery.get_item() instead.")
if self._model_cache and len(self._model_cache) >= item: if self._model_cache and len(self._model_cache) >= item:
return self._model_cache[item] return self._model_cache[item]
query = self.copy(offset=item, limit=1) query = self.copy(offset=item, limit=1)
return query.execute()[0] return query.execute()[0] # noqa
async def get_item(self, item: int):
"""
Given this code:
await Model.find().get_item(1000)
We should return only the 1000th result.
1. If the result is loaded in the query cache for this query,
we can return it directly from the cache.
2. If the query cache does not have enough elements to return
that result, then we should clone the current query and
give it a new offset and limit: offset=n, limit=1.
NOTE: This method is included specifically for async users, who
cannot use the notation Model.find()[1000].
"""
if self._model_cache and len(self._model_cache) >= item:
return self._model_cache[item]
query = self.copy(offset=item, limit=1)
result = await query.execute()
return result[0]
class PrimaryKeyCreator(Protocol): class PrimaryKeyCreator(Protocol):
@ -913,7 +942,7 @@ class MetaProtocol(Protocol):
global_key_prefix: str global_key_prefix: str
model_key_prefix: str model_key_prefix: str
primary_key_pattern: str primary_key_pattern: str
database: redis.Redis database: aioredis.Redis
primary_key: PrimaryKey primary_key: PrimaryKey
primary_key_creator_cls: Type[PrimaryKeyCreator] primary_key_creator_cls: Type[PrimaryKeyCreator]
index_name: str index_name: str
@ -932,7 +961,7 @@ class DefaultMeta:
global_key_prefix: Optional[str] = None global_key_prefix: Optional[str] = None
model_key_prefix: Optional[str] = None model_key_prefix: Optional[str] = None
primary_key_pattern: Optional[str] = None primary_key_pattern: Optional[str] = None
database: Optional[redis.Redis] = None database: Optional[Union[redis.Redis, aioredis.Redis]] = None
primary_key: Optional[PrimaryKey] = None primary_key: Optional[PrimaryKey] = None
primary_key_creator_cls: Optional[Type[PrimaryKeyCreator]] = None primary_key_creator_cls: Optional[Type[PrimaryKeyCreator]] = None
index_name: Optional[str] = None index_name: Optional[str] = None
@ -1049,14 +1078,18 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
pk = getattr(self, self._meta.primary_key.field.name) pk = getattr(self, self._meta.primary_key.field.name)
return self.make_primary_key(pk) return self.make_primary_key(pk)
def delete(self): async def delete(self):
return self.db().delete(self.key()) return await self.db().delete(self.key())
def update(self, **field_values): @classmethod
async def get(cls, pk: Any) -> 'RedisModel':
raise NotImplementedError
async def update(self, **field_values):
"""Update this model instance with the specified key-value pairs.""" """Update this model instance with the specified key-value pairs."""
raise NotImplementedError raise NotImplementedError
def save(self, pipeline: Optional[Pipeline] = None) -> "RedisModel": async def save(self, pipeline: Optional[Pipeline] = None) -> "RedisModel":
raise NotImplementedError raise NotImplementedError
@validator("pk", always=True) @validator("pk", always=True)
@ -1158,9 +1191,9 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
return d return d
@classmethod @classmethod
def add(cls, models: Sequence["RedisModel"]) -> Sequence["RedisModel"]: async def add(cls, models: Sequence["RedisModel"]) -> Sequence["RedisModel"]:
# TODO: Add transaction support # TODO: Add transaction support
return [model.save() for model in models] return [await model.save() for model in models]
@classmethod @classmethod
def values(cls): def values(cls):
@ -1189,17 +1222,18 @@ class HashModel(RedisModel, abc.ABC):
f" or mapping fields. Field: {name}" f" or mapping fields. Field: {name}"
) )
def save(self, pipeline: Optional[Pipeline] = None) -> "HashModel": async def save(self, pipeline: Optional[Pipeline] = None) -> "HashModel":
if pipeline is None: if pipeline is None:
db = self.db() db = self.db()
else: else:
db = pipeline db = pipeline
document = jsonable_encoder(self.dict()) document = jsonable_encoder(self.dict())
db.hset(self.key(), mapping=document) # TODO: Wrap any Redis response errors in a custom exception?
await db.hset(self.key(), mapping=document)
return self return self
@classmethod @classmethod
def get(cls, pk: Any) -> "HashModel": async def get(cls, pk: Any) -> "HashModel":
document = cls.db().hgetall(cls.make_primary_key(pk)) document = cls.db().hgetall(cls.make_primary_key(pk))
if not document: if not document:
raise NotFoundError raise NotFoundError
@ -1311,23 +1345,24 @@ class JsonModel(RedisModel, abc.ABC):
# Generate the RediSearch schema once to validate fields. # Generate the RediSearch schema once to validate fields.
cls.redisearch_schema() cls.redisearch_schema()
def save(self, pipeline: Optional[Pipeline] = None) -> "JsonModel": async def save(self, pipeline: Optional[Pipeline] = None) -> "JsonModel":
if pipeline is None: if pipeline is None:
db = self.db() db = self.db()
else: else:
db = pipeline db = pipeline
db.execute_command("JSON.SET", self.key(), ".", self.json()) # TODO: Wrap response errors in a custom exception?
await db.execute_command("JSON.SET", self.key(), ".", self.json())
return self return self
def update(self, **field_values): async def update(self, **field_values):
validate_model_fields(self.__class__, field_values) validate_model_fields(self.__class__, field_values)
for field, value in field_values.items(): for field, value in field_values.items():
setattr(self, field, value) setattr(self, field, value)
self.save() await self.save()
@classmethod @classmethod
def get(cls, pk: Any) -> "JsonModel": async def get(cls, pk: Any) -> "JsonModel":
document = cls.db().execute_command("JSON.GET", cls.make_primary_key(pk)) document = await cls.db().execute_command("JSON.GET", cls.make_primary_key(pk))
if not document: if not document:
raise NotFoundError raise NotFoundError
return cls.parse_raw(document) return cls.parse_raw(document)

View file

@ -1,17 +1,17 @@
import abc import abc
from typing import Optional from typing import Optional
from redis_developer.model.model import HashModel, JsonModel from redis_om.model.model import HashModel, JsonModel
class BaseJsonModel(JsonModel, abc.ABC): class BaseJsonModel(JsonModel, abc.ABC):
class Meta: class Meta:
global_key_prefix = "redis-developer" global_key_prefix = "redis-om"
class BaseHashModel(HashModel, abc.ABC): class BaseHashModel(HashModel, abc.ABC):
class Meta: class Meta:
global_key_prefix = "redis-developer" global_key_prefix = "redis-om"
# class AddressJson(BaseJsonModel): # class AddressJson(BaseJsonModel):

View file

@ -1,7 +1,7 @@
from collections import Sequence from collections import Sequence
from typing import Any, Dict, List, Mapping, Union from typing import Any, Dict, List, Mapping, Union
from redis_developer.model.model import Expression from redis_om.model.model import Expression
class LogicalOperatorForListOfExpressions(Expression): class LogicalOperatorForListOfExpressions(Expression):

40
redis_om/unasync_util.py Normal file
View file

@ -0,0 +1,40 @@
"""Set of utility functions for unasync that transform into sync counterparts cleanly"""
import inspect
_original_next = next
def is_async_mode():
"""Tests if we're in the async part of the code or not"""
async def f():
"""Unasync transforms async functions in sync functions"""
return None
obj = f()
if obj is None:
return False
else:
obj.close() # prevent unawaited coroutine warning
return True
ASYNC_MODE = is_async_mode()
async def anext(x):
return await x.__anext__()
async def await_if_coro(x):
if inspect.iscoroutine(x):
return await x
return x
next = _original_next
def return_non_coro(x):
return x

View file

@ -1,19 +1,18 @@
import random import random
import pytest import pytest
from redis import Redis
from redis_developer.connections import get_redis_connection from redis_om.connections import get_redis_connection
@pytest.fixture @pytest.fixture
def redis(): def redis(event_loop):
yield get_redis_connection() yield get_redis_connection()
def _delete_test_keys(prefix: str, conn: Redis): async def _delete_test_keys(prefix: str, conn):
keys = [] keys = []
for key in conn.scan_iter(f"{prefix}:*"): async for key in conn.scan_iter(f"{prefix}:*"):
keys.append(key) keys.append(key)
if keys: if keys:
conn.delete(*keys) conn.delete(*keys)
@ -21,11 +20,10 @@ def _delete_test_keys(prefix: str, conn: Redis):
@pytest.fixture @pytest.fixture
def key_prefix(redis): def key_prefix(redis):
key_prefix = f"redis-developer:{random.random()}" key_prefix = f"redis-om:{random.random()}"
yield key_prefix yield key_prefix
_delete_test_keys(key_prefix, redis)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def delete_test_keys(redis, request, key_prefix): async def delete_test_keys(redis, request, key_prefix):
_delete_test_keys(key_prefix, redis) await _delete_test_keys(key_prefix, redis)

View file

@ -8,9 +8,9 @@ from unittest import mock
import pytest import pytest
from pydantic import ValidationError from pydantic import ValidationError
from redis_developer.model import Field, HashModel from redis_om.model import Field, HashModel
from redis_developer.model.migrations.migrator import Migrator from redis_om.model.migrations.migrator import Migrator
from redis_developer.model.model import ( from redis_om.model.model import (
NotFoundError, NotFoundError,
QueryNotSupportedError, QueryNotSupportedError,
RedisModelError, RedisModelError,

View file

@ -1,4 +1,5 @@
import abc import abc
import asyncio
import datetime import datetime
import decimal import decimal
from collections import namedtuple from collections import namedtuple
@ -8,9 +9,9 @@ from unittest import mock
import pytest import pytest
from pydantic import ValidationError from pydantic import ValidationError
from redis_developer.model import EmbeddedJsonModel, Field, JsonModel from redis_om.model import EmbeddedJsonModel, Field, JsonModel
from redis_developer.model.migrations.migrator import Migrator from redis_om.model.migrations.migrator import Migrator
from redis_developer.model.model import ( from redis_om.model.model import (
NotFoundError, NotFoundError,
QueryNotSupportedError, QueryNotSupportedError,
RedisModelError, RedisModelError,
@ -21,7 +22,7 @@ today = datetime.date.today()
@pytest.fixture @pytest.fixture
def m(key_prefix): async def m(key_prefix, redis):
class BaseJsonModel(JsonModel, abc.ABC): class BaseJsonModel(JsonModel, abc.ABC):
class Meta: class Meta:
global_key_prefix = key_prefix global_key_prefix = key_prefix
@ -64,7 +65,7 @@ def m(key_prefix):
# Creates an embedded list of models. # Creates an embedded list of models.
orders: Optional[List[Order]] orders: Optional[List[Order]]
Migrator().run() await Migrator(redis).run()
return namedtuple( return namedtuple(
"Models", ["BaseJsonModel", "Note", "Address", "Item", "Order", "Member"] "Models", ["BaseJsonModel", "Note", "Address", "Item", "Order", "Member"]
@ -83,7 +84,7 @@ def address(m):
@pytest.fixture() @pytest.fixture()
def members(address, m): async def members(address, m):
member1 = m.Member( member1 = m.Member(
first_name="Andrew", first_name="Andrew",
last_name="Brookins", last_name="Brookins",
@ -111,14 +112,15 @@ def members(address, m):
address=address, address=address,
) )
member1.save() await member1.save()
member2.save() await member2.save()
member3.save() await member3.save()
yield member1, member2, member3 yield member1, member2, member3
def test_validates_required_fields(address, m): @pytest.mark.asyncio
async def test_validates_required_fields(address, m):
# Raises ValidationError address is required # Raises ValidationError address is required
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
m.Member( m.Member(
@ -129,7 +131,8 @@ def test_validates_required_fields(address, m):
) )
def test_validates_field(address, m): @pytest.mark.asyncio
async def test_validates_field(address, m):
# Raises ValidationError: join_date is not a date # Raises ValidationError: join_date is not a date
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
m.Member( m.Member(
@ -141,7 +144,8 @@ def test_validates_field(address, m):
# Passes validation # Passes validation
def test_validation_passes(address, m): @pytest.mark.asyncio
async def test_validation_passes(address, m):
member = m.Member( member = m.Member(
first_name="Andrew", first_name="Andrew",
last_name="Brookins", last_name="Brookins",
@ -153,7 +157,10 @@ def test_validation_passes(address, m):
assert member.first_name == "Andrew" assert member.first_name == "Andrew"
def test_saves_model_and_creates_pk(address, m): @pytest.mark.asyncio
async def test_saves_model_and_creates_pk(address, m, redis):
await Migrator(redis).run()
member = m.Member( member = m.Member(
first_name="Andrew", first_name="Andrew",
last_name="Brookins", last_name="Brookins",
@ -163,15 +170,16 @@ def test_saves_model_and_creates_pk(address, m):
address=address, address=address,
) )
# Save a model instance to Redis # Save a model instance to Redis
member.save() await member.save()
member2 = m.Member.get(member.pk) member2 = await m.Member.get(member.pk)
assert member2 == member assert member2 == member
assert member2.address == address assert member2.address == address
@pytest.mark.skip("Not implemented yet") @pytest.mark.skip("Not implemented yet")
def test_saves_many(address, m): @pytest.mark.asyncio
async def test_saves_many(address, m):
members = [ members = [
m.Member( m.Member(
first_name="Andrew", first_name="Andrew",
@ -193,9 +201,16 @@ def test_saves_many(address, m):
m.Member.add(members) m.Member.add(members)
async def save(members):
for m in members:
await m.save()
return members
@pytest.mark.skip("Not ready yet") @pytest.mark.skip("Not ready yet")
def test_updates_a_model(members, m): @pytest.mark.asyncio
member1, member2, member3 = members async def test_updates_a_model(members, m):
member1, member2, member3 = await save(members)
# Or, with an implicit save: # Or, with an implicit save:
member1.update(last_name="Smith") member1.update(last_name="Smith")
@ -213,18 +228,20 @@ def test_updates_a_model(members, m):
) )
def test_paginate_query(members, m): @pytest.mark.asyncio
async def test_paginate_query(members, m):
member1, member2, member3 = members member1, member2, member3 = members
actual = m.Member.find().sort_by("age").all(batch_size=1) actual = await m.Member.find().sort_by("age").all(batch_size=1)
assert actual == [member2, member1, member3] assert actual == [member2, member1, member3]
def test_access_result_by_index_cached(members, m): @pytest.mark.asyncio
async def test_access_result_by_index_cached(members, m):
member1, member2, member3 = members member1, member2, member3 = members
query = m.Member.find().sort_by("age") query = m.Member.find().sort_by("age")
# Load the cache, throw away the result. # Load the cache, throw away the result.
assert query._model_cache == [] assert query._model_cache == []
query.execute() await query.execute()
assert query._model_cache == [member2, member1, member3] assert query._model_cache == [member2, member1, member3]
# Access an item that should be in the cache. # Access an item that should be in the cache.
@ -233,21 +250,23 @@ def test_access_result_by_index_cached(members, m):
assert not mock_db.called assert not mock_db.called
def test_access_result_by_index_not_cached(members, m): @pytest.mark.asyncio
async def test_access_result_by_index_not_cached(members, m):
member1, member2, member3 = members member1, member2, member3 = members
query = m.Member.find().sort_by("age") query = m.Member.find().sort_by("age")
# Assert that we don't have any models in the cache yet -- we # Assert that we don't have any models in the cache yet -- we
# haven't made any requests of Redis. # haven't made any requests of Redis.
assert query._model_cache == [] assert query._model_cache == []
assert query[0] == member2 assert query.get_item(0) == member2
assert query[1] == member1 assert query.get_item(1) == member1
assert query[2] == member3 assert query.get_item(2) == member3
def test_in_query(members, m): @pytest.mark.asyncio
async def test_in_query(members, m):
member1, member2, member3 = members member1, member2, member3 = members
actual = ( actual = await (
m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk]) m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk])
.sort_by("age") .sort_by("age")
.all() .all()
@ -256,12 +275,13 @@ def test_in_query(members, m):
@pytest.mark.skip("Not implemented yet") @pytest.mark.skip("Not implemented yet")
def test_update_query(members, m): @pytest.mark.asyncio
async def test_update_query(members, m):
member1, member2, member3 = members member1, member2, member3 = members
m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk]).update( await m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk]).update(
first_name="Bobby" first_name="Bobby"
) )
actual = ( actual = await (
m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk]) m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk])
.sort_by("age") .sort_by("age")
.all() .all()
@ -270,24 +290,25 @@ def test_update_query(members, m):
assert all([m.name == "Bobby" for m in actual]) assert all([m.name == "Bobby" for m in actual])
def test_exact_match_queries(members, m): @pytest.mark.asyncio
async def test_exact_match_queries(members, m):
member1, member2, member3 = members member1, member2, member3 = members
actual = m.Member.find(m.Member.last_name == "Brookins").sort_by("age").all() actual = await m.Member.find(m.Member.last_name == "Brookins").sort_by("age").all()
assert actual == [member2, member1] assert actual == [member2, member1]
actual = m.Member.find( actual = await m.Member.find(
(m.Member.last_name == "Brookins") & ~(m.Member.first_name == "Andrew") (m.Member.last_name == "Brookins") & ~(m.Member.first_name == "Andrew")
).all() ).all()
assert actual == [member2] assert actual == [member2]
actual = m.Member.find(~(m.Member.last_name == "Brookins")).all() actual = await m.Member.find(~(m.Member.last_name == "Brookins")).all()
assert actual == [member3] assert actual == [member3]
actual = m.Member.find(m.Member.last_name != "Brookins").all() actual = await m.Member.find(m.Member.last_name != "Brookins").all()
assert actual == [member3] assert actual == [member3]
actual = ( actual = await (
m.Member.find( m.Member.find(
(m.Member.last_name == "Brookins") & (m.Member.first_name == "Andrew") (m.Member.last_name == "Brookins") & (m.Member.first_name == "Andrew")
| (m.Member.first_name == "Kim") | (m.Member.first_name == "Kim")
@ -297,19 +318,20 @@ def test_exact_match_queries(members, m):
) )
assert actual == [member2, member1] assert actual == [member2, member1]
actual = m.Member.find( actual = await m.Member.find(
m.Member.first_name == "Kim", m.Member.last_name == "Brookins" m.Member.first_name == "Kim", m.Member.last_name == "Brookins"
).all() ).all()
assert actual == [member2] assert actual == [member2]
actual = m.Member.find(m.Member.address.city == "Portland").sort_by("age").all() actual = await m.Member.find(m.Member.address.city == "Portland").sort_by("age").all()
assert actual == [member2, member1, member3] assert actual == [member2, member1, member3]
def test_recursive_query_expression_resolution(members, m): @pytest.mark.asyncio
async def test_recursive_query_expression_resolution(members, m):
member1, member2, member3 = members member1, member2, member3 = members
actual = ( actual = await (
m.Member.find( m.Member.find(
(m.Member.last_name == "Brookins") (m.Member.last_name == "Brookins")
| (m.Member.age == 100) & (m.Member.last_name == "Smith") | (m.Member.age == 100) & (m.Member.last_name == "Smith")
@ -320,13 +342,14 @@ def test_recursive_query_expression_resolution(members, m):
assert actual == [member2, member1, member3] assert actual == [member2, member1, member3]
def test_recursive_query_field_resolution(members, m): @pytest.mark.asyncio
async def test_recursive_query_field_resolution(members, m):
member1, _, _ = members member1, _, _ = members
member1.address.note = m.Note( member1.address.note = m.Note(
description="Weird house", created_on=datetime.datetime.now() description="Weird house", created_on=datetime.datetime.now()
) )
member1.save() await member1.save()
actual = m.Member.find(m.Member.address.note.description == "Weird house").all() actual = await m.Member.find(m.Member.address.note.description == "Weird house").all()
assert actual == [member1] assert actual == [member1]
member1.orders = [ member1.orders = [
@ -336,29 +359,31 @@ def test_recursive_query_field_resolution(members, m):
created_on=datetime.datetime.now(), created_on=datetime.datetime.now(),
) )
] ]
member1.save() await member1.save()
actual = m.Member.find(m.Member.orders.items.name == "Ball").all() actual = await m.Member.find(m.Member.orders.items.name == "Ball").all()
assert actual == [member1] assert actual == [member1]
assert actual[0].orders[0].items[0].name == "Ball" assert actual[0].orders[0].items[0].name == "Ball"
def test_full_text_search(members, m): @pytest.mark.asyncio
async def test_full_text_search(members, m):
member1, member2, _ = members member1, member2, _ = members
member1.update(bio="Hates sunsets, likes beaches") await member1.update(bio="Hates sunsets, likes beaches")
member2.update(bio="Hates beaches, likes forests") await member2.update(bio="Hates beaches, likes forests")
actual = m.Member.find(m.Member.bio % "beaches").sort_by("age").all() actual = await m.Member.find(m.Member.bio % "beaches").sort_by("age").all()
assert actual == [member2, member1] assert actual == [member2, member1]
actual = m.Member.find(m.Member.bio % "forests").all() actual = await m.Member.find(m.Member.bio % "forests").all()
assert actual == [member2] assert actual == [member2]
def test_tag_queries_boolean_logic(members, m): @pytest.mark.asyncio
async def test_tag_queries_boolean_logic(members, m):
member1, member2, member3 = members member1, member2, member3 = members
actual = ( actual = (
m.Member.find( await m.Member.find(
(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins") (m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
| (m.Member.last_name == "Smith") | (m.Member.last_name == "Smith")
) )
@ -368,7 +393,8 @@ def test_tag_queries_boolean_logic(members, m):
assert actual == [member1, member3] assert actual == [member1, member3]
def test_tag_queries_punctuation(address, m): @pytest.mark.asyncio
async def test_tag_queries_punctuation(address, m):
member1 = m.Member( member1 = m.Member(
first_name="Andrew, the Michael", first_name="Andrew, the Michael",
last_name="St. Brookins-on-Pier", last_name="St. Brookins-on-Pier",
@ -377,7 +403,7 @@ def test_tag_queries_punctuation(address, m):
join_date=today, join_date=today,
address=address, address=address,
) )
member1.save() await member1.save()
member2 = m.Member( member2 = m.Member(
first_name="Bob", first_name="Bob",
@ -387,24 +413,25 @@ def test_tag_queries_punctuation(address, m):
join_date=today, join_date=today,
address=address, address=address,
) )
member2.save() await member2.save()
assert ( assert (
m.Member.find(m.Member.first_name == "Andrew, the Michael").first() == member1 await m.Member.find(m.Member.first_name == "Andrew, the Michael").first() == member1
) )
assert ( assert (
m.Member.find(m.Member.last_name == "St. Brookins-on-Pier").first() == member1 await m.Member.find(m.Member.last_name == "St. Brookins-on-Pier").first() == member1
) )
# Notice that when we index and query multiple values that use the internal # Notice that when we index and query multiple values that use the internal
# TAG separator for single-value exact-match fields, like an indexed string, # TAG separator for single-value exact-match fields, like an indexed string,
# the queries will succeed. We apply a workaround that queries for the union # the queries will succeed. We apply a workaround that queries for the union
# of the two values separated by the tag separator. # of the two values separated by the tag separator.
assert m.Member.find(m.Member.email == "a|b@example.com").all() == [member1] assert await m.Member.find(m.Member.email == "a|b@example.com").all() == [member1]
assert m.Member.find(m.Member.email == "a|villain@example.com").all() == [member2] assert await m.Member.find(m.Member.email == "a|villain@example.com").all() == [member2]
def test_tag_queries_negation(members, m): @pytest.mark.asyncio
async def test_tag_queries_negation(members, m):
member1, member2, member3 = members member1, member2, member3 = members
""" """
@ -414,7 +441,7 @@ def test_tag_queries_negation(members, m):
""" """
query = m.Member.find(~(m.Member.first_name == "Andrew")) query = m.Member.find(~(m.Member.first_name == "Andrew"))
assert query.all() == [member2] assert await query.all() == [member2]
""" """
first_name first_name
@ -429,7 +456,7 @@ def test_tag_queries_negation(members, m):
query = m.Member.find( query = m.Member.find(
~(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins") ~(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
) )
assert query.all() == [member2] assert await query.all() == [member2]
""" """
first_name first_name
@ -448,7 +475,7 @@ def test_tag_queries_negation(members, m):
~(m.Member.first_name == "Andrew") ~(m.Member.first_name == "Andrew")
& ((m.Member.last_name == "Brookins") | (m.Member.last_name == "Smith")) & ((m.Member.last_name == "Brookins") | (m.Member.last_name == "Smith"))
) )
assert query.all() == [member2] assert await query.all() == [member2]
""" """
first_name first_name
@ -467,67 +494,71 @@ def test_tag_queries_negation(members, m):
~(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins") ~(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
| (m.Member.last_name == "Smith") | (m.Member.last_name == "Smith")
) )
assert query.sort_by("age").all() == [member2, member3] assert await query.sort_by("age").all() == [member2, member3]
actual = m.Member.find( actual = await m.Member.find(
(m.Member.first_name == "Andrew") & ~(m.Member.last_name == "Brookins") (m.Member.first_name == "Andrew") & ~(m.Member.last_name == "Brookins")
).all() ).all()
assert actual == [member3] assert actual == [member3]
def test_numeric_queries(members, m): @pytest.mark.asyncio
async def test_numeric_queries(members, m):
member1, member2, member3 = members member1, member2, member3 = members
actual = m.Member.find(m.Member.age == 34).all() actual = await m.Member.find(m.Member.age == 34).all()
assert actual == [member2] assert actual == [member2]
actual = m.Member.find(m.Member.age > 34).all() actual = await m.Member.find(m.Member.age > 34).all()
assert actual == [member1, member3] assert actual == [member1, member3]
actual = m.Member.find(m.Member.age < 35).all() actual = await m.Member.find(m.Member.age < 35).all()
assert actual == [member2] assert actual == [member2]
actual = m.Member.find(m.Member.age <= 34).all() actual = await m.Member.find(m.Member.age <= 34).all()
assert actual == [member2] assert actual == [member2]
actual = m.Member.find(m.Member.age >= 100).all() actual = await m.Member.find(m.Member.age >= 100).all()
assert actual == [member3] assert actual == [member3]
actual = m.Member.find(~(m.Member.age == 100)).sort_by("age").all() actual = await m.Member.find(~(m.Member.age == 100)).sort_by("age").all()
assert actual == [member2, member1] assert actual == [member2, member1]
actual = m.Member.find(m.Member.age > 30, m.Member.age < 40).sort_by("age").all() actual = await m.Member.find(m.Member.age > 30, m.Member.age < 40).sort_by("age").all()
assert actual == [member2, member1] assert actual == [member2, member1]
actual = m.Member.find(m.Member.age != 34).sort_by("age").all() actual = await m.Member.find(m.Member.age != 34).sort_by("age").all()
assert actual == [member1, member3] assert actual == [member1, member3]
def test_sorting(members, m): @pytest.mark.asyncio
async def test_sorting(members, m):
member1, member2, member3 = members member1, member2, member3 = members
actual = m.Member.find(m.Member.age > 34).sort_by("age").all() actual = await m.Member.find(m.Member.age > 34).sort_by("age").all()
assert actual == [member1, member3] assert actual == [member1, member3]
actual = m.Member.find(m.Member.age > 34).sort_by("-age").all() actual = await m.Member.find(m.Member.age > 34).sort_by("-age").all()
assert actual == [member3, member1] assert actual == [member3, member1]
with pytest.raises(QueryNotSupportedError): with pytest.raises(QueryNotSupportedError):
# This field does not exist. # This field does not exist.
m.Member.find().sort_by("not-a-real-field").all() await m.Member.find().sort_by("not-a-real-field").all()
with pytest.raises(QueryNotSupportedError): with pytest.raises(QueryNotSupportedError):
# This field is not sortable. # This field is not sortable.
m.Member.find().sort_by("join_date").all() await m.Member.find().sort_by("join_date").all()
def test_not_found(m): @pytest.mark.asyncio
async def test_not_found(m):
with pytest.raises(NotFoundError): with pytest.raises(NotFoundError):
# This ID does not exist. # This ID does not exist.
m.Member.get(1000) await m.Member.get(1000)
def test_list_field_limitations(m): @pytest.mark.asyncio
async def test_list_field_limitations(m, redis):
with pytest.raises(RedisModelError): with pytest.raises(RedisModelError):
class SortableTarotWitch(m.BaseJsonModel): class SortableTarotWitch(m.BaseJsonModel):
@ -571,15 +602,16 @@ def test_list_field_limitations(m):
# We need to import and run this manually because we defined # We need to import and run this manually because we defined
# our model classes within a function that runs after the test # our model classes within a function that runs after the test
# suite's migrator has already looked for migrations to run. # suite's migrator has already looked for migrations to run.
Migrator().run() await Migrator(redis).run()
witch = TarotWitch(tarot_cards=["death"]) witch = TarotWitch(tarot_cards=["death"])
witch.save() await witch.save()
actual = TarotWitch.find(TarotWitch.tarot_cards << "death").all() actual = await TarotWitch.find(TarotWitch.tarot_cards << "death").all()
assert actual == [witch] assert actual == [witch]
def test_schema(m, key_prefix): @pytest.mark.asyncio
async def test_schema(m, key_prefix):
assert ( assert (
m.Member.redisearch_schema() m.Member.redisearch_schema()
== f"ON JSON PREFIX 1 {key_prefix}:tests.test_json_model.Member: SCHEMA $.pk AS pk TAG SEPARATOR | $.first_name AS first_name TAG SEPARATOR | $.last_name AS last_name TAG SEPARATOR | $.email AS email TAG SEPARATOR | $.age AS age NUMERIC $.bio AS bio TAG SEPARATOR | $.bio AS bio_fts TEXT $.address.pk AS address_pk TAG SEPARATOR | $.address.city AS address_city TAG SEPARATOR | $.address.postal_code AS address_postal_code TAG SEPARATOR | $.address.note.pk AS address_note_pk TAG SEPARATOR | $.address.note.description AS address_note_description TAG SEPARATOR | $.orders[*].pk AS orders_pk TAG SEPARATOR | $.orders[*].items[*].pk AS orders_items_pk TAG SEPARATOR | $.orders[*].items[*].name AS orders_items_name TAG SEPARATOR |" == f"ON JSON PREFIX 1 {key_prefix}:tests.test_json_model.Member: SCHEMA $.pk AS pk TAG SEPARATOR | $.first_name AS first_name TAG SEPARATOR | $.last_name AS last_name TAG SEPARATOR | $.email AS email TAG SEPARATOR | $.age AS age NUMERIC $.bio AS bio TAG SEPARATOR | $.bio AS bio_fts TEXT $.address.pk AS address_pk TAG SEPARATOR | $.address.city AS address_city TAG SEPARATOR | $.address.postal_code AS address_postal_code TAG SEPARATOR | $.address.note.pk AS address_note_pk TAG SEPARATOR | $.address.note.description AS address_note_description TAG SEPARATOR | $.orders[*].pk AS orders_pk TAG SEPARATOR | $.orders[*].items[*].pk AS orders_items_pk TAG SEPARATOR | $.orders[*].items[*].name AS orders_items_name TAG SEPARATOR |"