WIP on async - test failure due to closed event loop
This commit is contained in:
		
							parent
							
								
									0f9f7aa868
								
							
						
					
					
						commit
						b2c2dd9f6f
					
				
					 22 changed files with 348 additions and 190 deletions
				
			
		
							
								
								
									
										2
									
								
								Makefile
									
										
									
									
									
								
							
							
						
						
									
										2
									
								
								Makefile
									
										
									
									
									
								
							| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
NAME := redis_developer
 | 
					NAME := redis_om
 | 
				
			||||||
INSTALL_STAMP := .install.stamp
 | 
					INSTALL_STAMP := .install.stamp
 | 
				
			||||||
POETRY := $(shell command -v poetry 2> /dev/null)
 | 
					POETRY := $(shell command -v poetry 2> /dev/null)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										18
									
								
								README.md
									
										
									
									
									
								
							
							
						
						
									
										18
									
								
								README.md
									
										
									
									
									
								
							| 
						 | 
					@ -52,7 +52,7 @@ Check out this example:
 | 
				
			||||||
import datetime
 | 
					import datetime
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from redis_developer.model import (
 | 
					from redis_om.model import (
 | 
				
			||||||
    EmbeddedJsonModel,
 | 
					    EmbeddedJsonModel,
 | 
				
			||||||
    JsonModel,
 | 
					    JsonModel,
 | 
				
			||||||
    Field,
 | 
					    Field,
 | 
				
			||||||
| 
						 | 
					@ -172,9 +172,9 @@ Don't want to run Redis yourself? RediSearch and RedisJSON are also available on
 | 
				
			||||||
 | 
					
 | 
				
			||||||
We'd love your contributions!
 | 
					We'd love your contributions!
 | 
				
			||||||
 | 
					
 | 
				
			||||||
**Bug reports** are especially helpful at this stage of the project. [You can open a bug report on GitHub](https://github.com/redis-developer/redis-developer-python/issues/new).
 | 
					**Bug reports** are especially helpful at this stage of the project. [You can open a bug report on GitHub](https://github.com/redis-om/redis-om-python/issues/new).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
You can also **contribute documentation** -- or just let us know if something needs more detail. [Open an issue on GitHub](https://github.com/redis-developer/redis-developer-python/issues/new) to get started.
 | 
					You can also **contribute documentation** -- or just let us know if something needs more detail. [Open an issue on GitHub](https://github.com/redis-om/redis-om-python/issues/new) to get started.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## License
 | 
					## License
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -184,17 +184,17 @@ Redis OM is [MIT licensed][license-url].
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[version-svg]: https://img.shields.io/pypi/v/redis-om?style=flat-square
 | 
					[version-svg]: https://img.shields.io/pypi/v/redis-om?style=flat-square
 | 
				
			||||||
[package-url]: https://pypi.org/project/redis-om/
 | 
					[package-url]: https://pypi.org/project/redis-om/
 | 
				
			||||||
[ci-svg]: https://img.shields.io/github/workflow/status/redis-developer/redis-developer-python/python?style=flat-square
 | 
					[ci-svg]: https://img.shields.io/github/workflow/status/redis-om/redis-om-python/python?style=flat-square
 | 
				
			||||||
[ci-url]: https://github.com/redis-developer/redis-developer-python/actions/workflows/build.yml
 | 
					[ci-url]: https://github.com/redis-om/redis-om-python/actions/workflows/build.yml
 | 
				
			||||||
[license-image]: http://img.shields.io/badge/license-MIT-green.svg?style=flat-square
 | 
					[license-image]: http://img.shields.io/badge/license-MIT-green.svg?style=flat-square
 | 
				
			||||||
[license-url]: LICENSE
 | 
					[license-url]: LICENSE
 | 
				
			||||||
 | 
					
 | 
				
			||||||
<!-- Links -->
 | 
					<!-- Links -->
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[redis-developer-website]: https://developer.redis.com
 | 
					[redis-om-website]: https://developer.redis.com
 | 
				
			||||||
[redis-om-js]: https://github.com/redis-developer/redis-om-js
 | 
					[redis-om-js]: https://github.com/redis-om/redis-om-js
 | 
				
			||||||
[redis-om-dotnet]: https://github.com/redis-developer/redis-om-dotnet
 | 
					[redis-om-dotnet]: https://github.com/redis-om/redis-om-dotnet
 | 
				
			||||||
[redis-om-spring]: https://github.com/redis-developer/redis-om-spring
 | 
					[redis-om-spring]: https://github.com/redis-om/redis-om-spring
 | 
				
			||||||
[redisearch-url]: https://oss.redis.com/redisearch/
 | 
					[redisearch-url]: https://oss.redis.com/redisearch/
 | 
				
			||||||
[redis-json-url]: https://oss.redis.com/redisjson/
 | 
					[redis-json-url]: https://oss.redis.com/redisjson/
 | 
				
			||||||
[pydantic-url]: https://github.com/samuelcolvin/pydantic
 | 
					[pydantic-url]: https://github.com/samuelcolvin/pydantic
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										10
									
								
								build.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								build.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,10 @@
 | 
				
			||||||
 | 
					import unasync
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def build(setup_kwargs):
 | 
				
			||||||
 | 
					    setup_kwargs.update(
 | 
				
			||||||
 | 
					        {"cmdclass": {'build_py': unasync.cmdclass_build_py(rules=[
 | 
				
			||||||
 | 
					            unasync.Rule("/aredis_om/", "/redis_om/"),
 | 
				
			||||||
 | 
					            unasync.Rule("/aredis_om/tests/", "/redis_om/tests/", additional_replacements={"aredis_om": "redis_om"}),
 | 
				
			||||||
 | 
					        ])}}
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
							
								
								
									
										32
									
								
								poetry.lock
									
										
									
										generated
									
									
									
								
							
							
						
						
									
										32
									
								
								poetry.lock
									
										
									
										generated
									
									
									
								
							| 
						 | 
					@ -538,6 +538,20 @@ toml = "*"
 | 
				
			||||||
[package.extras]
 | 
					[package.extras]
 | 
				
			||||||
testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"]
 | 
					testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[[package]]
 | 
				
			||||||
 | 
					name = "pytest-asyncio"
 | 
				
			||||||
 | 
					version = "0.16.0"
 | 
				
			||||||
 | 
					description = "Pytest support for asyncio."
 | 
				
			||||||
 | 
					category = "dev"
 | 
				
			||||||
 | 
					optional = false
 | 
				
			||||||
 | 
					python-versions = ">= 3.6"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[package.dependencies]
 | 
				
			||||||
 | 
					pytest = ">=5.4.0"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[package.extras]
 | 
				
			||||||
 | 
					testing = ["coverage", "hypothesis (>=5.7.1)"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[[package]]
 | 
					[[package]]
 | 
				
			||||||
name = "pytest-cov"
 | 
					name = "pytest-cov"
 | 
				
			||||||
version = "3.0.0"
 | 
					version = "3.0.0"
 | 
				
			||||||
| 
						 | 
					@ -707,6 +721,14 @@ category = "main"
 | 
				
			||||||
optional = false
 | 
					optional = false
 | 
				
			||||||
python-versions = "*"
 | 
					python-versions = "*"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[[package]]
 | 
				
			||||||
 | 
					name = "unasync"
 | 
				
			||||||
 | 
					version = "0.5.0"
 | 
				
			||||||
 | 
					description = "The async transformation code."
 | 
				
			||||||
 | 
					category = "dev"
 | 
				
			||||||
 | 
					optional = false
 | 
				
			||||||
 | 
					python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[[package]]
 | 
					[[package]]
 | 
				
			||||||
name = "wcwidth"
 | 
					name = "wcwidth"
 | 
				
			||||||
version = "0.2.5"
 | 
					version = "0.2.5"
 | 
				
			||||||
| 
						 | 
					@ -726,7 +748,7 @@ python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7"
 | 
				
			||||||
[metadata]
 | 
					[metadata]
 | 
				
			||||||
lock-version = "1.1"
 | 
					lock-version = "1.1"
 | 
				
			||||||
python-versions = "^3.8"
 | 
					python-versions = "^3.8"
 | 
				
			||||||
content-hash = "56b381dd9b79bd082e978019124176491c63f09dd5ce90e5f8ab642a7f79480f"
 | 
					content-hash = "d2d83b8cd3b094879e1aeb058d0036203942143f12fafa8be03fb0c79460028f"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[metadata.files]
 | 
					[metadata.files]
 | 
				
			||||||
aioredis = [
 | 
					aioredis = [
 | 
				
			||||||
| 
						 | 
					@ -1003,6 +1025,10 @@ pytest = [
 | 
				
			||||||
    {file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"},
 | 
					    {file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"},
 | 
				
			||||||
    {file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"},
 | 
					    {file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"},
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
 | 
					pytest-asyncio = [
 | 
				
			||||||
 | 
					    {file = "pytest-asyncio-0.16.0.tar.gz", hash = "sha256:7496c5977ce88c34379df64a66459fe395cd05543f0a2f837016e7144391fcfb"},
 | 
				
			||||||
 | 
					    {file = "pytest_asyncio-0.16.0-py3-none-any.whl", hash = "sha256:5f2a21273c47b331ae6aa5b36087047b4899e40f03f18397c0e65fa5cca54e9b"},
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
pytest-cov = [
 | 
					pytest-cov = [
 | 
				
			||||||
    {file = "pytest-cov-3.0.0.tar.gz", hash = "sha256:e7f0f5b1617d2210a2cabc266dfe2f4c75a8d32fb89eafb7ad9d06f6d076d470"},
 | 
					    {file = "pytest-cov-3.0.0.tar.gz", hash = "sha256:e7f0f5b1617d2210a2cabc266dfe2f4c75a8d32fb89eafb7ad9d06f6d076d470"},
 | 
				
			||||||
    {file = "pytest_cov-3.0.0-py3-none-any.whl", hash = "sha256:578d5d15ac4a25e5f961c938b85a05b09fdaae9deef3bb6de9a6e766622ca7a6"},
 | 
					    {file = "pytest_cov-3.0.0-py3-none-any.whl", hash = "sha256:578d5d15ac4a25e5f961c938b85a05b09fdaae9deef3bb6de9a6e766622ca7a6"},
 | 
				
			||||||
| 
						 | 
					@ -1148,6 +1174,10 @@ typing-extensions = [
 | 
				
			||||||
    {file = "typing_extensions-3.10.0.2-py3-none-any.whl", hash = "sha256:f1d25edafde516b146ecd0613dabcc61409817af4766fbbcfb8d1ad4ec441a34"},
 | 
					    {file = "typing_extensions-3.10.0.2-py3-none-any.whl", hash = "sha256:f1d25edafde516b146ecd0613dabcc61409817af4766fbbcfb8d1ad4ec441a34"},
 | 
				
			||||||
    {file = "typing_extensions-3.10.0.2.tar.gz", hash = "sha256:49f75d16ff11f1cd258e1b988ccff82a3ca5570217d7ad8c5f48205dd99a677e"},
 | 
					    {file = "typing_extensions-3.10.0.2.tar.gz", hash = "sha256:49f75d16ff11f1cd258e1b988ccff82a3ca5570217d7ad8c5f48205dd99a677e"},
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
 | 
					unasync = [
 | 
				
			||||||
 | 
					    {file = "unasync-0.5.0-py3-none-any.whl", hash = "sha256:8d4536dae85e87b8751dfcc776f7656fd0baf54bb022a7889440dc1b9dc3becb"},
 | 
				
			||||||
 | 
					    {file = "unasync-0.5.0.tar.gz", hash = "sha256:b675d87cf56da68bd065d3b7a67ac71df85591978d84c53083c20d79a7e5096d"},
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
wcwidth = [
 | 
					wcwidth = [
 | 
				
			||||||
    {file = "wcwidth-0.2.5-py2.py3-none-any.whl", hash = "sha256:beb4802a9cebb9144e99086eff703a642a13d6a0052920003a230f3294bbe784"},
 | 
					    {file = "wcwidth-0.2.5-py2.py3-none-any.whl", hash = "sha256:beb4802a9cebb9144e99086eff703a642a13d6a0052920003a230f3294bbe784"},
 | 
				
			||||||
    {file = "wcwidth-0.2.5.tar.gz", hash = "sha256:c4d647b99872929fdb7bdcaa4fbe7f01413ed3d98077df798530e5b04f116c83"},
 | 
					    {file = "wcwidth-0.2.5.tar.gz", hash = "sha256:c4d647b99872929fdb7bdcaa4fbe7f01413ed3d98077df798530e5b04f116c83"},
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,9 +1,10 @@
 | 
				
			||||||
[tool.poetry]
 | 
					[tool.poetry]
 | 
				
			||||||
name = "redis-developer"
 | 
					name = "redis-om"
 | 
				
			||||||
version = "0.1.0"
 | 
					version = "0.1.0"
 | 
				
			||||||
description = "A high-level library containing useful Redis abstractions and tools, like an ORM and leaderboard."
 | 
					description = "A high-level library containing useful Redis abstractions and tools, like an ORM and leaderboard."
 | 
				
			||||||
authors = ["Andrew Brookins <andrew.brookins@redislabs.com>"]
 | 
					authors = ["Andrew Brookins <andrew.brookins@redislabs.com>"]
 | 
				
			||||||
license = "MIT"
 | 
					license = "MIT"
 | 
				
			||||||
 | 
					build = "build.py"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[tool.poetry.dependencies]
 | 
					[tool.poetry.dependencies]
 | 
				
			||||||
python = "^3.8"
 | 
					python = "^3.8"
 | 
				
			||||||
| 
						 | 
					@ -30,10 +31,11 @@ bandit = "^1.7.0"
 | 
				
			||||||
coverage = "^6.0.2"
 | 
					coverage = "^6.0.2"
 | 
				
			||||||
pytest-cov = "^3.0.0"
 | 
					pytest-cov = "^3.0.0"
 | 
				
			||||||
pytest-xdist = "^2.4.0"
 | 
					pytest-xdist = "^2.4.0"
 | 
				
			||||||
 | 
					unasync = "^0.5.0"
 | 
				
			||||||
 | 
					pytest-asyncio = "^0.16.0"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[tool.poetry.scripts]
 | 
					[tool.poetry.scripts]
 | 
				
			||||||
migrate = "redis_developer.orm.cli.migrate:migrate"
 | 
					migrate = "redis_om.orm.cli.migrate:migrate"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[build-system]
 | 
					[build-system]
 | 
				
			||||||
requires = ["poetry-core>=1.0.0"]
 | 
					requires = ["poetry-core>=1.0.0"]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,22 +1,28 @@
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
 | 
					from typing import Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import dotenv
 | 
					import dotenv
 | 
				
			||||||
 | 
					import aioredis
 | 
				
			||||||
import redis
 | 
					import redis
 | 
				
			||||||
 | 
					from redis_om.unasync_util import ASYNC_MODE
 | 
				
			||||||
 | 
					
 | 
				
			||||||
dotenv.load_dotenv()
 | 
					dotenv.load_dotenv()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
URL = os.environ.get("REDIS_OM_URL", None)
 | 
					URL = os.environ.get("REDIS_OM_URL", None)
 | 
				
			||||||
 | 
					if ASYNC_MODE:
 | 
				
			||||||
 | 
					    client = aioredis.Redis
 | 
				
			||||||
 | 
					else:
 | 
				
			||||||
 | 
					    client = redis.Redis
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_redis_connection(**kwargs) -> redis.Redis:
 | 
					def get_redis_connection(**kwargs) -> Union[aioredis.Redis, redis.Redis]:
 | 
				
			||||||
    # If someone passed in a 'url' parameter, or specified a REDIS_OM_URL
 | 
					    # If someone passed in a 'url' parameter, or specified a REDIS_OM_URL
 | 
				
			||||||
    # environment variable, we'll create the Redis client from the URL.
 | 
					    # environment variable, we'll create the Redis client from the URL.
 | 
				
			||||||
    url = kwargs.pop("url", URL)
 | 
					    url = kwargs.pop("url", URL)
 | 
				
			||||||
    if url:
 | 
					    if url:
 | 
				
			||||||
        return redis.from_url(url, **kwargs)
 | 
					        return client.from_url(url, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Decode from UTF-8 by default
 | 
					    # Decode from UTF-8 by default
 | 
				
			||||||
    if "decode_responses" not in kwargs:
 | 
					    if "decode_responses" not in kwargs:
 | 
				
			||||||
        kwargs["decode_responses"] = True
 | 
					        kwargs["decode_responses"] = True
 | 
				
			||||||
    return redis.Redis(**kwargs)
 | 
					    return client(**kwargs)
 | 
				
			||||||
| 
						 | 
					@ -1,10 +1,10 @@
 | 
				
			||||||
import click
 | 
					import click
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from redis_developer.model.migrations.migrator import Migrator
 | 
					from redis_om.model.migrations.migrator import Migrator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@click.command()
 | 
					@click.command()
 | 
				
			||||||
@click.option("--module", default="redis_developer")
 | 
					@click.option("--module", default="redis_om")
 | 
				
			||||||
def migrate(module):
 | 
					def migrate(module):
 | 
				
			||||||
    migrator = Migrator(module)
 | 
					    migrator = Migrator(module)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2,15 +2,14 @@ import hashlib
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
from dataclasses import dataclass
 | 
					from dataclasses import dataclass
 | 
				
			||||||
from enum import Enum
 | 
					from enum import Enum
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from redis import ResponseError
 | 
					from redis import ResponseError, Redis
 | 
				
			||||||
 | 
					from aioredis import ResponseError as AResponseError, Redis as ARedis
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from redis_developer.connections import get_redis_connection
 | 
					from redis_om.model.model import model_registry
 | 
				
			||||||
from redis_developer.model.model import model_registry
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
redis = get_redis_connection()
 | 
					 | 
				
			||||||
log = logging.getLogger(__name__)
 | 
					log = logging.getLogger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -43,12 +42,12 @@ def schema_hash_key(index_name):
 | 
				
			||||||
    return f"{index_name}:hash"
 | 
					    return f"{index_name}:hash"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def create_index(index_name, schema, current_hash):
 | 
					async def create_index(redis: Union[Redis, ARedis], index_name, schema, current_hash):
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        redis.execute_command(f"ft.info {index_name}")
 | 
					        await redis.execute_command(f"ft.info {index_name}")
 | 
				
			||||||
    except ResponseError:
 | 
					    except (ResponseError, AResponseError):
 | 
				
			||||||
        redis.execute_command(f"ft.create {index_name} {schema}")
 | 
					        await redis.execute_command(f"ft.create {index_name} {schema}")
 | 
				
			||||||
        redis.set(schema_hash_key(index_name), current_hash)
 | 
					        await redis.set(schema_hash_key(index_name), current_hash)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        log.info("Index already exists, skipping. Index hash: %s", index_name)
 | 
					        log.info("Index already exists, skipping. Index hash: %s", index_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -65,34 +64,38 @@ class IndexMigration:
 | 
				
			||||||
    schema: str
 | 
					    schema: str
 | 
				
			||||||
    hash: str
 | 
					    hash: str
 | 
				
			||||||
    action: MigrationAction
 | 
					    action: MigrationAction
 | 
				
			||||||
 | 
					    redis: Union[Redis, ARedis]
 | 
				
			||||||
    previous_hash: Optional[str] = None
 | 
					    previous_hash: Optional[str] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def run(self):
 | 
					    async def run(self):
 | 
				
			||||||
        if self.action is MigrationAction.CREATE:
 | 
					        if self.action is MigrationAction.CREATE:
 | 
				
			||||||
            self.create()
 | 
					            await self.create()
 | 
				
			||||||
        elif self.action is MigrationAction.DROP:
 | 
					        elif self.action is MigrationAction.DROP:
 | 
				
			||||||
            self.drop()
 | 
					            await self.drop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def create(self):
 | 
					    async def create(self):
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            return create_index(self.index_name, self.schema, self.hash)
 | 
					            await create_index(self.redis, self.index_name, self.schema, self.hash)
 | 
				
			||||||
        except ResponseError:
 | 
					        except ResponseError:
 | 
				
			||||||
            log.info("Index already exists: %s", self.index_name)
 | 
					            log.info("Index already exists: %s", self.index_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def drop(self):
 | 
					    async def drop(self):
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            redis.execute_command(f"FT.DROPINDEX {self.index_name}")
 | 
					            await self.redis.execute_command(f"FT.DROPINDEX {self.index_name}")
 | 
				
			||||||
        except ResponseError:
 | 
					        except ResponseError:
 | 
				
			||||||
            log.info("Index does not exist: %s", self.index_name)
 | 
					            log.info("Index does not exist: %s", self.index_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Migrator:
 | 
					class Migrator:
 | 
				
			||||||
    def __init__(self, module=None):
 | 
					    def __init__(self, redis: Union[Redis, ARedis], module=None):
 | 
				
			||||||
        # Try to load any modules found under the given path or module name.
 | 
					        self.module = module
 | 
				
			||||||
        if module:
 | 
					 | 
				
			||||||
            import_submodules(module)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.migrations = []
 | 
					        self.migrations = []
 | 
				
			||||||
 | 
					        self.redis = redis
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def run(self):
 | 
				
			||||||
 | 
					        # Try to load any modules found under the given path or module name.
 | 
				
			||||||
 | 
					        if self.module:
 | 
				
			||||||
 | 
					            import_submodules(self.module)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for name, cls in model_registry.items():
 | 
					        for name, cls in model_registry.items():
 | 
				
			||||||
            hash_key = schema_hash_key(cls.Meta.index_name)
 | 
					            hash_key = schema_hash_key(cls.Meta.index_name)
 | 
				
			||||||
| 
						 | 
					@ -104,8 +107,8 @@ class Migrator:
 | 
				
			||||||
            current_hash = hashlib.sha1(schema.encode("utf-8")).hexdigest()  # nosec
 | 
					            current_hash = hashlib.sha1(schema.encode("utf-8")).hexdigest()  # nosec
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                redis.execute_command("ft.info", cls.Meta.index_name)
 | 
					                await self.redis.execute_command("ft.info", cls.Meta.index_name)
 | 
				
			||||||
            except ResponseError:
 | 
					            except (ResponseError, AResponseError):
 | 
				
			||||||
                self.migrations.append(
 | 
					                self.migrations.append(
 | 
				
			||||||
                    IndexMigration(
 | 
					                    IndexMigration(
 | 
				
			||||||
                        name,
 | 
					                        name,
 | 
				
			||||||
| 
						 | 
					@ -113,11 +116,12 @@ class Migrator:
 | 
				
			||||||
                        schema,
 | 
					                        schema,
 | 
				
			||||||
                        current_hash,
 | 
					                        current_hash,
 | 
				
			||||||
                        MigrationAction.CREATE,
 | 
					                        MigrationAction.CREATE,
 | 
				
			||||||
 | 
					                        self.redis
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
                continue
 | 
					                continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            stored_hash = redis.get(hash_key)
 | 
					            stored_hash = self.redis.get(hash_key)
 | 
				
			||||||
            schema_out_of_date = current_hash != stored_hash
 | 
					            schema_out_of_date = current_hash != stored_hash
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if schema_out_of_date:
 | 
					            if schema_out_of_date:
 | 
				
			||||||
| 
						 | 
					@ -129,7 +133,8 @@ class Migrator:
 | 
				
			||||||
                        schema,
 | 
					                        schema,
 | 
				
			||||||
                        current_hash,
 | 
					                        current_hash,
 | 
				
			||||||
                        MigrationAction.DROP,
 | 
					                        MigrationAction.DROP,
 | 
				
			||||||
                        stored_hash,
 | 
					                        self.redis,
 | 
				
			||||||
 | 
					                        stored_hash
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
                self.migrations.append(
 | 
					                self.migrations.append(
 | 
				
			||||||
| 
						 | 
					@ -139,12 +144,12 @@ class Migrator:
 | 
				
			||||||
                        schema,
 | 
					                        schema,
 | 
				
			||||||
                        current_hash,
 | 
					                        current_hash,
 | 
				
			||||||
                        MigrationAction.CREATE,
 | 
					                        MigrationAction.CREATE,
 | 
				
			||||||
                        stored_hash,
 | 
					                        self.redis,
 | 
				
			||||||
 | 
					                        stored_hash
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def run(self):
 | 
					 | 
				
			||||||
        # TODO: Migration history
 | 
					        # TODO: Migration history
 | 
				
			||||||
        # TODO: Dry run with output
 | 
					        # TODO: Dry run with output
 | 
				
			||||||
        for migration in self.migrations:
 | 
					        for migration in self.migrations:
 | 
				
			||||||
            migration.run()
 | 
					            await migration.run()
 | 
				
			||||||
| 
						 | 
					@ -4,7 +4,7 @@ import decimal
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
import operator
 | 
					import operator
 | 
				
			||||||
from copy import copy, deepcopy
 | 
					from copy import copy
 | 
				
			||||||
from enum import Enum
 | 
					from enum import Enum
 | 
				
			||||||
from functools import reduce
 | 
					from functools import reduce
 | 
				
			||||||
from typing import (
 | 
					from typing import (
 | 
				
			||||||
| 
						 | 
					@ -27,6 +27,7 @@ from typing import (
 | 
				
			||||||
    no_type_check,
 | 
					    no_type_check,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import aioredis
 | 
				
			||||||
import redis
 | 
					import redis
 | 
				
			||||||
from pydantic import BaseModel, validator
 | 
					from pydantic import BaseModel, validator
 | 
				
			||||||
from pydantic.fields import FieldInfo as PydanticFieldInfo
 | 
					from pydantic.fields import FieldInfo as PydanticFieldInfo
 | 
				
			||||||
| 
						 | 
					@ -37,11 +38,11 @@ from pydantic.utils import Representation
 | 
				
			||||||
from redis.client import Pipeline
 | 
					from redis.client import Pipeline
 | 
				
			||||||
from ulid import ULID
 | 
					from ulid import ULID
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..connections import get_redis_connection
 | 
					from redis_om.connections import get_redis_connection
 | 
				
			||||||
from .encoders import jsonable_encoder
 | 
					from .encoders import jsonable_encoder
 | 
				
			||||||
from .render_tree import render_tree
 | 
					from .render_tree import render_tree
 | 
				
			||||||
from .token_escaper import TokenEscaper
 | 
					from .token_escaper import TokenEscaper
 | 
				
			||||||
 | 
					from ..unasync_util import ASYNC_MODE
 | 
				
			||||||
 | 
					
 | 
				
			||||||
model_registry = {}
 | 
					model_registry = {}
 | 
				
			||||||
_T = TypeVar("_T")
 | 
					_T = TypeVar("_T")
 | 
				
			||||||
| 
						 | 
					@ -521,7 +522,7 @@ class FindQuery:
 | 
				
			||||||
                    # this is not going to work.
 | 
					                    # this is not going to work.
 | 
				
			||||||
                    log.warning(
 | 
					                    log.warning(
 | 
				
			||||||
                        "Your query against the field %s is for a single character, %s, "
 | 
					                        "Your query against the field %s is for a single character, %s, "
 | 
				
			||||||
                        "that is used internally by redis-developer-python. We must ignore "
 | 
					                        "that is used internally by redis-om-python. We must ignore "
 | 
				
			||||||
                        "this portion of the query. Please review your query to find "
 | 
					                        "this portion of the query. Please review your query to find "
 | 
				
			||||||
                        "an alternative query that uses a string containing more than "
 | 
					                        "an alternative query that uses a string containing more than "
 | 
				
			||||||
                        "just the character %s.",
 | 
					                        "just the character %s.",
 | 
				
			||||||
| 
						 | 
					@ -680,7 +681,7 @@ class FindQuery:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return result
 | 
					        return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def execute(self, exhaust_results=True):
 | 
					    async def execute(self, exhaust_results=True):
 | 
				
			||||||
        args = ["ft.search", self.model.Meta.index_name, self.query, *self.pagination]
 | 
					        args = ["ft.search", self.model.Meta.index_name, self.query, *self.pagination]
 | 
				
			||||||
        if self.sort_fields:
 | 
					        if self.sort_fields:
 | 
				
			||||||
            args += self.resolve_redisearch_sort_fields()
 | 
					            args += self.resolve_redisearch_sort_fields()
 | 
				
			||||||
| 
						 | 
					@ -691,7 +692,7 @@ class FindQuery:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # If the offset is greater than 0, we're paginating through a result set,
 | 
					        # If the offset is greater than 0, we're paginating through a result set,
 | 
				
			||||||
        # so append the new results to results already in the cache.
 | 
					        # so append the new results to results already in the cache.
 | 
				
			||||||
        raw_result = self.model.db().execute_command(*args)
 | 
					        raw_result = await self.model.db().execute_command(*args)
 | 
				
			||||||
        count = raw_result[0]
 | 
					        count = raw_result[0]
 | 
				
			||||||
        results = self.model.from_redis(raw_result)
 | 
					        results = self.model.from_redis(raw_result)
 | 
				
			||||||
        self._model_cache += results
 | 
					        self._model_cache += results
 | 
				
			||||||
| 
						 | 
					@ -710,31 +711,31 @@ class FindQuery:
 | 
				
			||||||
            # Make a query for each pass of the loop, with a new offset equal to the
 | 
					            # Make a query for each pass of the loop, with a new offset equal to the
 | 
				
			||||||
            # current offset plus `page_size`, until we stop getting results back.
 | 
					            # current offset plus `page_size`, until we stop getting results back.
 | 
				
			||||||
            query = query.copy(offset=query.offset + query.page_size)
 | 
					            query = query.copy(offset=query.offset + query.page_size)
 | 
				
			||||||
            _results = query.execute(exhaust_results=False)
 | 
					            _results = await query.execute(exhaust_results=False)
 | 
				
			||||||
            if not _results:
 | 
					            if not _results:
 | 
				
			||||||
                break
 | 
					                break
 | 
				
			||||||
            self._model_cache += _results
 | 
					            self._model_cache += _results
 | 
				
			||||||
        return self._model_cache
 | 
					        return self._model_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def first(self):
 | 
					    async def first(self):
 | 
				
			||||||
        query = self.copy(offset=0, limit=1, sort_fields=self.sort_fields)
 | 
					        query = self.copy(offset=0, limit=1, sort_fields=self.sort_fields)
 | 
				
			||||||
        results = query.execute()
 | 
					        results = await query.execute()
 | 
				
			||||||
        if not results:
 | 
					        if not results:
 | 
				
			||||||
            raise NotFoundError()
 | 
					            raise NotFoundError()
 | 
				
			||||||
        return results[0]
 | 
					        return results[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def all(self, batch_size=10):
 | 
					    async def all(self, batch_size=10):
 | 
				
			||||||
        if batch_size != self.page_size:
 | 
					        if batch_size != self.page_size:
 | 
				
			||||||
            query = self.copy(page_size=batch_size, limit=batch_size)
 | 
					            query = self.copy(page_size=batch_size, limit=batch_size)
 | 
				
			||||||
            return query.execute()
 | 
					            return await query.execute()
 | 
				
			||||||
        return self.execute()
 | 
					        return await self.execute()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def sort_by(self, *fields: str):
 | 
					    def sort_by(self, *fields: str):
 | 
				
			||||||
        if not fields:
 | 
					        if not fields:
 | 
				
			||||||
            return self
 | 
					            return self
 | 
				
			||||||
        return self.copy(sort_fields=list(fields))
 | 
					        return self.copy(sort_fields=list(fields))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def update(self, use_transaction=True, **field_values):
 | 
					    async def update(self, use_transaction=True, **field_values):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Update models that match this query to the given field-value pairs.
 | 
					        Update models that match this query to the given field-value pairs.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -743,31 +744,32 @@ class FindQuery:
 | 
				
			||||||
        given fields.
 | 
					        given fields.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        validate_model_fields(self.model, field_values)
 | 
					        validate_model_fields(self.model, field_values)
 | 
				
			||||||
        pipeline = self.model.db().pipeline() if use_transaction else None
 | 
					        pipeline = await self.model.db().pipeline() if use_transaction else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for model in self.all():
 | 
					        # TODO: async for here?
 | 
				
			||||||
 | 
					        for model in await self.all():
 | 
				
			||||||
            for field, value in field_values.items():
 | 
					            for field, value in field_values.items():
 | 
				
			||||||
                setattr(model, field, value)
 | 
					                setattr(model, field, value)
 | 
				
			||||||
            # TODO: In the non-transaction case, can we do more to detect
 | 
					            # TODO: In the non-transaction case, can we do more to detect
 | 
				
			||||||
            #  failure responses from Redis?
 | 
					            #  failure responses from Redis?
 | 
				
			||||||
            model.save(pipeline=pipeline)
 | 
					            await model.save(pipeline=pipeline)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if pipeline:
 | 
					        if pipeline:
 | 
				
			||||||
            # TODO: Response type?
 | 
					            # TODO: Response type?
 | 
				
			||||||
            # TODO: Better error detection for transactions.
 | 
					            # TODO: Better error detection for transactions.
 | 
				
			||||||
            pipeline.execute()
 | 
					            pipeline.execute()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def delete(self):
 | 
					    async def delete(self):
 | 
				
			||||||
        """Delete all matching records in this query."""
 | 
					        """Delete all matching records in this query."""
 | 
				
			||||||
        # TODO: Better response type, error detection
 | 
					        # TODO: Better response type, error detection
 | 
				
			||||||
        return self.model.db().delete(*[m.key() for m in self.all()])
 | 
					        return await self.model.db().delete(*[m.key() for m in await self.all()])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __iter__(self):
 | 
					    async def __aiter__(self):
 | 
				
			||||||
        if self._model_cache:
 | 
					        if self._model_cache:
 | 
				
			||||||
            for m in self._model_cache:
 | 
					            for m in self._model_cache:
 | 
				
			||||||
                yield m
 | 
					                yield m
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            for m in self.execute():
 | 
					            for m in await self.execute():
 | 
				
			||||||
                yield m
 | 
					                yield m
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __getitem__(self, item: int):
 | 
					    def __getitem__(self, item: int):
 | 
				
			||||||
| 
						 | 
					@ -784,12 +786,39 @@ class FindQuery:
 | 
				
			||||||
               that result, then we should clone the current query and
 | 
					               that result, then we should clone the current query and
 | 
				
			||||||
               give it a new offset and limit: offset=n, limit=1.
 | 
					               give it a new offset and limit: offset=n, limit=1.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					        if ASYNC_MODE:
 | 
				
			||||||
 | 
					            raise QuerySyntaxError("Cannot use [] notation with async code. "
 | 
				
			||||||
 | 
					                                   "Use FindQuery.get_item() instead.")
 | 
				
			||||||
        if self._model_cache and len(self._model_cache) >= item:
 | 
					        if self._model_cache and len(self._model_cache) >= item:
 | 
				
			||||||
            return self._model_cache[item]
 | 
					            return self._model_cache[item]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        query = self.copy(offset=item, limit=1)
 | 
					        query = self.copy(offset=item, limit=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return query.execute()[0]
 | 
					        return query.execute()[0]  # noqa
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def get_item(self, item: int):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Given this code:
 | 
				
			||||||
 | 
					            await Model.find().get_item(1000)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        We should return only the 1000th result.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            1. If the result is loaded in the query cache for this query,
 | 
				
			||||||
 | 
					               we can return it directly from the cache.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            2. If the query cache does not have enough elements to return
 | 
				
			||||||
 | 
					               that result, then we should clone the current query and
 | 
				
			||||||
 | 
					               give it a new offset and limit: offset=n, limit=1.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        NOTE: This method is included specifically for async users, who
 | 
				
			||||||
 | 
					        cannot use the notation Model.find()[1000].
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if self._model_cache and len(self._model_cache) >= item:
 | 
				
			||||||
 | 
					            return self._model_cache[item]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        query = self.copy(offset=item, limit=1)
 | 
				
			||||||
 | 
					        result = await query.execute()
 | 
				
			||||||
 | 
					        return result[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class PrimaryKeyCreator(Protocol):
 | 
					class PrimaryKeyCreator(Protocol):
 | 
				
			||||||
| 
						 | 
					@ -913,7 +942,7 @@ class MetaProtocol(Protocol):
 | 
				
			||||||
    global_key_prefix: str
 | 
					    global_key_prefix: str
 | 
				
			||||||
    model_key_prefix: str
 | 
					    model_key_prefix: str
 | 
				
			||||||
    primary_key_pattern: str
 | 
					    primary_key_pattern: str
 | 
				
			||||||
    database: redis.Redis
 | 
					    database: aioredis.Redis
 | 
				
			||||||
    primary_key: PrimaryKey
 | 
					    primary_key: PrimaryKey
 | 
				
			||||||
    primary_key_creator_cls: Type[PrimaryKeyCreator]
 | 
					    primary_key_creator_cls: Type[PrimaryKeyCreator]
 | 
				
			||||||
    index_name: str
 | 
					    index_name: str
 | 
				
			||||||
| 
						 | 
					@ -932,7 +961,7 @@ class DefaultMeta:
 | 
				
			||||||
    global_key_prefix: Optional[str] = None
 | 
					    global_key_prefix: Optional[str] = None
 | 
				
			||||||
    model_key_prefix: Optional[str] = None
 | 
					    model_key_prefix: Optional[str] = None
 | 
				
			||||||
    primary_key_pattern: Optional[str] = None
 | 
					    primary_key_pattern: Optional[str] = None
 | 
				
			||||||
    database: Optional[redis.Redis] = None
 | 
					    database: Optional[Union[redis.Redis, aioredis.Redis]] = None
 | 
				
			||||||
    primary_key: Optional[PrimaryKey] = None
 | 
					    primary_key: Optional[PrimaryKey] = None
 | 
				
			||||||
    primary_key_creator_cls: Optional[Type[PrimaryKeyCreator]] = None
 | 
					    primary_key_creator_cls: Optional[Type[PrimaryKeyCreator]] = None
 | 
				
			||||||
    index_name: Optional[str] = None
 | 
					    index_name: Optional[str] = None
 | 
				
			||||||
| 
						 | 
					@ -1049,14 +1078,18 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
 | 
				
			||||||
        pk = getattr(self, self._meta.primary_key.field.name)
 | 
					        pk = getattr(self, self._meta.primary_key.field.name)
 | 
				
			||||||
        return self.make_primary_key(pk)
 | 
					        return self.make_primary_key(pk)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def delete(self):
 | 
					    async def delete(self):
 | 
				
			||||||
        return self.db().delete(self.key())
 | 
					        return await self.db().delete(self.key())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def update(self, **field_values):
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    async def get(cls, pk: Any) -> 'RedisModel':
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def update(self, **field_values):
 | 
				
			||||||
        """Update this model instance with the specified key-value pairs."""
 | 
					        """Update this model instance with the specified key-value pairs."""
 | 
				
			||||||
        raise NotImplementedError
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def save(self, pipeline: Optional[Pipeline] = None) -> "RedisModel":
 | 
					    async def save(self, pipeline: Optional[Pipeline] = None) -> "RedisModel":
 | 
				
			||||||
        raise NotImplementedError
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @validator("pk", always=True)
 | 
					    @validator("pk", always=True)
 | 
				
			||||||
| 
						 | 
					@ -1158,9 +1191,9 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
 | 
				
			||||||
        return d
 | 
					        return d
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def add(cls, models: Sequence["RedisModel"]) -> Sequence["RedisModel"]:
 | 
					    async def add(cls, models: Sequence["RedisModel"]) -> Sequence["RedisModel"]:
 | 
				
			||||||
        # TODO: Add transaction support
 | 
					        # TODO: Add transaction support
 | 
				
			||||||
        return [model.save() for model in models]
 | 
					        return [await model.save() for model in models]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def values(cls):
 | 
					    def values(cls):
 | 
				
			||||||
| 
						 | 
					@ -1189,17 +1222,18 @@ class HashModel(RedisModel, abc.ABC):
 | 
				
			||||||
                        f" or mapping fields. Field: {name}"
 | 
					                        f" or mapping fields. Field: {name}"
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def save(self, pipeline: Optional[Pipeline] = None) -> "HashModel":
 | 
					    async def save(self, pipeline: Optional[Pipeline] = None) -> "HashModel":
 | 
				
			||||||
        if pipeline is None:
 | 
					        if pipeline is None:
 | 
				
			||||||
            db = self.db()
 | 
					            db = self.db()
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            db = pipeline
 | 
					            db = pipeline
 | 
				
			||||||
        document = jsonable_encoder(self.dict())
 | 
					        document = jsonable_encoder(self.dict())
 | 
				
			||||||
        db.hset(self.key(), mapping=document)
 | 
					        # TODO: Wrap any Redis response errors in a custom exception?
 | 
				
			||||||
 | 
					        await db.hset(self.key(), mapping=document)
 | 
				
			||||||
        return self
 | 
					        return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def get(cls, pk: Any) -> "HashModel":
 | 
					    async def get(cls, pk: Any) -> "HashModel":
 | 
				
			||||||
        document = cls.db().hgetall(cls.make_primary_key(pk))
 | 
					        document = cls.db().hgetall(cls.make_primary_key(pk))
 | 
				
			||||||
        if not document:
 | 
					        if not document:
 | 
				
			||||||
            raise NotFoundError
 | 
					            raise NotFoundError
 | 
				
			||||||
| 
						 | 
					@ -1311,23 +1345,24 @@ class JsonModel(RedisModel, abc.ABC):
 | 
				
			||||||
        # Generate the RediSearch schema once to validate fields.
 | 
					        # Generate the RediSearch schema once to validate fields.
 | 
				
			||||||
        cls.redisearch_schema()
 | 
					        cls.redisearch_schema()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def save(self, pipeline: Optional[Pipeline] = None) -> "JsonModel":
 | 
					    async def save(self, pipeline: Optional[Pipeline] = None) -> "JsonModel":
 | 
				
			||||||
        if pipeline is None:
 | 
					        if pipeline is None:
 | 
				
			||||||
            db = self.db()
 | 
					            db = self.db()
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            db = pipeline
 | 
					            db = pipeline
 | 
				
			||||||
        db.execute_command("JSON.SET", self.key(), ".", self.json())
 | 
					        # TODO: Wrap response errors in a custom exception?
 | 
				
			||||||
 | 
					        await db.execute_command("JSON.SET", self.key(), ".", self.json())
 | 
				
			||||||
        return self
 | 
					        return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def update(self, **field_values):
 | 
					    async def update(self, **field_values):
 | 
				
			||||||
        validate_model_fields(self.__class__, field_values)
 | 
					        validate_model_fields(self.__class__, field_values)
 | 
				
			||||||
        for field, value in field_values.items():
 | 
					        for field, value in field_values.items():
 | 
				
			||||||
            setattr(self, field, value)
 | 
					            setattr(self, field, value)
 | 
				
			||||||
        self.save()
 | 
					        await self.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def get(cls, pk: Any) -> "JsonModel":
 | 
					    async def get(cls, pk: Any) -> "JsonModel":
 | 
				
			||||||
        document = cls.db().execute_command("JSON.GET", cls.make_primary_key(pk))
 | 
					        document = await cls.db().execute_command("JSON.GET", cls.make_primary_key(pk))
 | 
				
			||||||
        if not document:
 | 
					        if not document:
 | 
				
			||||||
            raise NotFoundError
 | 
					            raise NotFoundError
 | 
				
			||||||
        return cls.parse_raw(document)
 | 
					        return cls.parse_raw(document)
 | 
				
			||||||
| 
						 | 
					@ -1,17 +1,17 @@
 | 
				
			||||||
import abc
 | 
					import abc
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from redis_developer.model.model import HashModel, JsonModel
 | 
					from redis_om.model.model import HashModel, JsonModel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BaseJsonModel(JsonModel, abc.ABC):
 | 
					class BaseJsonModel(JsonModel, abc.ABC):
 | 
				
			||||||
    class Meta:
 | 
					    class Meta:
 | 
				
			||||||
        global_key_prefix = "redis-developer"
 | 
					        global_key_prefix = "redis-om"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BaseHashModel(HashModel, abc.ABC):
 | 
					class BaseHashModel(HashModel, abc.ABC):
 | 
				
			||||||
    class Meta:
 | 
					    class Meta:
 | 
				
			||||||
        global_key_prefix = "redis-developer"
 | 
					        global_key_prefix = "redis-om"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# class AddressJson(BaseJsonModel):
 | 
					# class AddressJson(BaseJsonModel):
 | 
				
			||||||
| 
						 | 
					@ -1,7 +1,7 @@
 | 
				
			||||||
from collections import Sequence
 | 
					from collections import Sequence
 | 
				
			||||||
from typing import Any, Dict, List, Mapping, Union
 | 
					from typing import Any, Dict, List, Mapping, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from redis_developer.model.model import Expression
 | 
					from redis_om.model.model import Expression
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LogicalOperatorForListOfExpressions(Expression):
 | 
					class LogicalOperatorForListOfExpressions(Expression):
 | 
				
			||||||
							
								
								
									
										40
									
								
								redis_om/unasync_util.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								redis_om/unasync_util.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,40 @@
 | 
				
			||||||
 | 
					"""Set of utility functions for unasync that transform into sync counterparts cleanly"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import inspect
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_original_next = next
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def is_async_mode():
 | 
				
			||||||
 | 
					    """Tests if we're in the async part of the code or not"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def f():
 | 
				
			||||||
 | 
					        """Unasync transforms async functions in sync functions"""
 | 
				
			||||||
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    obj = f()
 | 
				
			||||||
 | 
					    if obj is None:
 | 
				
			||||||
 | 
					        return False
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        obj.close()  # prevent unawaited coroutine warning
 | 
				
			||||||
 | 
					        return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					ASYNC_MODE = is_async_mode()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def anext(x):
 | 
				
			||||||
 | 
					    return await x.__anext__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def await_if_coro(x):
 | 
				
			||||||
 | 
					    if inspect.iscoroutine(x):
 | 
				
			||||||
 | 
					        return await x
 | 
				
			||||||
 | 
					    return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					next = _original_next
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def return_non_coro(x):
 | 
				
			||||||
 | 
					    return x
 | 
				
			||||||
| 
						 | 
					@ -1,19 +1,18 @@
 | 
				
			||||||
import random
 | 
					import random
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
from redis import Redis
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from redis_developer.connections import get_redis_connection
 | 
					from redis_om.connections import get_redis_connection
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.fixture
 | 
					@pytest.fixture
 | 
				
			||||||
def redis():
 | 
					def redis(event_loop):
 | 
				
			||||||
    yield get_redis_connection()
 | 
					    yield get_redis_connection()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _delete_test_keys(prefix: str, conn: Redis):
 | 
					async def _delete_test_keys(prefix: str, conn):
 | 
				
			||||||
    keys = []
 | 
					    keys = []
 | 
				
			||||||
    for key in conn.scan_iter(f"{prefix}:*"):
 | 
					    async for key in conn.scan_iter(f"{prefix}:*"):
 | 
				
			||||||
        keys.append(key)
 | 
					        keys.append(key)
 | 
				
			||||||
    if keys:
 | 
					    if keys:
 | 
				
			||||||
        conn.delete(*keys)
 | 
					        conn.delete(*keys)
 | 
				
			||||||
| 
						 | 
					@ -21,11 +20,10 @@ def _delete_test_keys(prefix: str, conn: Redis):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.fixture
 | 
					@pytest.fixture
 | 
				
			||||||
def key_prefix(redis):
 | 
					def key_prefix(redis):
 | 
				
			||||||
    key_prefix = f"redis-developer:{random.random()}"
 | 
					    key_prefix = f"redis-om:{random.random()}"
 | 
				
			||||||
    yield key_prefix
 | 
					    yield key_prefix
 | 
				
			||||||
    _delete_test_keys(key_prefix, redis)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.fixture(autouse=True)
 | 
					@pytest.fixture(autouse=True)
 | 
				
			||||||
def delete_test_keys(redis, request, key_prefix):
 | 
					async def delete_test_keys(redis, request, key_prefix):
 | 
				
			||||||
    _delete_test_keys(key_prefix, redis)
 | 
					    await _delete_test_keys(key_prefix, redis)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -8,9 +8,9 @@ from unittest import mock
 | 
				
			||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
from pydantic import ValidationError
 | 
					from pydantic import ValidationError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from redis_developer.model import Field, HashModel
 | 
					from redis_om.model import Field, HashModel
 | 
				
			||||||
from redis_developer.model.migrations.migrator import Migrator
 | 
					from redis_om.model.migrations.migrator import Migrator
 | 
				
			||||||
from redis_developer.model.model import (
 | 
					from redis_om.model.model import (
 | 
				
			||||||
    NotFoundError,
 | 
					    NotFoundError,
 | 
				
			||||||
    QueryNotSupportedError,
 | 
					    QueryNotSupportedError,
 | 
				
			||||||
    RedisModelError,
 | 
					    RedisModelError,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,5 @@
 | 
				
			||||||
import abc
 | 
					import abc
 | 
				
			||||||
 | 
					import asyncio
 | 
				
			||||||
import datetime
 | 
					import datetime
 | 
				
			||||||
import decimal
 | 
					import decimal
 | 
				
			||||||
from collections import namedtuple
 | 
					from collections import namedtuple
 | 
				
			||||||
| 
						 | 
					@ -8,9 +9,9 @@ from unittest import mock
 | 
				
			||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
from pydantic import ValidationError
 | 
					from pydantic import ValidationError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from redis_developer.model import EmbeddedJsonModel, Field, JsonModel
 | 
					from redis_om.model import EmbeddedJsonModel, Field, JsonModel
 | 
				
			||||||
from redis_developer.model.migrations.migrator import Migrator
 | 
					from redis_om.model.migrations.migrator import Migrator
 | 
				
			||||||
from redis_developer.model.model import (
 | 
					from redis_om.model.model import (
 | 
				
			||||||
    NotFoundError,
 | 
					    NotFoundError,
 | 
				
			||||||
    QueryNotSupportedError,
 | 
					    QueryNotSupportedError,
 | 
				
			||||||
    RedisModelError,
 | 
					    RedisModelError,
 | 
				
			||||||
| 
						 | 
					@ -21,7 +22,7 @@ today = datetime.date.today()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.fixture
 | 
					@pytest.fixture
 | 
				
			||||||
def m(key_prefix):
 | 
					async def m(key_prefix, redis):
 | 
				
			||||||
    class BaseJsonModel(JsonModel, abc.ABC):
 | 
					    class BaseJsonModel(JsonModel, abc.ABC):
 | 
				
			||||||
        class Meta:
 | 
					        class Meta:
 | 
				
			||||||
            global_key_prefix = key_prefix
 | 
					            global_key_prefix = key_prefix
 | 
				
			||||||
| 
						 | 
					@ -64,7 +65,7 @@ def m(key_prefix):
 | 
				
			||||||
        # Creates an embedded list of models.
 | 
					        # Creates an embedded list of models.
 | 
				
			||||||
        orders: Optional[List[Order]]
 | 
					        orders: Optional[List[Order]]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Migrator().run()
 | 
					    await Migrator(redis).run()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return namedtuple(
 | 
					    return namedtuple(
 | 
				
			||||||
        "Models", ["BaseJsonModel", "Note", "Address", "Item", "Order", "Member"]
 | 
					        "Models", ["BaseJsonModel", "Note", "Address", "Item", "Order", "Member"]
 | 
				
			||||||
| 
						 | 
					@ -83,7 +84,7 @@ def address(m):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.fixture()
 | 
					@pytest.fixture()
 | 
				
			||||||
def members(address, m):
 | 
					async def members(address, m):
 | 
				
			||||||
    member1 = m.Member(
 | 
					    member1 = m.Member(
 | 
				
			||||||
        first_name="Andrew",
 | 
					        first_name="Andrew",
 | 
				
			||||||
        last_name="Brookins",
 | 
					        last_name="Brookins",
 | 
				
			||||||
| 
						 | 
					@ -111,14 +112,15 @@ def members(address, m):
 | 
				
			||||||
        address=address,
 | 
					        address=address,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    member1.save()
 | 
					    await member1.save()
 | 
				
			||||||
    member2.save()
 | 
					    await member2.save()
 | 
				
			||||||
    member3.save()
 | 
					    await member3.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    yield member1, member2, member3
 | 
					    yield member1, member2, member3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_validates_required_fields(address, m):
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_validates_required_fields(address, m):
 | 
				
			||||||
    # Raises ValidationError address is required
 | 
					    # Raises ValidationError address is required
 | 
				
			||||||
    with pytest.raises(ValidationError):
 | 
					    with pytest.raises(ValidationError):
 | 
				
			||||||
        m.Member(
 | 
					        m.Member(
 | 
				
			||||||
| 
						 | 
					@ -129,7 +131,8 @@ def test_validates_required_fields(address, m):
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_validates_field(address, m):
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_validates_field(address, m):
 | 
				
			||||||
    # Raises ValidationError: join_date is not a date
 | 
					    # Raises ValidationError: join_date is not a date
 | 
				
			||||||
    with pytest.raises(ValidationError):
 | 
					    with pytest.raises(ValidationError):
 | 
				
			||||||
        m.Member(
 | 
					        m.Member(
 | 
				
			||||||
| 
						 | 
					@ -141,7 +144,8 @@ def test_validates_field(address, m):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Passes validation
 | 
					# Passes validation
 | 
				
			||||||
def test_validation_passes(address, m):
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_validation_passes(address, m):
 | 
				
			||||||
    member = m.Member(
 | 
					    member = m.Member(
 | 
				
			||||||
        first_name="Andrew",
 | 
					        first_name="Andrew",
 | 
				
			||||||
        last_name="Brookins",
 | 
					        last_name="Brookins",
 | 
				
			||||||
| 
						 | 
					@ -153,7 +157,10 @@ def test_validation_passes(address, m):
 | 
				
			||||||
    assert member.first_name == "Andrew"
 | 
					    assert member.first_name == "Andrew"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_saves_model_and_creates_pk(address, m):
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_saves_model_and_creates_pk(address, m, redis):
 | 
				
			||||||
 | 
					    await Migrator(redis).run()
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
    member = m.Member(
 | 
					    member = m.Member(
 | 
				
			||||||
        first_name="Andrew",
 | 
					        first_name="Andrew",
 | 
				
			||||||
        last_name="Brookins",
 | 
					        last_name="Brookins",
 | 
				
			||||||
| 
						 | 
					@ -163,15 +170,16 @@ def test_saves_model_and_creates_pk(address, m):
 | 
				
			||||||
        address=address,
 | 
					        address=address,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    # Save a model instance to Redis
 | 
					    # Save a model instance to Redis
 | 
				
			||||||
    member.save()
 | 
					    await member.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    member2 = m.Member.get(member.pk)
 | 
					    member2 = await m.Member.get(member.pk)
 | 
				
			||||||
    assert member2 == member
 | 
					    assert member2 == member
 | 
				
			||||||
    assert member2.address == address
 | 
					    assert member2.address == address
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.skip("Not implemented yet")
 | 
					@pytest.mark.skip("Not implemented yet")
 | 
				
			||||||
def test_saves_many(address, m):
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_saves_many(address, m):
 | 
				
			||||||
    members = [
 | 
					    members = [
 | 
				
			||||||
        m.Member(
 | 
					        m.Member(
 | 
				
			||||||
            first_name="Andrew",
 | 
					            first_name="Andrew",
 | 
				
			||||||
| 
						 | 
					@ -193,9 +201,16 @@ def test_saves_many(address, m):
 | 
				
			||||||
    m.Member.add(members)
 | 
					    m.Member.add(members)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def save(members):
 | 
				
			||||||
 | 
					    for m in members:
 | 
				
			||||||
 | 
					        await m.save()
 | 
				
			||||||
 | 
					    return members
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.skip("Not ready yet")
 | 
					@pytest.mark.skip("Not ready yet")
 | 
				
			||||||
def test_updates_a_model(members, m):
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
    member1, member2, member3 = members
 | 
					async def test_updates_a_model(members, m):
 | 
				
			||||||
 | 
					    member1, member2, member3 = await save(members)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Or, with an implicit save:
 | 
					    # Or, with an implicit save:
 | 
				
			||||||
    member1.update(last_name="Smith")
 | 
					    member1.update(last_name="Smith")
 | 
				
			||||||
| 
						 | 
					@ -213,18 +228,20 @@ def test_updates_a_model(members, m):
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_paginate_query(members, m):
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_paginate_query(members, m):
 | 
				
			||||||
    member1, member2, member3 = members
 | 
					    member1, member2, member3 = members
 | 
				
			||||||
    actual = m.Member.find().sort_by("age").all(batch_size=1)
 | 
					    actual = await m.Member.find().sort_by("age").all(batch_size=1)
 | 
				
			||||||
    assert actual == [member2, member1, member3]
 | 
					    assert actual == [member2, member1, member3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_access_result_by_index_cached(members, m):
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_access_result_by_index_cached(members, m):
 | 
				
			||||||
    member1, member2, member3 = members
 | 
					    member1, member2, member3 = members
 | 
				
			||||||
    query = m.Member.find().sort_by("age")
 | 
					    query = m.Member.find().sort_by("age")
 | 
				
			||||||
    # Load the cache, throw away the result.
 | 
					    # Load the cache, throw away the result.
 | 
				
			||||||
    assert query._model_cache == []
 | 
					    assert query._model_cache == []
 | 
				
			||||||
    query.execute()
 | 
					    await query.execute()
 | 
				
			||||||
    assert query._model_cache == [member2, member1, member3]
 | 
					    assert query._model_cache == [member2, member1, member3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Access an item that should be in the cache.
 | 
					    # Access an item that should be in the cache.
 | 
				
			||||||
| 
						 | 
					@ -233,21 +250,23 @@ def test_access_result_by_index_cached(members, m):
 | 
				
			||||||
        assert not mock_db.called
 | 
					        assert not mock_db.called
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_access_result_by_index_not_cached(members, m):
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_access_result_by_index_not_cached(members, m):
 | 
				
			||||||
    member1, member2, member3 = members
 | 
					    member1, member2, member3 = members
 | 
				
			||||||
    query = m.Member.find().sort_by("age")
 | 
					    query = m.Member.find().sort_by("age")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Assert that we don't have any models in the cache yet -- we
 | 
					    # Assert that we don't have any models in the cache yet -- we
 | 
				
			||||||
    # haven't made any requests of Redis.
 | 
					    # haven't made any requests of Redis.
 | 
				
			||||||
    assert query._model_cache == []
 | 
					    assert query._model_cache == []
 | 
				
			||||||
    assert query[0] == member2
 | 
					    assert query.get_item(0) == member2
 | 
				
			||||||
    assert query[1] == member1
 | 
					    assert query.get_item(1) == member1
 | 
				
			||||||
    assert query[2] == member3
 | 
					    assert query.get_item(2) == member3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_in_query(members, m):
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_in_query(members, m):
 | 
				
			||||||
    member1, member2, member3 = members
 | 
					    member1, member2, member3 = members
 | 
				
			||||||
    actual = (
 | 
					    actual = await (
 | 
				
			||||||
        m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk])
 | 
					        m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk])
 | 
				
			||||||
        .sort_by("age")
 | 
					        .sort_by("age")
 | 
				
			||||||
        .all()
 | 
					        .all()
 | 
				
			||||||
| 
						 | 
					@ -256,12 +275,13 @@ def test_in_query(members, m):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.skip("Not implemented yet")
 | 
					@pytest.mark.skip("Not implemented yet")
 | 
				
			||||||
def test_update_query(members, m):
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_update_query(members, m):
 | 
				
			||||||
    member1, member2, member3 = members
 | 
					    member1, member2, member3 = members
 | 
				
			||||||
    m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk]).update(
 | 
					    await m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk]).update(
 | 
				
			||||||
        first_name="Bobby"
 | 
					        first_name="Bobby"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    actual = (
 | 
					    actual = await (
 | 
				
			||||||
        m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk])
 | 
					        m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk])
 | 
				
			||||||
        .sort_by("age")
 | 
					        .sort_by("age")
 | 
				
			||||||
        .all()
 | 
					        .all()
 | 
				
			||||||
| 
						 | 
					@ -270,24 +290,25 @@ def test_update_query(members, m):
 | 
				
			||||||
    assert all([m.name == "Bobby" for m in actual])
 | 
					    assert all([m.name == "Bobby" for m in actual])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_exact_match_queries(members, m):
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_exact_match_queries(members, m):
 | 
				
			||||||
    member1, member2, member3 = members
 | 
					    member1, member2, member3 = members
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    actual = m.Member.find(m.Member.last_name == "Brookins").sort_by("age").all()
 | 
					    actual = await m.Member.find(m.Member.last_name == "Brookins").sort_by("age").all()
 | 
				
			||||||
    assert actual == [member2, member1]
 | 
					    assert actual == [member2, member1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    actual = m.Member.find(
 | 
					    actual = await m.Member.find(
 | 
				
			||||||
        (m.Member.last_name == "Brookins") & ~(m.Member.first_name == "Andrew")
 | 
					        (m.Member.last_name == "Brookins") & ~(m.Member.first_name == "Andrew")
 | 
				
			||||||
    ).all()
 | 
					    ).all()
 | 
				
			||||||
    assert actual == [member2]
 | 
					    assert actual == [member2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    actual = m.Member.find(~(m.Member.last_name == "Brookins")).all()
 | 
					    actual = await m.Member.find(~(m.Member.last_name == "Brookins")).all()
 | 
				
			||||||
    assert actual == [member3]
 | 
					    assert actual == [member3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    actual = m.Member.find(m.Member.last_name != "Brookins").all()
 | 
					    actual = await m.Member.find(m.Member.last_name != "Brookins").all()
 | 
				
			||||||
    assert actual == [member3]
 | 
					    assert actual == [member3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    actual = (
 | 
					    actual = await (
 | 
				
			||||||
        m.Member.find(
 | 
					        m.Member.find(
 | 
				
			||||||
            (m.Member.last_name == "Brookins") & (m.Member.first_name == "Andrew")
 | 
					            (m.Member.last_name == "Brookins") & (m.Member.first_name == "Andrew")
 | 
				
			||||||
            | (m.Member.first_name == "Kim")
 | 
					            | (m.Member.first_name == "Kim")
 | 
				
			||||||
| 
						 | 
					@ -297,19 +318,20 @@ def test_exact_match_queries(members, m):
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    assert actual == [member2, member1]
 | 
					    assert actual == [member2, member1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    actual = m.Member.find(
 | 
					    actual = await m.Member.find(
 | 
				
			||||||
        m.Member.first_name == "Kim", m.Member.last_name == "Brookins"
 | 
					        m.Member.first_name == "Kim", m.Member.last_name == "Brookins"
 | 
				
			||||||
    ).all()
 | 
					    ).all()
 | 
				
			||||||
    assert actual == [member2]
 | 
					    assert actual == [member2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    actual = m.Member.find(m.Member.address.city == "Portland").sort_by("age").all()
 | 
					    actual = await m.Member.find(m.Member.address.city == "Portland").sort_by("age").all()
 | 
				
			||||||
    assert actual == [member2, member1, member3]
 | 
					    assert actual == [member2, member1, member3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_recursive_query_expression_resolution(members, m):
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_recursive_query_expression_resolution(members, m):
 | 
				
			||||||
    member1, member2, member3 = members
 | 
					    member1, member2, member3 = members
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    actual = (
 | 
					    actual = await (
 | 
				
			||||||
        m.Member.find(
 | 
					        m.Member.find(
 | 
				
			||||||
            (m.Member.last_name == "Brookins")
 | 
					            (m.Member.last_name == "Brookins")
 | 
				
			||||||
            | (m.Member.age == 100) & (m.Member.last_name == "Smith")
 | 
					            | (m.Member.age == 100) & (m.Member.last_name == "Smith")
 | 
				
			||||||
| 
						 | 
					@ -320,13 +342,14 @@ def test_recursive_query_expression_resolution(members, m):
 | 
				
			||||||
    assert actual == [member2, member1, member3]
 | 
					    assert actual == [member2, member1, member3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_recursive_query_field_resolution(members, m):
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_recursive_query_field_resolution(members, m):
 | 
				
			||||||
    member1, _, _ = members
 | 
					    member1, _, _ = members
 | 
				
			||||||
    member1.address.note = m.Note(
 | 
					    member1.address.note = m.Note(
 | 
				
			||||||
        description="Weird house", created_on=datetime.datetime.now()
 | 
					        description="Weird house", created_on=datetime.datetime.now()
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    member1.save()
 | 
					    await member1.save()
 | 
				
			||||||
    actual = m.Member.find(m.Member.address.note.description == "Weird house").all()
 | 
					    actual = await m.Member.find(m.Member.address.note.description == "Weird house").all()
 | 
				
			||||||
    assert actual == [member1]
 | 
					    assert actual == [member1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    member1.orders = [
 | 
					    member1.orders = [
 | 
				
			||||||
| 
						 | 
					@ -336,29 +359,31 @@ def test_recursive_query_field_resolution(members, m):
 | 
				
			||||||
            created_on=datetime.datetime.now(),
 | 
					            created_on=datetime.datetime.now(),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    ]
 | 
					    ]
 | 
				
			||||||
    member1.save()
 | 
					    await member1.save()
 | 
				
			||||||
    actual = m.Member.find(m.Member.orders.items.name == "Ball").all()
 | 
					    actual = await m.Member.find(m.Member.orders.items.name == "Ball").all()
 | 
				
			||||||
    assert actual == [member1]
 | 
					    assert actual == [member1]
 | 
				
			||||||
    assert actual[0].orders[0].items[0].name == "Ball"
 | 
					    assert actual[0].orders[0].items[0].name == "Ball"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_full_text_search(members, m):
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_full_text_search(members, m):
 | 
				
			||||||
    member1, member2, _ = members
 | 
					    member1, member2, _ = members
 | 
				
			||||||
    member1.update(bio="Hates sunsets, likes beaches")
 | 
					    await member1.update(bio="Hates sunsets, likes beaches")
 | 
				
			||||||
    member2.update(bio="Hates beaches, likes forests")
 | 
					    await member2.update(bio="Hates beaches, likes forests")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    actual = m.Member.find(m.Member.bio % "beaches").sort_by("age").all()
 | 
					    actual = await m.Member.find(m.Member.bio % "beaches").sort_by("age").all()
 | 
				
			||||||
    assert actual == [member2, member1]
 | 
					    assert actual == [member2, member1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    actual = m.Member.find(m.Member.bio % "forests").all()
 | 
					    actual = await m.Member.find(m.Member.bio % "forests").all()
 | 
				
			||||||
    assert actual == [member2]
 | 
					    assert actual == [member2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_tag_queries_boolean_logic(members, m):
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_tag_queries_boolean_logic(members, m):
 | 
				
			||||||
    member1, member2, member3 = members
 | 
					    member1, member2, member3 = members
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    actual = (
 | 
					    actual = (
 | 
				
			||||||
        m.Member.find(
 | 
					        await m.Member.find(
 | 
				
			||||||
            (m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
 | 
					            (m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
 | 
				
			||||||
            | (m.Member.last_name == "Smith")
 | 
					            | (m.Member.last_name == "Smith")
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
| 
						 | 
					@ -368,7 +393,8 @@ def test_tag_queries_boolean_logic(members, m):
 | 
				
			||||||
    assert actual == [member1, member3]
 | 
					    assert actual == [member1, member3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_tag_queries_punctuation(address, m):
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_tag_queries_punctuation(address, m):
 | 
				
			||||||
    member1 = m.Member(
 | 
					    member1 = m.Member(
 | 
				
			||||||
        first_name="Andrew, the Michael",
 | 
					        first_name="Andrew, the Michael",
 | 
				
			||||||
        last_name="St. Brookins-on-Pier",
 | 
					        last_name="St. Brookins-on-Pier",
 | 
				
			||||||
| 
						 | 
					@ -377,7 +403,7 @@ def test_tag_queries_punctuation(address, m):
 | 
				
			||||||
        join_date=today,
 | 
					        join_date=today,
 | 
				
			||||||
        address=address,
 | 
					        address=address,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    member1.save()
 | 
					    await member1.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    member2 = m.Member(
 | 
					    member2 = m.Member(
 | 
				
			||||||
        first_name="Bob",
 | 
					        first_name="Bob",
 | 
				
			||||||
| 
						 | 
					@ -387,24 +413,25 @@ def test_tag_queries_punctuation(address, m):
 | 
				
			||||||
        join_date=today,
 | 
					        join_date=today,
 | 
				
			||||||
        address=address,
 | 
					        address=address,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    member2.save()
 | 
					    await member2.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    assert (
 | 
					    assert (
 | 
				
			||||||
        m.Member.find(m.Member.first_name == "Andrew, the Michael").first() == member1
 | 
					        await m.Member.find(m.Member.first_name == "Andrew, the Michael").first() == member1
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    assert (
 | 
					    assert (
 | 
				
			||||||
        m.Member.find(m.Member.last_name == "St. Brookins-on-Pier").first() == member1
 | 
					        await m.Member.find(m.Member.last_name == "St. Brookins-on-Pier").first() == member1
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Notice that when we index and query multiple values that use the internal
 | 
					    # Notice that when we index and query multiple values that use the internal
 | 
				
			||||||
    # TAG separator for single-value exact-match fields, like an indexed string,
 | 
					    # TAG separator for single-value exact-match fields, like an indexed string,
 | 
				
			||||||
    # the queries will succeed. We apply a workaround that queries for the union
 | 
					    # the queries will succeed. We apply a workaround that queries for the union
 | 
				
			||||||
    # of the two values separated by the tag separator.
 | 
					    # of the two values separated by the tag separator.
 | 
				
			||||||
    assert m.Member.find(m.Member.email == "a|b@example.com").all() == [member1]
 | 
					    assert await m.Member.find(m.Member.email == "a|b@example.com").all() == [member1]
 | 
				
			||||||
    assert m.Member.find(m.Member.email == "a|villain@example.com").all() == [member2]
 | 
					    assert await m.Member.find(m.Member.email == "a|villain@example.com").all() == [member2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_tag_queries_negation(members, m):
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_tag_queries_negation(members, m):
 | 
				
			||||||
    member1, member2, member3 = members
 | 
					    member1, member2, member3 = members
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
| 
						 | 
					@ -414,7 +441,7 @@ def test_tag_queries_negation(members, m):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    query = m.Member.find(~(m.Member.first_name == "Andrew"))
 | 
					    query = m.Member.find(~(m.Member.first_name == "Andrew"))
 | 
				
			||||||
    assert query.all() == [member2]
 | 
					    assert await query.all() == [member2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
               ┌first_name
 | 
					               ┌first_name
 | 
				
			||||||
| 
						 | 
					@ -429,7 +456,7 @@ def test_tag_queries_negation(members, m):
 | 
				
			||||||
    query = m.Member.find(
 | 
					    query = m.Member.find(
 | 
				
			||||||
        ~(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
 | 
					        ~(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    assert query.all() == [member2]
 | 
					    assert await query.all() == [member2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
               ┌first_name
 | 
					               ┌first_name
 | 
				
			||||||
| 
						 | 
					@ -448,7 +475,7 @@ def test_tag_queries_negation(members, m):
 | 
				
			||||||
        ~(m.Member.first_name == "Andrew")
 | 
					        ~(m.Member.first_name == "Andrew")
 | 
				
			||||||
        & ((m.Member.last_name == "Brookins") | (m.Member.last_name == "Smith"))
 | 
					        & ((m.Member.last_name == "Brookins") | (m.Member.last_name == "Smith"))
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    assert query.all() == [member2]
 | 
					    assert await query.all() == [member2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
                  ┌first_name
 | 
					                  ┌first_name
 | 
				
			||||||
| 
						 | 
					@ -467,67 +494,71 @@ def test_tag_queries_negation(members, m):
 | 
				
			||||||
        ~(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
 | 
					        ~(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
 | 
				
			||||||
        | (m.Member.last_name == "Smith")
 | 
					        | (m.Member.last_name == "Smith")
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    assert query.sort_by("age").all() == [member2, member3]
 | 
					    assert await query.sort_by("age").all() == [member2, member3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    actual = m.Member.find(
 | 
					    actual = await m.Member.find(
 | 
				
			||||||
        (m.Member.first_name == "Andrew") & ~(m.Member.last_name == "Brookins")
 | 
					        (m.Member.first_name == "Andrew") & ~(m.Member.last_name == "Brookins")
 | 
				
			||||||
    ).all()
 | 
					    ).all()
 | 
				
			||||||
    assert actual == [member3]
 | 
					    assert actual == [member3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_numeric_queries(members, m):
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_numeric_queries(members, m):
 | 
				
			||||||
    member1, member2, member3 = members
 | 
					    member1, member2, member3 = members
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    actual = m.Member.find(m.Member.age == 34).all()
 | 
					    actual = await m.Member.find(m.Member.age == 34).all()
 | 
				
			||||||
    assert actual == [member2]
 | 
					    assert actual == [member2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    actual = m.Member.find(m.Member.age > 34).all()
 | 
					    actual = await m.Member.find(m.Member.age > 34).all()
 | 
				
			||||||
    assert actual == [member1, member3]
 | 
					    assert actual == [member1, member3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    actual = m.Member.find(m.Member.age < 35).all()
 | 
					    actual = await m.Member.find(m.Member.age < 35).all()
 | 
				
			||||||
    assert actual == [member2]
 | 
					    assert actual == [member2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    actual = m.Member.find(m.Member.age <= 34).all()
 | 
					    actual = await m.Member.find(m.Member.age <= 34).all()
 | 
				
			||||||
    assert actual == [member2]
 | 
					    assert actual == [member2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    actual = m.Member.find(m.Member.age >= 100).all()
 | 
					    actual = await m.Member.find(m.Member.age >= 100).all()
 | 
				
			||||||
    assert actual == [member3]
 | 
					    assert actual == [member3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    actual = m.Member.find(~(m.Member.age == 100)).sort_by("age").all()
 | 
					    actual = await m.Member.find(~(m.Member.age == 100)).sort_by("age").all()
 | 
				
			||||||
    assert actual == [member2, member1]
 | 
					    assert actual == [member2, member1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    actual = m.Member.find(m.Member.age > 30, m.Member.age < 40).sort_by("age").all()
 | 
					    actual = await m.Member.find(m.Member.age > 30, m.Member.age < 40).sort_by("age").all()
 | 
				
			||||||
    assert actual == [member2, member1]
 | 
					    assert actual == [member2, member1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    actual = m.Member.find(m.Member.age != 34).sort_by("age").all()
 | 
					    actual = await m.Member.find(m.Member.age != 34).sort_by("age").all()
 | 
				
			||||||
    assert actual == [member1, member3]
 | 
					    assert actual == [member1, member3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_sorting(members, m):
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_sorting(members, m):
 | 
				
			||||||
    member1, member2, member3 = members
 | 
					    member1, member2, member3 = members
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    actual = m.Member.find(m.Member.age > 34).sort_by("age").all()
 | 
					    actual = await m.Member.find(m.Member.age > 34).sort_by("age").all()
 | 
				
			||||||
    assert actual == [member1, member3]
 | 
					    assert actual == [member1, member3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    actual = m.Member.find(m.Member.age > 34).sort_by("-age").all()
 | 
					    actual = await m.Member.find(m.Member.age > 34).sort_by("-age").all()
 | 
				
			||||||
    assert actual == [member3, member1]
 | 
					    assert actual == [member3, member1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    with pytest.raises(QueryNotSupportedError):
 | 
					    with pytest.raises(QueryNotSupportedError):
 | 
				
			||||||
        # This field does not exist.
 | 
					        # This field does not exist.
 | 
				
			||||||
        m.Member.find().sort_by("not-a-real-field").all()
 | 
					        await m.Member.find().sort_by("not-a-real-field").all()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    with pytest.raises(QueryNotSupportedError):
 | 
					    with pytest.raises(QueryNotSupportedError):
 | 
				
			||||||
        # This field is not sortable.
 | 
					        # This field is not sortable.
 | 
				
			||||||
        m.Member.find().sort_by("join_date").all()
 | 
					        await m.Member.find().sort_by("join_date").all()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_not_found(m):
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_not_found(m):
 | 
				
			||||||
    with pytest.raises(NotFoundError):
 | 
					    with pytest.raises(NotFoundError):
 | 
				
			||||||
        # This ID does not exist.
 | 
					        # This ID does not exist.
 | 
				
			||||||
        m.Member.get(1000)
 | 
					        await m.Member.get(1000)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_list_field_limitations(m):
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_list_field_limitations(m, redis):
 | 
				
			||||||
    with pytest.raises(RedisModelError):
 | 
					    with pytest.raises(RedisModelError):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        class SortableTarotWitch(m.BaseJsonModel):
 | 
					        class SortableTarotWitch(m.BaseJsonModel):
 | 
				
			||||||
| 
						 | 
					@ -571,15 +602,16 @@ def test_list_field_limitations(m):
 | 
				
			||||||
    # We need to import and run this manually because we defined
 | 
					    # We need to import and run this manually because we defined
 | 
				
			||||||
    # our model classes within a function that runs after the test
 | 
					    # our model classes within a function that runs after the test
 | 
				
			||||||
    # suite's migrator has already looked for migrations to run.
 | 
					    # suite's migrator has already looked for migrations to run.
 | 
				
			||||||
    Migrator().run()
 | 
					    await Migrator(redis).run()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    witch = TarotWitch(tarot_cards=["death"])
 | 
					    witch = TarotWitch(tarot_cards=["death"])
 | 
				
			||||||
    witch.save()
 | 
					    await witch.save()
 | 
				
			||||||
    actual = TarotWitch.find(TarotWitch.tarot_cards << "death").all()
 | 
					    actual = await TarotWitch.find(TarotWitch.tarot_cards << "death").all()
 | 
				
			||||||
    assert actual == [witch]
 | 
					    assert actual == [witch]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_schema(m, key_prefix):
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_schema(m, key_prefix):
 | 
				
			||||||
    assert (
 | 
					    assert (
 | 
				
			||||||
        m.Member.redisearch_schema()
 | 
					        m.Member.redisearch_schema()
 | 
				
			||||||
        == f"ON JSON PREFIX 1 {key_prefix}:tests.test_json_model.Member: SCHEMA $.pk AS pk TAG SEPARATOR | $.first_name AS first_name TAG SEPARATOR | $.last_name AS last_name TAG SEPARATOR | $.email AS email TAG SEPARATOR |  $.age AS age NUMERIC $.bio AS bio TAG SEPARATOR | $.bio AS bio_fts TEXT $.address.pk AS address_pk TAG SEPARATOR | $.address.city AS address_city TAG SEPARATOR | $.address.postal_code AS address_postal_code TAG SEPARATOR | $.address.note.pk AS address_note_pk TAG SEPARATOR | $.address.note.description AS address_note_description TAG SEPARATOR | $.orders[*].pk AS orders_pk TAG SEPARATOR | $.orders[*].items[*].pk AS orders_items_pk TAG SEPARATOR | $.orders[*].items[*].name AS orders_items_name TAG SEPARATOR |"
 | 
					        == f"ON JSON PREFIX 1 {key_prefix}:tests.test_json_model.Member: SCHEMA $.pk AS pk TAG SEPARATOR | $.first_name AS first_name TAG SEPARATOR | $.last_name AS last_name TAG SEPARATOR | $.email AS email TAG SEPARATOR |  $.age AS age NUMERIC $.bio AS bio TAG SEPARATOR | $.bio AS bio_fts TEXT $.address.pk AS address_pk TAG SEPARATOR | $.address.city AS address_city TAG SEPARATOR | $.address.postal_code AS address_postal_code TAG SEPARATOR | $.address.note.pk AS address_note_pk TAG SEPARATOR | $.address.note.description AS address_note_description TAG SEPARATOR | $.orders[*].pk AS orders_pk TAG SEPARATOR | $.orders[*].items[*].pk AS orders_items_pk TAG SEPARATOR | $.orders[*].items[*].name AS orders_items_name TAG SEPARATOR |"
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue