WIP on async - test failure due to closed event loop
This commit is contained in:
parent
0f9f7aa868
commit
b2c2dd9f6f
22 changed files with 348 additions and 190 deletions
2
Makefile
2
Makefile
|
@ -1,4 +1,4 @@
|
|||
NAME := redis_developer
|
||||
NAME := redis_om
|
||||
INSTALL_STAMP := .install.stamp
|
||||
POETRY := $(shell command -v poetry 2> /dev/null)
|
||||
|
||||
|
|
18
README.md
18
README.md
|
@ -52,7 +52,7 @@ Check out this example:
|
|||
import datetime
|
||||
from typing import Optional
|
||||
|
||||
from redis_developer.model import (
|
||||
from redis_om.model import (
|
||||
EmbeddedJsonModel,
|
||||
JsonModel,
|
||||
Field,
|
||||
|
@ -172,9 +172,9 @@ Don't want to run Redis yourself? RediSearch and RedisJSON are also available on
|
|||
|
||||
We'd love your contributions!
|
||||
|
||||
**Bug reports** are especially helpful at this stage of the project. [You can open a bug report on GitHub](https://github.com/redis-developer/redis-developer-python/issues/new).
|
||||
**Bug reports** are especially helpful at this stage of the project. [You can open a bug report on GitHub](https://github.com/redis-om/redis-om-python/issues/new).
|
||||
|
||||
You can also **contribute documentation** -- or just let us know if something needs more detail. [Open an issue on GitHub](https://github.com/redis-developer/redis-developer-python/issues/new) to get started.
|
||||
You can also **contribute documentation** -- or just let us know if something needs more detail. [Open an issue on GitHub](https://github.com/redis-om/redis-om-python/issues/new) to get started.
|
||||
|
||||
## License
|
||||
|
||||
|
@ -184,17 +184,17 @@ Redis OM is [MIT licensed][license-url].
|
|||
|
||||
[version-svg]: https://img.shields.io/pypi/v/redis-om?style=flat-square
|
||||
[package-url]: https://pypi.org/project/redis-om/
|
||||
[ci-svg]: https://img.shields.io/github/workflow/status/redis-developer/redis-developer-python/python?style=flat-square
|
||||
[ci-url]: https://github.com/redis-developer/redis-developer-python/actions/workflows/build.yml
|
||||
[ci-svg]: https://img.shields.io/github/workflow/status/redis-om/redis-om-python/python?style=flat-square
|
||||
[ci-url]: https://github.com/redis-om/redis-om-python/actions/workflows/build.yml
|
||||
[license-image]: http://img.shields.io/badge/license-MIT-green.svg?style=flat-square
|
||||
[license-url]: LICENSE
|
||||
|
||||
<!-- Links -->
|
||||
|
||||
[redis-developer-website]: https://developer.redis.com
|
||||
[redis-om-js]: https://github.com/redis-developer/redis-om-js
|
||||
[redis-om-dotnet]: https://github.com/redis-developer/redis-om-dotnet
|
||||
[redis-om-spring]: https://github.com/redis-developer/redis-om-spring
|
||||
[redis-om-website]: https://developer.redis.com
|
||||
[redis-om-js]: https://github.com/redis-om/redis-om-js
|
||||
[redis-om-dotnet]: https://github.com/redis-om/redis-om-dotnet
|
||||
[redis-om-spring]: https://github.com/redis-om/redis-om-spring
|
||||
[redisearch-url]: https://oss.redis.com/redisearch/
|
||||
[redis-json-url]: https://oss.redis.com/redisjson/
|
||||
[pydantic-url]: https://github.com/samuelcolvin/pydantic
|
||||
|
|
10
build.py
Normal file
10
build.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
import unasync
|
||||
|
||||
|
||||
def build(setup_kwargs):
|
||||
setup_kwargs.update(
|
||||
{"cmdclass": {'build_py': unasync.cmdclass_build_py(rules=[
|
||||
unasync.Rule("/aredis_om/", "/redis_om/"),
|
||||
unasync.Rule("/aredis_om/tests/", "/redis_om/tests/", additional_replacements={"aredis_om": "redis_om"}),
|
||||
])}}
|
||||
)
|
32
poetry.lock
generated
32
poetry.lock
generated
|
@ -538,6 +538,20 @@ toml = "*"
|
|||
[package.extras]
|
||||
testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-asyncio"
|
||||
version = "0.16.0"
|
||||
description = "Pytest support for asyncio."
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">= 3.6"
|
||||
|
||||
[package.dependencies]
|
||||
pytest = ">=5.4.0"
|
||||
|
||||
[package.extras]
|
||||
testing = ["coverage", "hypothesis (>=5.7.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-cov"
|
||||
version = "3.0.0"
|
||||
|
@ -707,6 +721,14 @@ category = "main"
|
|||
optional = false
|
||||
python-versions = "*"
|
||||
|
||||
[[package]]
|
||||
name = "unasync"
|
||||
version = "0.5.0"
|
||||
description = "The async transformation code."
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4"
|
||||
|
||||
[[package]]
|
||||
name = "wcwidth"
|
||||
version = "0.2.5"
|
||||
|
@ -726,7 +748,7 @@ python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7"
|
|||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = "^3.8"
|
||||
content-hash = "56b381dd9b79bd082e978019124176491c63f09dd5ce90e5f8ab642a7f79480f"
|
||||
content-hash = "d2d83b8cd3b094879e1aeb058d0036203942143f12fafa8be03fb0c79460028f"
|
||||
|
||||
[metadata.files]
|
||||
aioredis = [
|
||||
|
@ -1003,6 +1025,10 @@ pytest = [
|
|||
{file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"},
|
||||
{file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"},
|
||||
]
|
||||
pytest-asyncio = [
|
||||
{file = "pytest-asyncio-0.16.0.tar.gz", hash = "sha256:7496c5977ce88c34379df64a66459fe395cd05543f0a2f837016e7144391fcfb"},
|
||||
{file = "pytest_asyncio-0.16.0-py3-none-any.whl", hash = "sha256:5f2a21273c47b331ae6aa5b36087047b4899e40f03f18397c0e65fa5cca54e9b"},
|
||||
]
|
||||
pytest-cov = [
|
||||
{file = "pytest-cov-3.0.0.tar.gz", hash = "sha256:e7f0f5b1617d2210a2cabc266dfe2f4c75a8d32fb89eafb7ad9d06f6d076d470"},
|
||||
{file = "pytest_cov-3.0.0-py3-none-any.whl", hash = "sha256:578d5d15ac4a25e5f961c938b85a05b09fdaae9deef3bb6de9a6e766622ca7a6"},
|
||||
|
@ -1148,6 +1174,10 @@ typing-extensions = [
|
|||
{file = "typing_extensions-3.10.0.2-py3-none-any.whl", hash = "sha256:f1d25edafde516b146ecd0613dabcc61409817af4766fbbcfb8d1ad4ec441a34"},
|
||||
{file = "typing_extensions-3.10.0.2.tar.gz", hash = "sha256:49f75d16ff11f1cd258e1b988ccff82a3ca5570217d7ad8c5f48205dd99a677e"},
|
||||
]
|
||||
unasync = [
|
||||
{file = "unasync-0.5.0-py3-none-any.whl", hash = "sha256:8d4536dae85e87b8751dfcc776f7656fd0baf54bb022a7889440dc1b9dc3becb"},
|
||||
{file = "unasync-0.5.0.tar.gz", hash = "sha256:b675d87cf56da68bd065d3b7a67ac71df85591978d84c53083c20d79a7e5096d"},
|
||||
]
|
||||
wcwidth = [
|
||||
{file = "wcwidth-0.2.5-py2.py3-none-any.whl", hash = "sha256:beb4802a9cebb9144e99086eff703a642a13d6a0052920003a230f3294bbe784"},
|
||||
{file = "wcwidth-0.2.5.tar.gz", hash = "sha256:c4d647b99872929fdb7bdcaa4fbe7f01413ed3d98077df798530e5b04f116c83"},
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
[tool.poetry]
|
||||
name = "redis-developer"
|
||||
name = "redis-om"
|
||||
version = "0.1.0"
|
||||
description = "A high-level library containing useful Redis abstractions and tools, like an ORM and leaderboard."
|
||||
authors = ["Andrew Brookins <andrew.brookins@redislabs.com>"]
|
||||
license = "MIT"
|
||||
build = "build.py"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.8"
|
||||
|
@ -30,10 +31,11 @@ bandit = "^1.7.0"
|
|||
coverage = "^6.0.2"
|
||||
pytest-cov = "^3.0.0"
|
||||
pytest-xdist = "^2.4.0"
|
||||
|
||||
unasync = "^0.5.0"
|
||||
pytest-asyncio = "^0.16.0"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
migrate = "redis_developer.orm.cli.migrate:migrate"
|
||||
migrate = "redis_om.orm.cli.migrate:migrate"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
|
|
|
@ -1,22 +1,28 @@
|
|||
import os
|
||||
from typing import Union
|
||||
|
||||
import dotenv
|
||||
import aioredis
|
||||
import redis
|
||||
|
||||
from redis_om.unasync_util import ASYNC_MODE
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
URL = os.environ.get("REDIS_OM_URL", None)
|
||||
if ASYNC_MODE:
|
||||
client = aioredis.Redis
|
||||
else:
|
||||
client = redis.Redis
|
||||
|
||||
|
||||
def get_redis_connection(**kwargs) -> redis.Redis:
|
||||
def get_redis_connection(**kwargs) -> Union[aioredis.Redis, redis.Redis]:
|
||||
# If someone passed in a 'url' parameter, or specified a REDIS_OM_URL
|
||||
# environment variable, we'll create the Redis client from the URL.
|
||||
url = kwargs.pop("url", URL)
|
||||
if url:
|
||||
return redis.from_url(url, **kwargs)
|
||||
return client.from_url(url, **kwargs)
|
||||
|
||||
# Decode from UTF-8 by default
|
||||
if "decode_responses" not in kwargs:
|
||||
kwargs["decode_responses"] = True
|
||||
return redis.Redis(**kwargs)
|
||||
return client(**kwargs)
|
|
@ -1,10 +1,10 @@
|
|||
import click
|
||||
|
||||
from redis_developer.model.migrations.migrator import Migrator
|
||||
from redis_om.model.migrations.migrator import Migrator
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--module", default="redis_developer")
|
||||
@click.option("--module", default="redis_om")
|
||||
def migrate(module):
|
||||
migrator = Migrator(module)
|
||||
|
|
@ -2,15 +2,14 @@ import hashlib
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from redis import ResponseError
|
||||
from redis import ResponseError, Redis
|
||||
from aioredis import ResponseError as AResponseError, Redis as ARedis
|
||||
|
||||
from redis_developer.connections import get_redis_connection
|
||||
from redis_developer.model.model import model_registry
|
||||
from redis_om.model.model import model_registry
|
||||
|
||||
|
||||
redis = get_redis_connection()
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -43,12 +42,12 @@ def schema_hash_key(index_name):
|
|||
return f"{index_name}:hash"
|
||||
|
||||
|
||||
def create_index(index_name, schema, current_hash):
|
||||
async def create_index(redis: Union[Redis, ARedis], index_name, schema, current_hash):
|
||||
try:
|
||||
redis.execute_command(f"ft.info {index_name}")
|
||||
except ResponseError:
|
||||
redis.execute_command(f"ft.create {index_name} {schema}")
|
||||
redis.set(schema_hash_key(index_name), current_hash)
|
||||
await redis.execute_command(f"ft.info {index_name}")
|
||||
except (ResponseError, AResponseError):
|
||||
await redis.execute_command(f"ft.create {index_name} {schema}")
|
||||
await redis.set(schema_hash_key(index_name), current_hash)
|
||||
else:
|
||||
log.info("Index already exists, skipping. Index hash: %s", index_name)
|
||||
|
||||
|
@ -65,34 +64,38 @@ class IndexMigration:
|
|||
schema: str
|
||||
hash: str
|
||||
action: MigrationAction
|
||||
redis: Union[Redis, ARedis]
|
||||
previous_hash: Optional[str] = None
|
||||
|
||||
def run(self):
|
||||
async def run(self):
|
||||
if self.action is MigrationAction.CREATE:
|
||||
self.create()
|
||||
await self.create()
|
||||
elif self.action is MigrationAction.DROP:
|
||||
self.drop()
|
||||
await self.drop()
|
||||
|
||||
def create(self):
|
||||
async def create(self):
|
||||
try:
|
||||
return create_index(self.index_name, self.schema, self.hash)
|
||||
await create_index(self.redis, self.index_name, self.schema, self.hash)
|
||||
except ResponseError:
|
||||
log.info("Index already exists: %s", self.index_name)
|
||||
|
||||
def drop(self):
|
||||
async def drop(self):
|
||||
try:
|
||||
redis.execute_command(f"FT.DROPINDEX {self.index_name}")
|
||||
await self.redis.execute_command(f"FT.DROPINDEX {self.index_name}")
|
||||
except ResponseError:
|
||||
log.info("Index does not exist: %s", self.index_name)
|
||||
|
||||
|
||||
class Migrator:
|
||||
def __init__(self, module=None):
|
||||
# Try to load any modules found under the given path or module name.
|
||||
if module:
|
||||
import_submodules(module)
|
||||
|
||||
def __init__(self, redis: Union[Redis, ARedis], module=None):
|
||||
self.module = module
|
||||
self.migrations = []
|
||||
self.redis = redis
|
||||
|
||||
async def run(self):
|
||||
# Try to load any modules found under the given path or module name.
|
||||
if self.module:
|
||||
import_submodules(self.module)
|
||||
|
||||
for name, cls in model_registry.items():
|
||||
hash_key = schema_hash_key(cls.Meta.index_name)
|
||||
|
@ -104,8 +107,8 @@ class Migrator:
|
|||
current_hash = hashlib.sha1(schema.encode("utf-8")).hexdigest() # nosec
|
||||
|
||||
try:
|
||||
redis.execute_command("ft.info", cls.Meta.index_name)
|
||||
except ResponseError:
|
||||
await self.redis.execute_command("ft.info", cls.Meta.index_name)
|
||||
except (ResponseError, AResponseError):
|
||||
self.migrations.append(
|
||||
IndexMigration(
|
||||
name,
|
||||
|
@ -113,11 +116,12 @@ class Migrator:
|
|||
schema,
|
||||
current_hash,
|
||||
MigrationAction.CREATE,
|
||||
self.redis
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
stored_hash = redis.get(hash_key)
|
||||
stored_hash = self.redis.get(hash_key)
|
||||
schema_out_of_date = current_hash != stored_hash
|
||||
|
||||
if schema_out_of_date:
|
||||
|
@ -129,7 +133,8 @@ class Migrator:
|
|||
schema,
|
||||
current_hash,
|
||||
MigrationAction.DROP,
|
||||
stored_hash,
|
||||
self.redis,
|
||||
stored_hash
|
||||
)
|
||||
)
|
||||
self.migrations.append(
|
||||
|
@ -139,12 +144,12 @@ class Migrator:
|
|||
schema,
|
||||
current_hash,
|
||||
MigrationAction.CREATE,
|
||||
stored_hash,
|
||||
self.redis,
|
||||
stored_hash
|
||||
)
|
||||
)
|
||||
|
||||
def run(self):
|
||||
# TODO: Migration history
|
||||
# TODO: Dry run with output
|
||||
for migration in self.migrations:
|
||||
migration.run()
|
||||
await migration.run()
|
|
@ -4,7 +4,7 @@ import decimal
|
|||
import json
|
||||
import logging
|
||||
import operator
|
||||
from copy import copy, deepcopy
|
||||
from copy import copy
|
||||
from enum import Enum
|
||||
from functools import reduce
|
||||
from typing import (
|
||||
|
@ -27,6 +27,7 @@ from typing import (
|
|||
no_type_check,
|
||||
)
|
||||
|
||||
import aioredis
|
||||
import redis
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic.fields import FieldInfo as PydanticFieldInfo
|
||||
|
@ -37,11 +38,11 @@ from pydantic.utils import Representation
|
|||
from redis.client import Pipeline
|
||||
from ulid import ULID
|
||||
|
||||
from ..connections import get_redis_connection
|
||||
from redis_om.connections import get_redis_connection
|
||||
from .encoders import jsonable_encoder
|
||||
from .render_tree import render_tree
|
||||
from .token_escaper import TokenEscaper
|
||||
|
||||
from ..unasync_util import ASYNC_MODE
|
||||
|
||||
model_registry = {}
|
||||
_T = TypeVar("_T")
|
||||
|
@ -521,7 +522,7 @@ class FindQuery:
|
|||
# this is not going to work.
|
||||
log.warning(
|
||||
"Your query against the field %s is for a single character, %s, "
|
||||
"that is used internally by redis-developer-python. We must ignore "
|
||||
"that is used internally by redis-om-python. We must ignore "
|
||||
"this portion of the query. Please review your query to find "
|
||||
"an alternative query that uses a string containing more than "
|
||||
"just the character %s.",
|
||||
|
@ -680,7 +681,7 @@ class FindQuery:
|
|||
|
||||
return result
|
||||
|
||||
def execute(self, exhaust_results=True):
|
||||
async def execute(self, exhaust_results=True):
|
||||
args = ["ft.search", self.model.Meta.index_name, self.query, *self.pagination]
|
||||
if self.sort_fields:
|
||||
args += self.resolve_redisearch_sort_fields()
|
||||
|
@ -691,7 +692,7 @@ class FindQuery:
|
|||
|
||||
# If the offset is greater than 0, we're paginating through a result set,
|
||||
# so append the new results to results already in the cache.
|
||||
raw_result = self.model.db().execute_command(*args)
|
||||
raw_result = await self.model.db().execute_command(*args)
|
||||
count = raw_result[0]
|
||||
results = self.model.from_redis(raw_result)
|
||||
self._model_cache += results
|
||||
|
@ -710,31 +711,31 @@ class FindQuery:
|
|||
# Make a query for each pass of the loop, with a new offset equal to the
|
||||
# current offset plus `page_size`, until we stop getting results back.
|
||||
query = query.copy(offset=query.offset + query.page_size)
|
||||
_results = query.execute(exhaust_results=False)
|
||||
_results = await query.execute(exhaust_results=False)
|
||||
if not _results:
|
||||
break
|
||||
self._model_cache += _results
|
||||
return self._model_cache
|
||||
|
||||
def first(self):
|
||||
async def first(self):
|
||||
query = self.copy(offset=0, limit=1, sort_fields=self.sort_fields)
|
||||
results = query.execute()
|
||||
results = await query.execute()
|
||||
if not results:
|
||||
raise NotFoundError()
|
||||
return results[0]
|
||||
|
||||
def all(self, batch_size=10):
|
||||
async def all(self, batch_size=10):
|
||||
if batch_size != self.page_size:
|
||||
query = self.copy(page_size=batch_size, limit=batch_size)
|
||||
return query.execute()
|
||||
return self.execute()
|
||||
return await query.execute()
|
||||
return await self.execute()
|
||||
|
||||
def sort_by(self, *fields: str):
|
||||
if not fields:
|
||||
return self
|
||||
return self.copy(sort_fields=list(fields))
|
||||
|
||||
def update(self, use_transaction=True, **field_values):
|
||||
async def update(self, use_transaction=True, **field_values):
|
||||
"""
|
||||
Update models that match this query to the given field-value pairs.
|
||||
|
||||
|
@ -743,31 +744,32 @@ class FindQuery:
|
|||
given fields.
|
||||
"""
|
||||
validate_model_fields(self.model, field_values)
|
||||
pipeline = self.model.db().pipeline() if use_transaction else None
|
||||
pipeline = await self.model.db().pipeline() if use_transaction else None
|
||||
|
||||
for model in self.all():
|
||||
# TODO: async for here?
|
||||
for model in await self.all():
|
||||
for field, value in field_values.items():
|
||||
setattr(model, field, value)
|
||||
# TODO: In the non-transaction case, can we do more to detect
|
||||
# failure responses from Redis?
|
||||
model.save(pipeline=pipeline)
|
||||
await model.save(pipeline=pipeline)
|
||||
|
||||
if pipeline:
|
||||
# TODO: Response type?
|
||||
# TODO: Better error detection for transactions.
|
||||
pipeline.execute()
|
||||
|
||||
def delete(self):
|
||||
async def delete(self):
|
||||
"""Delete all matching records in this query."""
|
||||
# TODO: Better response type, error detection
|
||||
return self.model.db().delete(*[m.key() for m in self.all()])
|
||||
return await self.model.db().delete(*[m.key() for m in await self.all()])
|
||||
|
||||
def __iter__(self):
|
||||
async def __aiter__(self):
|
||||
if self._model_cache:
|
||||
for m in self._model_cache:
|
||||
yield m
|
||||
else:
|
||||
for m in self.execute():
|
||||
for m in await self.execute():
|
||||
yield m
|
||||
|
||||
def __getitem__(self, item: int):
|
||||
|
@ -784,12 +786,39 @@ class FindQuery:
|
|||
that result, then we should clone the current query and
|
||||
give it a new offset and limit: offset=n, limit=1.
|
||||
"""
|
||||
if ASYNC_MODE:
|
||||
raise QuerySyntaxError("Cannot use [] notation with async code. "
|
||||
"Use FindQuery.get_item() instead.")
|
||||
if self._model_cache and len(self._model_cache) >= item:
|
||||
return self._model_cache[item]
|
||||
|
||||
query = self.copy(offset=item, limit=1)
|
||||
|
||||
return query.execute()[0]
|
||||
return query.execute()[0] # noqa
|
||||
|
||||
async def get_item(self, item: int):
|
||||
"""
|
||||
Given this code:
|
||||
await Model.find().get_item(1000)
|
||||
|
||||
We should return only the 1000th result.
|
||||
|
||||
1. If the result is loaded in the query cache for this query,
|
||||
we can return it directly from the cache.
|
||||
|
||||
2. If the query cache does not have enough elements to return
|
||||
that result, then we should clone the current query and
|
||||
give it a new offset and limit: offset=n, limit=1.
|
||||
|
||||
NOTE: This method is included specifically for async users, who
|
||||
cannot use the notation Model.find()[1000].
|
||||
"""
|
||||
if self._model_cache and len(self._model_cache) >= item:
|
||||
return self._model_cache[item]
|
||||
|
||||
query = self.copy(offset=item, limit=1)
|
||||
result = await query.execute()
|
||||
return result[0]
|
||||
|
||||
|
||||
class PrimaryKeyCreator(Protocol):
|
||||
|
@ -913,7 +942,7 @@ class MetaProtocol(Protocol):
|
|||
global_key_prefix: str
|
||||
model_key_prefix: str
|
||||
primary_key_pattern: str
|
||||
database: redis.Redis
|
||||
database: aioredis.Redis
|
||||
primary_key: PrimaryKey
|
||||
primary_key_creator_cls: Type[PrimaryKeyCreator]
|
||||
index_name: str
|
||||
|
@ -932,7 +961,7 @@ class DefaultMeta:
|
|||
global_key_prefix: Optional[str] = None
|
||||
model_key_prefix: Optional[str] = None
|
||||
primary_key_pattern: Optional[str] = None
|
||||
database: Optional[redis.Redis] = None
|
||||
database: Optional[Union[redis.Redis, aioredis.Redis]] = None
|
||||
primary_key: Optional[PrimaryKey] = None
|
||||
primary_key_creator_cls: Optional[Type[PrimaryKeyCreator]] = None
|
||||
index_name: Optional[str] = None
|
||||
|
@ -1049,14 +1078,18 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
|||
pk = getattr(self, self._meta.primary_key.field.name)
|
||||
return self.make_primary_key(pk)
|
||||
|
||||
def delete(self):
|
||||
return self.db().delete(self.key())
|
||||
async def delete(self):
|
||||
return await self.db().delete(self.key())
|
||||
|
||||
def update(self, **field_values):
|
||||
@classmethod
|
||||
async def get(cls, pk: Any) -> 'RedisModel':
|
||||
raise NotImplementedError
|
||||
|
||||
async def update(self, **field_values):
|
||||
"""Update this model instance with the specified key-value pairs."""
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self, pipeline: Optional[Pipeline] = None) -> "RedisModel":
|
||||
async def save(self, pipeline: Optional[Pipeline] = None) -> "RedisModel":
|
||||
raise NotImplementedError
|
||||
|
||||
@validator("pk", always=True)
|
||||
|
@ -1158,9 +1191,9 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
|||
return d
|
||||
|
||||
@classmethod
|
||||
def add(cls, models: Sequence["RedisModel"]) -> Sequence["RedisModel"]:
|
||||
async def add(cls, models: Sequence["RedisModel"]) -> Sequence["RedisModel"]:
|
||||
# TODO: Add transaction support
|
||||
return [model.save() for model in models]
|
||||
return [await model.save() for model in models]
|
||||
|
||||
@classmethod
|
||||
def values(cls):
|
||||
|
@ -1189,17 +1222,18 @@ class HashModel(RedisModel, abc.ABC):
|
|||
f" or mapping fields. Field: {name}"
|
||||
)
|
||||
|
||||
def save(self, pipeline: Optional[Pipeline] = None) -> "HashModel":
|
||||
async def save(self, pipeline: Optional[Pipeline] = None) -> "HashModel":
|
||||
if pipeline is None:
|
||||
db = self.db()
|
||||
else:
|
||||
db = pipeline
|
||||
document = jsonable_encoder(self.dict())
|
||||
db.hset(self.key(), mapping=document)
|
||||
# TODO: Wrap any Redis response errors in a custom exception?
|
||||
await db.hset(self.key(), mapping=document)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def get(cls, pk: Any) -> "HashModel":
|
||||
async def get(cls, pk: Any) -> "HashModel":
|
||||
document = cls.db().hgetall(cls.make_primary_key(pk))
|
||||
if not document:
|
||||
raise NotFoundError
|
||||
|
@ -1311,23 +1345,24 @@ class JsonModel(RedisModel, abc.ABC):
|
|||
# Generate the RediSearch schema once to validate fields.
|
||||
cls.redisearch_schema()
|
||||
|
||||
def save(self, pipeline: Optional[Pipeline] = None) -> "JsonModel":
|
||||
async def save(self, pipeline: Optional[Pipeline] = None) -> "JsonModel":
|
||||
if pipeline is None:
|
||||
db = self.db()
|
||||
else:
|
||||
db = pipeline
|
||||
db.execute_command("JSON.SET", self.key(), ".", self.json())
|
||||
# TODO: Wrap response errors in a custom exception?
|
||||
await db.execute_command("JSON.SET", self.key(), ".", self.json())
|
||||
return self
|
||||
|
||||
def update(self, **field_values):
|
||||
async def update(self, **field_values):
|
||||
validate_model_fields(self.__class__, field_values)
|
||||
for field, value in field_values.items():
|
||||
setattr(self, field, value)
|
||||
self.save()
|
||||
await self.save()
|
||||
|
||||
@classmethod
|
||||
def get(cls, pk: Any) -> "JsonModel":
|
||||
document = cls.db().execute_command("JSON.GET", cls.make_primary_key(pk))
|
||||
async def get(cls, pk: Any) -> "JsonModel":
|
||||
document = await cls.db().execute_command("JSON.GET", cls.make_primary_key(pk))
|
||||
if not document:
|
||||
raise NotFoundError
|
||||
return cls.parse_raw(document)
|
|
@ -1,17 +1,17 @@
|
|||
import abc
|
||||
from typing import Optional
|
||||
|
||||
from redis_developer.model.model import HashModel, JsonModel
|
||||
from redis_om.model.model import HashModel, JsonModel
|
||||
|
||||
|
||||
class BaseJsonModel(JsonModel, abc.ABC):
|
||||
class Meta:
|
||||
global_key_prefix = "redis-developer"
|
||||
global_key_prefix = "redis-om"
|
||||
|
||||
|
||||
class BaseHashModel(HashModel, abc.ABC):
|
||||
class Meta:
|
||||
global_key_prefix = "redis-developer"
|
||||
global_key_prefix = "redis-om"
|
||||
|
||||
|
||||
# class AddressJson(BaseJsonModel):
|
|
@ -1,7 +1,7 @@
|
|||
from collections import Sequence
|
||||
from typing import Any, Dict, List, Mapping, Union
|
||||
|
||||
from redis_developer.model.model import Expression
|
||||
from redis_om.model.model import Expression
|
||||
|
||||
|
||||
class LogicalOperatorForListOfExpressions(Expression):
|
40
redis_om/unasync_util.py
Normal file
40
redis_om/unasync_util.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
"""Set of utility functions for unasync that transform into sync counterparts cleanly"""
|
||||
|
||||
import inspect
|
||||
|
||||
_original_next = next
|
||||
|
||||
|
||||
def is_async_mode():
|
||||
"""Tests if we're in the async part of the code or not"""
|
||||
|
||||
async def f():
|
||||
"""Unasync transforms async functions in sync functions"""
|
||||
return None
|
||||
|
||||
obj = f()
|
||||
if obj is None:
|
||||
return False
|
||||
else:
|
||||
obj.close() # prevent unawaited coroutine warning
|
||||
return True
|
||||
|
||||
|
||||
ASYNC_MODE = is_async_mode()
|
||||
|
||||
|
||||
async def anext(x):
|
||||
return await x.__anext__()
|
||||
|
||||
|
||||
async def await_if_coro(x):
|
||||
if inspect.iscoroutine(x):
|
||||
return await x
|
||||
return x
|
||||
|
||||
|
||||
next = _original_next
|
||||
|
||||
|
||||
def return_non_coro(x):
|
||||
return x
|
|
@ -1,19 +1,18 @@
|
|||
import random
|
||||
|
||||
import pytest
|
||||
from redis import Redis
|
||||
|
||||
from redis_developer.connections import get_redis_connection
|
||||
from redis_om.connections import get_redis_connection
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def redis():
|
||||
def redis(event_loop):
|
||||
yield get_redis_connection()
|
||||
|
||||
|
||||
def _delete_test_keys(prefix: str, conn: Redis):
|
||||
async def _delete_test_keys(prefix: str, conn):
|
||||
keys = []
|
||||
for key in conn.scan_iter(f"{prefix}:*"):
|
||||
async for key in conn.scan_iter(f"{prefix}:*"):
|
||||
keys.append(key)
|
||||
if keys:
|
||||
conn.delete(*keys)
|
||||
|
@ -21,11 +20,10 @@ def _delete_test_keys(prefix: str, conn: Redis):
|
|||
|
||||
@pytest.fixture
|
||||
def key_prefix(redis):
|
||||
key_prefix = f"redis-developer:{random.random()}"
|
||||
key_prefix = f"redis-om:{random.random()}"
|
||||
yield key_prefix
|
||||
_delete_test_keys(key_prefix, redis)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def delete_test_keys(redis, request, key_prefix):
|
||||
_delete_test_keys(key_prefix, redis)
|
||||
async def delete_test_keys(redis, request, key_prefix):
|
||||
await _delete_test_keys(key_prefix, redis)
|
||||
|
|
|
@ -8,9 +8,9 @@ from unittest import mock
|
|||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from redis_developer.model import Field, HashModel
|
||||
from redis_developer.model.migrations.migrator import Migrator
|
||||
from redis_developer.model.model import (
|
||||
from redis_om.model import Field, HashModel
|
||||
from redis_om.model.migrations.migrator import Migrator
|
||||
from redis_om.model.model import (
|
||||
NotFoundError,
|
||||
QueryNotSupportedError,
|
||||
RedisModelError,
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import abc
|
||||
import asyncio
|
||||
import datetime
|
||||
import decimal
|
||||
from collections import namedtuple
|
||||
|
@ -8,9 +9,9 @@ from unittest import mock
|
|||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from redis_developer.model import EmbeddedJsonModel, Field, JsonModel
|
||||
from redis_developer.model.migrations.migrator import Migrator
|
||||
from redis_developer.model.model import (
|
||||
from redis_om.model import EmbeddedJsonModel, Field, JsonModel
|
||||
from redis_om.model.migrations.migrator import Migrator
|
||||
from redis_om.model.model import (
|
||||
NotFoundError,
|
||||
QueryNotSupportedError,
|
||||
RedisModelError,
|
||||
|
@ -21,7 +22,7 @@ today = datetime.date.today()
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def m(key_prefix):
|
||||
async def m(key_prefix, redis):
|
||||
class BaseJsonModel(JsonModel, abc.ABC):
|
||||
class Meta:
|
||||
global_key_prefix = key_prefix
|
||||
|
@ -64,7 +65,7 @@ def m(key_prefix):
|
|||
# Creates an embedded list of models.
|
||||
orders: Optional[List[Order]]
|
||||
|
||||
Migrator().run()
|
||||
await Migrator(redis).run()
|
||||
|
||||
return namedtuple(
|
||||
"Models", ["BaseJsonModel", "Note", "Address", "Item", "Order", "Member"]
|
||||
|
@ -83,7 +84,7 @@ def address(m):
|
|||
|
||||
|
||||
@pytest.fixture()
|
||||
def members(address, m):
|
||||
async def members(address, m):
|
||||
member1 = m.Member(
|
||||
first_name="Andrew",
|
||||
last_name="Brookins",
|
||||
|
@ -111,14 +112,15 @@ def members(address, m):
|
|||
address=address,
|
||||
)
|
||||
|
||||
member1.save()
|
||||
member2.save()
|
||||
member3.save()
|
||||
await member1.save()
|
||||
await member2.save()
|
||||
await member3.save()
|
||||
|
||||
yield member1, member2, member3
|
||||
|
||||
|
||||
def test_validates_required_fields(address, m):
|
||||
@pytest.mark.asyncio
|
||||
async def test_validates_required_fields(address, m):
|
||||
# Raises ValidationError address is required
|
||||
with pytest.raises(ValidationError):
|
||||
m.Member(
|
||||
|
@ -129,7 +131,8 @@ def test_validates_required_fields(address, m):
|
|||
)
|
||||
|
||||
|
||||
def test_validates_field(address, m):
|
||||
@pytest.mark.asyncio
|
||||
async def test_validates_field(address, m):
|
||||
# Raises ValidationError: join_date is not a date
|
||||
with pytest.raises(ValidationError):
|
||||
m.Member(
|
||||
|
@ -141,7 +144,8 @@ def test_validates_field(address, m):
|
|||
|
||||
|
||||
# Passes validation
|
||||
def test_validation_passes(address, m):
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_passes(address, m):
|
||||
member = m.Member(
|
||||
first_name="Andrew",
|
||||
last_name="Brookins",
|
||||
|
@ -153,7 +157,10 @@ def test_validation_passes(address, m):
|
|||
assert member.first_name == "Andrew"
|
||||
|
||||
|
||||
def test_saves_model_and_creates_pk(address, m):
|
||||
@pytest.mark.asyncio
|
||||
async def test_saves_model_and_creates_pk(address, m, redis):
|
||||
await Migrator(redis).run()
|
||||
|
||||
member = m.Member(
|
||||
first_name="Andrew",
|
||||
last_name="Brookins",
|
||||
|
@ -163,15 +170,16 @@ def test_saves_model_and_creates_pk(address, m):
|
|||
address=address,
|
||||
)
|
||||
# Save a model instance to Redis
|
||||
member.save()
|
||||
await member.save()
|
||||
|
||||
member2 = m.Member.get(member.pk)
|
||||
member2 = await m.Member.get(member.pk)
|
||||
assert member2 == member
|
||||
assert member2.address == address
|
||||
|
||||
|
||||
@pytest.mark.skip("Not implemented yet")
|
||||
def test_saves_many(address, m):
|
||||
@pytest.mark.asyncio
|
||||
async def test_saves_many(address, m):
|
||||
members = [
|
||||
m.Member(
|
||||
first_name="Andrew",
|
||||
|
@ -193,9 +201,16 @@ def test_saves_many(address, m):
|
|||
m.Member.add(members)
|
||||
|
||||
|
||||
async def save(members):
|
||||
for m in members:
|
||||
await m.save()
|
||||
return members
|
||||
|
||||
|
||||
@pytest.mark.skip("Not ready yet")
|
||||
def test_updates_a_model(members, m):
|
||||
member1, member2, member3 = members
|
||||
@pytest.mark.asyncio
|
||||
async def test_updates_a_model(members, m):
|
||||
member1, member2, member3 = await save(members)
|
||||
|
||||
# Or, with an implicit save:
|
||||
member1.update(last_name="Smith")
|
||||
|
@ -213,18 +228,20 @@ def test_updates_a_model(members, m):
|
|||
)
|
||||
|
||||
|
||||
def test_paginate_query(members, m):
|
||||
@pytest.mark.asyncio
|
||||
async def test_paginate_query(members, m):
|
||||
member1, member2, member3 = members
|
||||
actual = m.Member.find().sort_by("age").all(batch_size=1)
|
||||
actual = await m.Member.find().sort_by("age").all(batch_size=1)
|
||||
assert actual == [member2, member1, member3]
|
||||
|
||||
|
||||
def test_access_result_by_index_cached(members, m):
|
||||
@pytest.mark.asyncio
|
||||
async def test_access_result_by_index_cached(members, m):
|
||||
member1, member2, member3 = members
|
||||
query = m.Member.find().sort_by("age")
|
||||
# Load the cache, throw away the result.
|
||||
assert query._model_cache == []
|
||||
query.execute()
|
||||
await query.execute()
|
||||
assert query._model_cache == [member2, member1, member3]
|
||||
|
||||
# Access an item that should be in the cache.
|
||||
|
@ -233,21 +250,23 @@ def test_access_result_by_index_cached(members, m):
|
|||
assert not mock_db.called
|
||||
|
||||
|
||||
def test_access_result_by_index_not_cached(members, m):
|
||||
@pytest.mark.asyncio
|
||||
async def test_access_result_by_index_not_cached(members, m):
|
||||
member1, member2, member3 = members
|
||||
query = m.Member.find().sort_by("age")
|
||||
|
||||
# Assert that we don't have any models in the cache yet -- we
|
||||
# haven't made any requests of Redis.
|
||||
assert query._model_cache == []
|
||||
assert query[0] == member2
|
||||
assert query[1] == member1
|
||||
assert query[2] == member3
|
||||
assert query.get_item(0) == member2
|
||||
assert query.get_item(1) == member1
|
||||
assert query.get_item(2) == member3
|
||||
|
||||
|
||||
def test_in_query(members, m):
|
||||
@pytest.mark.asyncio
|
||||
async def test_in_query(members, m):
|
||||
member1, member2, member3 = members
|
||||
actual = (
|
||||
actual = await (
|
||||
m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk])
|
||||
.sort_by("age")
|
||||
.all()
|
||||
|
@ -256,12 +275,13 @@ def test_in_query(members, m):
|
|||
|
||||
|
||||
@pytest.mark.skip("Not implemented yet")
|
||||
def test_update_query(members, m):
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_query(members, m):
|
||||
member1, member2, member3 = members
|
||||
m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk]).update(
|
||||
await m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk]).update(
|
||||
first_name="Bobby"
|
||||
)
|
||||
actual = (
|
||||
actual = await (
|
||||
m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk])
|
||||
.sort_by("age")
|
||||
.all()
|
||||
|
@ -270,24 +290,25 @@ def test_update_query(members, m):
|
|||
assert all([m.name == "Bobby" for m in actual])
|
||||
|
||||
|
||||
def test_exact_match_queries(members, m):
|
||||
@pytest.mark.asyncio
|
||||
async def test_exact_match_queries(members, m):
|
||||
member1, member2, member3 = members
|
||||
|
||||
actual = m.Member.find(m.Member.last_name == "Brookins").sort_by("age").all()
|
||||
actual = await m.Member.find(m.Member.last_name == "Brookins").sort_by("age").all()
|
||||
assert actual == [member2, member1]
|
||||
|
||||
actual = m.Member.find(
|
||||
actual = await m.Member.find(
|
||||
(m.Member.last_name == "Brookins") & ~(m.Member.first_name == "Andrew")
|
||||
).all()
|
||||
assert actual == [member2]
|
||||
|
||||
actual = m.Member.find(~(m.Member.last_name == "Brookins")).all()
|
||||
actual = await m.Member.find(~(m.Member.last_name == "Brookins")).all()
|
||||
assert actual == [member3]
|
||||
|
||||
actual = m.Member.find(m.Member.last_name != "Brookins").all()
|
||||
actual = await m.Member.find(m.Member.last_name != "Brookins").all()
|
||||
assert actual == [member3]
|
||||
|
||||
actual = (
|
||||
actual = await (
|
||||
m.Member.find(
|
||||
(m.Member.last_name == "Brookins") & (m.Member.first_name == "Andrew")
|
||||
| (m.Member.first_name == "Kim")
|
||||
|
@ -297,19 +318,20 @@ def test_exact_match_queries(members, m):
|
|||
)
|
||||
assert actual == [member2, member1]
|
||||
|
||||
actual = m.Member.find(
|
||||
actual = await m.Member.find(
|
||||
m.Member.first_name == "Kim", m.Member.last_name == "Brookins"
|
||||
).all()
|
||||
assert actual == [member2]
|
||||
|
||||
actual = m.Member.find(m.Member.address.city == "Portland").sort_by("age").all()
|
||||
actual = await m.Member.find(m.Member.address.city == "Portland").sort_by("age").all()
|
||||
assert actual == [member2, member1, member3]
|
||||
|
||||
|
||||
def test_recursive_query_expression_resolution(members, m):
|
||||
@pytest.mark.asyncio
|
||||
async def test_recursive_query_expression_resolution(members, m):
|
||||
member1, member2, member3 = members
|
||||
|
||||
actual = (
|
||||
actual = await (
|
||||
m.Member.find(
|
||||
(m.Member.last_name == "Brookins")
|
||||
| (m.Member.age == 100) & (m.Member.last_name == "Smith")
|
||||
|
@ -320,13 +342,14 @@ def test_recursive_query_expression_resolution(members, m):
|
|||
assert actual == [member2, member1, member3]
|
||||
|
||||
|
||||
def test_recursive_query_field_resolution(members, m):
|
||||
@pytest.mark.asyncio
|
||||
async def test_recursive_query_field_resolution(members, m):
|
||||
member1, _, _ = members
|
||||
member1.address.note = m.Note(
|
||||
description="Weird house", created_on=datetime.datetime.now()
|
||||
)
|
||||
member1.save()
|
||||
actual = m.Member.find(m.Member.address.note.description == "Weird house").all()
|
||||
await member1.save()
|
||||
actual = await m.Member.find(m.Member.address.note.description == "Weird house").all()
|
||||
assert actual == [member1]
|
||||
|
||||
member1.orders = [
|
||||
|
@ -336,29 +359,31 @@ def test_recursive_query_field_resolution(members, m):
|
|||
created_on=datetime.datetime.now(),
|
||||
)
|
||||
]
|
||||
member1.save()
|
||||
actual = m.Member.find(m.Member.orders.items.name == "Ball").all()
|
||||
await member1.save()
|
||||
actual = await m.Member.find(m.Member.orders.items.name == "Ball").all()
|
||||
assert actual == [member1]
|
||||
assert actual[0].orders[0].items[0].name == "Ball"
|
||||
|
||||
|
||||
def test_full_text_search(members, m):
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_text_search(members, m):
|
||||
member1, member2, _ = members
|
||||
member1.update(bio="Hates sunsets, likes beaches")
|
||||
member2.update(bio="Hates beaches, likes forests")
|
||||
await member1.update(bio="Hates sunsets, likes beaches")
|
||||
await member2.update(bio="Hates beaches, likes forests")
|
||||
|
||||
actual = m.Member.find(m.Member.bio % "beaches").sort_by("age").all()
|
||||
actual = await m.Member.find(m.Member.bio % "beaches").sort_by("age").all()
|
||||
assert actual == [member2, member1]
|
||||
|
||||
actual = m.Member.find(m.Member.bio % "forests").all()
|
||||
actual = await m.Member.find(m.Member.bio % "forests").all()
|
||||
assert actual == [member2]
|
||||
|
||||
|
||||
def test_tag_queries_boolean_logic(members, m):
|
||||
@pytest.mark.asyncio
|
||||
async def test_tag_queries_boolean_logic(members, m):
|
||||
member1, member2, member3 = members
|
||||
|
||||
actual = (
|
||||
m.Member.find(
|
||||
await m.Member.find(
|
||||
(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
|
||||
| (m.Member.last_name == "Smith")
|
||||
)
|
||||
|
@ -368,7 +393,8 @@ def test_tag_queries_boolean_logic(members, m):
|
|||
assert actual == [member1, member3]
|
||||
|
||||
|
||||
def test_tag_queries_punctuation(address, m):
|
||||
@pytest.mark.asyncio
|
||||
async def test_tag_queries_punctuation(address, m):
|
||||
member1 = m.Member(
|
||||
first_name="Andrew, the Michael",
|
||||
last_name="St. Brookins-on-Pier",
|
||||
|
@ -377,7 +403,7 @@ def test_tag_queries_punctuation(address, m):
|
|||
join_date=today,
|
||||
address=address,
|
||||
)
|
||||
member1.save()
|
||||
await member1.save()
|
||||
|
||||
member2 = m.Member(
|
||||
first_name="Bob",
|
||||
|
@ -387,24 +413,25 @@ def test_tag_queries_punctuation(address, m):
|
|||
join_date=today,
|
||||
address=address,
|
||||
)
|
||||
member2.save()
|
||||
await member2.save()
|
||||
|
||||
assert (
|
||||
m.Member.find(m.Member.first_name == "Andrew, the Michael").first() == member1
|
||||
await m.Member.find(m.Member.first_name == "Andrew, the Michael").first() == member1
|
||||
)
|
||||
assert (
|
||||
m.Member.find(m.Member.last_name == "St. Brookins-on-Pier").first() == member1
|
||||
await m.Member.find(m.Member.last_name == "St. Brookins-on-Pier").first() == member1
|
||||
)
|
||||
|
||||
# Notice that when we index and query multiple values that use the internal
|
||||
# TAG separator for single-value exact-match fields, like an indexed string,
|
||||
# the queries will succeed. We apply a workaround that queries for the union
|
||||
# of the two values separated by the tag separator.
|
||||
assert m.Member.find(m.Member.email == "a|b@example.com").all() == [member1]
|
||||
assert m.Member.find(m.Member.email == "a|villain@example.com").all() == [member2]
|
||||
assert await m.Member.find(m.Member.email == "a|b@example.com").all() == [member1]
|
||||
assert await m.Member.find(m.Member.email == "a|villain@example.com").all() == [member2]
|
||||
|
||||
|
||||
def test_tag_queries_negation(members, m):
|
||||
@pytest.mark.asyncio
|
||||
async def test_tag_queries_negation(members, m):
|
||||
member1, member2, member3 = members
|
||||
|
||||
"""
|
||||
|
@ -414,7 +441,7 @@ def test_tag_queries_negation(members, m):
|
|||
|
||||
"""
|
||||
query = m.Member.find(~(m.Member.first_name == "Andrew"))
|
||||
assert query.all() == [member2]
|
||||
assert await query.all() == [member2]
|
||||
|
||||
"""
|
||||
┌first_name
|
||||
|
@ -429,7 +456,7 @@ def test_tag_queries_negation(members, m):
|
|||
query = m.Member.find(
|
||||
~(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
|
||||
)
|
||||
assert query.all() == [member2]
|
||||
assert await query.all() == [member2]
|
||||
|
||||
"""
|
||||
┌first_name
|
||||
|
@ -448,7 +475,7 @@ def test_tag_queries_negation(members, m):
|
|||
~(m.Member.first_name == "Andrew")
|
||||
& ((m.Member.last_name == "Brookins") | (m.Member.last_name == "Smith"))
|
||||
)
|
||||
assert query.all() == [member2]
|
||||
assert await query.all() == [member2]
|
||||
|
||||
"""
|
||||
┌first_name
|
||||
|
@ -467,67 +494,71 @@ def test_tag_queries_negation(members, m):
|
|||
~(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
|
||||
| (m.Member.last_name == "Smith")
|
||||
)
|
||||
assert query.sort_by("age").all() == [member2, member3]
|
||||
assert await query.sort_by("age").all() == [member2, member3]
|
||||
|
||||
actual = m.Member.find(
|
||||
actual = await m.Member.find(
|
||||
(m.Member.first_name == "Andrew") & ~(m.Member.last_name == "Brookins")
|
||||
).all()
|
||||
assert actual == [member3]
|
||||
|
||||
|
||||
def test_numeric_queries(members, m):
|
||||
@pytest.mark.asyncio
|
||||
async def test_numeric_queries(members, m):
|
||||
member1, member2, member3 = members
|
||||
|
||||
actual = m.Member.find(m.Member.age == 34).all()
|
||||
actual = await m.Member.find(m.Member.age == 34).all()
|
||||
assert actual == [member2]
|
||||
|
||||
actual = m.Member.find(m.Member.age > 34).all()
|
||||
actual = await m.Member.find(m.Member.age > 34).all()
|
||||
assert actual == [member1, member3]
|
||||
|
||||
actual = m.Member.find(m.Member.age < 35).all()
|
||||
actual = await m.Member.find(m.Member.age < 35).all()
|
||||
assert actual == [member2]
|
||||
|
||||
actual = m.Member.find(m.Member.age <= 34).all()
|
||||
actual = await m.Member.find(m.Member.age <= 34).all()
|
||||
assert actual == [member2]
|
||||
|
||||
actual = m.Member.find(m.Member.age >= 100).all()
|
||||
actual = await m.Member.find(m.Member.age >= 100).all()
|
||||
assert actual == [member3]
|
||||
|
||||
actual = m.Member.find(~(m.Member.age == 100)).sort_by("age").all()
|
||||
actual = await m.Member.find(~(m.Member.age == 100)).sort_by("age").all()
|
||||
assert actual == [member2, member1]
|
||||
|
||||
actual = m.Member.find(m.Member.age > 30, m.Member.age < 40).sort_by("age").all()
|
||||
actual = await m.Member.find(m.Member.age > 30, m.Member.age < 40).sort_by("age").all()
|
||||
assert actual == [member2, member1]
|
||||
|
||||
actual = m.Member.find(m.Member.age != 34).sort_by("age").all()
|
||||
actual = await m.Member.find(m.Member.age != 34).sort_by("age").all()
|
||||
assert actual == [member1, member3]
|
||||
|
||||
|
||||
def test_sorting(members, m):
|
||||
@pytest.mark.asyncio
|
||||
async def test_sorting(members, m):
|
||||
member1, member2, member3 = members
|
||||
|
||||
actual = m.Member.find(m.Member.age > 34).sort_by("age").all()
|
||||
actual = await m.Member.find(m.Member.age > 34).sort_by("age").all()
|
||||
assert actual == [member1, member3]
|
||||
|
||||
actual = m.Member.find(m.Member.age > 34).sort_by("-age").all()
|
||||
actual = await m.Member.find(m.Member.age > 34).sort_by("-age").all()
|
||||
assert actual == [member3, member1]
|
||||
|
||||
with pytest.raises(QueryNotSupportedError):
|
||||
# This field does not exist.
|
||||
m.Member.find().sort_by("not-a-real-field").all()
|
||||
await m.Member.find().sort_by("not-a-real-field").all()
|
||||
|
||||
with pytest.raises(QueryNotSupportedError):
|
||||
# This field is not sortable.
|
||||
m.Member.find().sort_by("join_date").all()
|
||||
await m.Member.find().sort_by("join_date").all()
|
||||
|
||||
|
||||
def test_not_found(m):
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found(m):
|
||||
with pytest.raises(NotFoundError):
|
||||
# This ID does not exist.
|
||||
m.Member.get(1000)
|
||||
await m.Member.get(1000)
|
||||
|
||||
|
||||
def test_list_field_limitations(m):
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_field_limitations(m, redis):
|
||||
with pytest.raises(RedisModelError):
|
||||
|
||||
class SortableTarotWitch(m.BaseJsonModel):
|
||||
|
@ -571,15 +602,16 @@ def test_list_field_limitations(m):
|
|||
# We need to import and run this manually because we defined
|
||||
# our model classes within a function that runs after the test
|
||||
# suite's migrator has already looked for migrations to run.
|
||||
Migrator().run()
|
||||
await Migrator(redis).run()
|
||||
|
||||
witch = TarotWitch(tarot_cards=["death"])
|
||||
witch.save()
|
||||
actual = TarotWitch.find(TarotWitch.tarot_cards << "death").all()
|
||||
await witch.save()
|
||||
actual = await TarotWitch.find(TarotWitch.tarot_cards << "death").all()
|
||||
assert actual == [witch]
|
||||
|
||||
|
||||
def test_schema(m, key_prefix):
|
||||
@pytest.mark.asyncio
|
||||
async def test_schema(m, key_prefix):
|
||||
assert (
|
||||
m.Member.redisearch_schema()
|
||||
== f"ON JSON PREFIX 1 {key_prefix}:tests.test_json_model.Member: SCHEMA $.pk AS pk TAG SEPARATOR | $.first_name AS first_name TAG SEPARATOR | $.last_name AS last_name TAG SEPARATOR | $.email AS email TAG SEPARATOR | $.age AS age NUMERIC $.bio AS bio TAG SEPARATOR | $.bio AS bio_fts TEXT $.address.pk AS address_pk TAG SEPARATOR | $.address.city AS address_city TAG SEPARATOR | $.address.postal_code AS address_postal_code TAG SEPARATOR | $.address.note.pk AS address_note_pk TAG SEPARATOR | $.address.note.description AS address_note_description TAG SEPARATOR | $.orders[*].pk AS orders_pk TAG SEPARATOR | $.orders[*].items[*].pk AS orders_items_pk TAG SEPARATOR | $.orders[*].items[*].name AS orders_items_name TAG SEPARATOR |"
|
||||
|
|
Loading…
Reference in a new issue