Handle TAG queries that include the separator

This commit is contained in:
Andrew Brookins 2021-10-05 16:40:02 -07:00
parent b46408ccd2
commit 8f32b359f0
7 changed files with 591 additions and 128 deletions

105
poetry.lock generated
View file

@ -165,6 +165,31 @@ python-versions = ">=3.5"
[package.dependencies] [package.dependencies]
traitlets = "*" traitlets = "*"
[[package]]
name = "mypy"
version = "0.910"
description = "Optional static typing for Python"
category = "main"
optional = false
python-versions = ">=3.5"
[package.dependencies]
mypy-extensions = ">=0.4.3,<0.5.0"
toml = "*"
typing-extensions = ">=3.7.4"
[package.extras]
dmypy = ["psutil (>=4.0)"]
python2 = ["typed-ast (>=1.4.0,<1.5.0)"]
[[package]]
name = "mypy-extensions"
version = "0.4.3"
description = "Experimental type system extensions for programs checked with the mypy typechecker."
category = "main"
optional = false
python-versions = "*"
[[package]] [[package]]
name = "packaging" name = "packaging"
version = "21.0" version = "21.0"
@ -219,6 +244,14 @@ python-versions = ">=3.6"
dev = ["pre-commit", "tox"] dev = ["pre-commit", "tox"]
testing = ["pytest", "pytest-benchmark"] testing = ["pytest", "pytest-benchmark"]
[[package]]
name = "pptree"
version = "3.1"
description = "Pretty print trees"
category = "main"
optional = false
python-versions = "*"
[[package]] [[package]]
name = "prompt-toolkit" name = "prompt-toolkit"
version = "3.0.20" version = "3.0.20"
@ -298,6 +331,14 @@ toml = "*"
[package.extras] [package.extras]
testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"] testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"]
[[package]]
name = "python-ulid"
version = "1.0.3"
description = "Universally Unique Lexicographically Sortable Identifier"
category = "main"
optional = false
python-versions = "*"
[[package]] [[package]]
name = "redis" name = "redis"
version = "3.5.3" version = "3.5.3"
@ -321,7 +362,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
name = "toml" name = "toml"
version = "0.10.2" version = "0.10.2"
description = "Python Library for Tom's Obvious, Minimal Language" description = "Python Library for Tom's Obvious, Minimal Language"
category = "dev" category = "main"
optional = false optional = false
python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
@ -336,6 +377,22 @@ python-versions = ">=3.7"
[package.extras] [package.extras]
test = ["pytest"] test = ["pytest"]
[[package]]
name = "types-redis"
version = "3.5.9"
description = "Typing stubs for redis"
category = "main"
optional = false
python-versions = "*"
[[package]]
name = "types-six"
version = "1.16.1"
description = "Typing stubs for six"
category = "main"
optional = false
python-versions = "*"
[[package]] [[package]]
name = "typing-extensions" name = "typing-extensions"
version = "3.10.0.2" version = "3.10.0.2"
@ -355,7 +412,7 @@ python-versions = "*"
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.8" python-versions = "^3.8"
content-hash = "b3f0c7c5701bb2c317df7f2f42218ef6c38d9b035ed49c9a9469df6a6727973c" content-hash = "baa4bd3c38445c3325bdd317ecbfe99ccaf4bef438970ed31f5c49cc782d575e"
[metadata.files] [metadata.files]
aioredis = [ aioredis = [
@ -413,6 +470,35 @@ matplotlib-inline = [
{file = "matplotlib-inline-0.1.3.tar.gz", hash = "sha256:a04bfba22e0d1395479f866853ec1ee28eea1485c1d69a6faf00dc3e24ff34ee"}, {file = "matplotlib-inline-0.1.3.tar.gz", hash = "sha256:a04bfba22e0d1395479f866853ec1ee28eea1485c1d69a6faf00dc3e24ff34ee"},
{file = "matplotlib_inline-0.1.3-py3-none-any.whl", hash = "sha256:aed605ba3b72462d64d475a21a9296f400a19c4f74a31b59103d2a99ffd5aa5c"}, {file = "matplotlib_inline-0.1.3-py3-none-any.whl", hash = "sha256:aed605ba3b72462d64d475a21a9296f400a19c4f74a31b59103d2a99ffd5aa5c"},
] ]
mypy = [
{file = "mypy-0.910-cp35-cp35m-macosx_10_9_x86_64.whl", hash = "sha256:a155d80ea6cee511a3694b108c4494a39f42de11ee4e61e72bc424c490e46457"},
{file = "mypy-0.910-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:b94e4b785e304a04ea0828759172a15add27088520dc7e49ceade7834275bedb"},
{file = "mypy-0.910-cp35-cp35m-manylinux2010_x86_64.whl", hash = "sha256:088cd9c7904b4ad80bec811053272986611b84221835e079be5bcad029e79dd9"},
{file = "mypy-0.910-cp35-cp35m-win_amd64.whl", hash = "sha256:adaeee09bfde366d2c13fe6093a7df5df83c9a2ba98638c7d76b010694db760e"},
{file = "mypy-0.910-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:ecd2c3fe726758037234c93df7e98deb257fd15c24c9180dacf1ef829da5f921"},
{file = "mypy-0.910-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:d9dd839eb0dc1bbe866a288ba3c1afc33a202015d2ad83b31e875b5905a079b6"},
{file = "mypy-0.910-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:3e382b29f8e0ccf19a2df2b29a167591245df90c0b5a2542249873b5c1d78212"},
{file = "mypy-0.910-cp36-cp36m-win_amd64.whl", hash = "sha256:53fd2eb27a8ee2892614370896956af2ff61254c275aaee4c230ae771cadd885"},
{file = "mypy-0.910-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b6fb13123aeef4a3abbcfd7e71773ff3ff1526a7d3dc538f3929a49b42be03f0"},
{file = "mypy-0.910-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:e4dab234478e3bd3ce83bac4193b2ecd9cf94e720ddd95ce69840273bf44f6de"},
{file = "mypy-0.910-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:7df1ead20c81371ccd6091fa3e2878559b5c4d4caadaf1a484cf88d93ca06703"},
{file = "mypy-0.910-cp37-cp37m-win_amd64.whl", hash = "sha256:0aadfb2d3935988ec3815952e44058a3100499f5be5b28c34ac9d79f002a4a9a"},
{file = "mypy-0.910-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ec4e0cd079db280b6bdabdc807047ff3e199f334050db5cbb91ba3e959a67504"},
{file = "mypy-0.910-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:119bed3832d961f3a880787bf621634ba042cb8dc850a7429f643508eeac97b9"},
{file = "mypy-0.910-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:866c41f28cee548475f146aa4d39a51cf3b6a84246969f3759cb3e9c742fc072"},
{file = "mypy-0.910-cp38-cp38-win_amd64.whl", hash = "sha256:ceb6e0a6e27fb364fb3853389607cf7eb3a126ad335790fa1e14ed02fba50811"},
{file = "mypy-0.910-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1a85e280d4d217150ce8cb1a6dddffd14e753a4e0c3cf90baabb32cefa41b59e"},
{file = "mypy-0.910-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:42c266ced41b65ed40a282c575705325fa7991af370036d3f134518336636f5b"},
{file = "mypy-0.910-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:3c4b8ca36877fc75339253721f69603a9c7fdb5d4d5a95a1a1b899d8b86a4de2"},
{file = "mypy-0.910-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:c0df2d30ed496a08de5daed2a9ea807d07c21ae0ab23acf541ab88c24b26ab97"},
{file = "mypy-0.910-cp39-cp39-win_amd64.whl", hash = "sha256:c6c2602dffb74867498f86e6129fd52a2770c48b7cd3ece77ada4fa38f94eba8"},
{file = "mypy-0.910-py3-none-any.whl", hash = "sha256:ef565033fa5a958e62796867b1df10c40263ea9ded87164d67572834e57a174d"},
{file = "mypy-0.910.tar.gz", hash = "sha256:704098302473cb31a218f1775a873b376b30b4c18229421e9e9dc8916fd16150"},
]
mypy-extensions = [
{file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"},
{file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"},
]
packaging = [ packaging = [
{file = "packaging-21.0-py3-none-any.whl", hash = "sha256:c86254f9220d55e31cc94d69bade760f0847da8000def4dfe1c6b872fd14ff14"}, {file = "packaging-21.0-py3-none-any.whl", hash = "sha256:c86254f9220d55e31cc94d69bade760f0847da8000def4dfe1c6b872fd14ff14"},
{file = "packaging-21.0.tar.gz", hash = "sha256:7dc96269f53a4ccec5c0670940a4281106dd0bb343f47b7471f779df49c2fbe7"}, {file = "packaging-21.0.tar.gz", hash = "sha256:7dc96269f53a4ccec5c0670940a4281106dd0bb343f47b7471f779df49c2fbe7"},
@ -433,6 +519,9 @@ pluggy = [
{file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"}, {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"},
{file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"}, {file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"},
] ]
pptree = [
{file = "pptree-3.1.tar.gz", hash = "sha256:4dd0ba2f58000cbd29d68a5b64bac29bcb5a663642f79404877c0059668a69f6"},
]
prompt-toolkit = [ prompt-toolkit = [
{file = "prompt_toolkit-3.0.20-py3-none-any.whl", hash = "sha256:6076e46efae19b1e0ca1ec003ed37a933dc94b4d20f486235d436e64771dcd5c"}, {file = "prompt_toolkit-3.0.20-py3-none-any.whl", hash = "sha256:6076e46efae19b1e0ca1ec003ed37a933dc94b4d20f486235d436e64771dcd5c"},
{file = "prompt_toolkit-3.0.20.tar.gz", hash = "sha256:eb71d5a6b72ce6db177af4a7d4d7085b99756bf656d98ffcc4fecd36850eea6c"}, {file = "prompt_toolkit-3.0.20.tar.gz", hash = "sha256:eb71d5a6b72ce6db177af4a7d4d7085b99756bf656d98ffcc4fecd36850eea6c"},
@ -481,6 +570,10 @@ pytest = [
{file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"}, {file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"},
{file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"}, {file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"},
] ]
python-ulid = [
{file = "python-ulid-1.0.3.tar.gz", hash = "sha256:5dd8b969312a40e2212cec9c1ad63f25d4b6eafd92ee3195883e0287b6e9d19e"},
{file = "python_ulid-1.0.3-py3-none-any.whl", hash = "sha256:8704dc20f547f531fe3a41d4369842d737a0f275403b909d0872e7ea0fe8d6f2"},
]
redis = [ redis = [
{file = "redis-3.5.3-py2.py3-none-any.whl", hash = "sha256:432b788c4530cfe16d8d943a09d40ca6c16149727e4afe8c2c9d5580c59d9f24"}, {file = "redis-3.5.3-py2.py3-none-any.whl", hash = "sha256:432b788c4530cfe16d8d943a09d40ca6c16149727e4afe8c2c9d5580c59d9f24"},
{file = "redis-3.5.3.tar.gz", hash = "sha256:0e7e0cfca8660dea8b7d5cd8c4f6c5e29e11f31158c0b0ae91a397f00e5a05a2"}, {file = "redis-3.5.3.tar.gz", hash = "sha256:0e7e0cfca8660dea8b7d5cd8c4f6c5e29e11f31158c0b0ae91a397f00e5a05a2"},
@ -497,6 +590,14 @@ traitlets = [
{file = "traitlets-5.1.0-py3-none-any.whl", hash = "sha256:03f172516916220b58c9f19d7f854734136dd9528103d04e9bf139a92c9f54c4"}, {file = "traitlets-5.1.0-py3-none-any.whl", hash = "sha256:03f172516916220b58c9f19d7f854734136dd9528103d04e9bf139a92c9f54c4"},
{file = "traitlets-5.1.0.tar.gz", hash = "sha256:bd382d7ea181fbbcce157c133db9a829ce06edffe097bcf3ab945b435452b46d"}, {file = "traitlets-5.1.0.tar.gz", hash = "sha256:bd382d7ea181fbbcce157c133db9a829ce06edffe097bcf3ab945b435452b46d"},
] ]
types-redis = [
{file = "types-redis-3.5.9.tar.gz", hash = "sha256:f142c48f4080757ca2a9441ec40213bda3b1535eebebfc4f3519e5aa46498076"},
{file = "types_redis-3.5.9-py3-none-any.whl", hash = "sha256:5f5648ffc025708858097173cf695164c20f2b5e3f57177de14e352cae8cc335"},
]
types-six = [
{file = "types-six-1.16.1.tar.gz", hash = "sha256:a9e6769cb0808f920958ac95f75c5191f49e21e041eac127fa62e286e1005616"},
{file = "types_six-1.16.1-py2.py3-none-any.whl", hash = "sha256:b14f5abe26c0997bd41a1a32d6816af25932f7bfbc54246dfdc8f6f6404fd1d4"},
]
typing-extensions = [ typing-extensions = [
{file = "typing_extensions-3.10.0.2-py2-none-any.whl", hash = "sha256:d8226d10bc02a29bcc81df19a26e56a9647f8b0a6d4a83924139f4a8b01f17b7"}, {file = "typing_extensions-3.10.0.2-py2-none-any.whl", hash = "sha256:d8226d10bc02a29bcc81df19a26e56a9647f8b0a6d4a83924139f4a8b01f17b7"},
{file = "typing_extensions-3.10.0.2-py3-none-any.whl", hash = "sha256:f1d25edafde516b146ecd0613dabcc61409817af4766fbbcfb8d1ad4ec441a34"}, {file = "typing_extensions-3.10.0.2-py3-none-any.whl", hash = "sha256:f1d25edafde516b146ecd0613dabcc61409817af4766fbbcfb8d1ad4ec441a34"},

View file

@ -12,6 +12,11 @@ aioredis = "^2.0.0"
pydantic = "^1.8.2" pydantic = "^1.8.2"
click = "^8.0.1" click = "^8.0.1"
six = "^1.16.0" six = "^1.16.0"
pptree = "^3.1"
mypy = "^0.910"
types-redis = "^3.5.9"
types-six = "^1.16.1"
python-ulid = "^1.0.3"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
pytest = "^6.2.4" pytest = "^6.2.4"

View file

@ -1,9 +1,9 @@
import abc import abc
import dataclasses import dataclasses
import decimal import decimal
import json
import logging import logging
import operator import operator
import re
from copy import deepcopy from copy import deepcopy
from enum import Enum from enum import Enum
from functools import reduce from functools import reduce
@ -22,10 +22,9 @@ from typing import (
no_type_check, no_type_check,
Protocol, Protocol,
List, List,
Type, get_origin,
Pattern, get_origin, get_args get_args, Type
) )
import uuid
import redis import redis
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
@ -34,46 +33,43 @@ from pydantic.fields import ModelField, Undefined, UndefinedType
from pydantic.main import ModelMetaclass from pydantic.main import ModelMetaclass
from pydantic.typing import NoArgAnyCallable, resolve_annotations from pydantic.typing import NoArgAnyCallable, resolve_annotations
from pydantic.utils import Representation from pydantic.utils import Representation
from ulid import ULID
from .encoders import jsonable_encoder from .encoders import jsonable_encoder
from .render_tree import render_tree
from .token_escaper import TokenEscaper
model_registry = {} model_registry = {}
_T = TypeVar("_T") _T = TypeVar("_T")
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class TokenEscaper:
"""
Escape punctuation within an input string.
"""
# Characters that RediSearch requires us to escape during queries.
# Source: https://oss.redis.com/redisearch/Escaping/#the_rules_of_text_field_tokenization
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\ ]"
def __init__(self, escape_chars_re: Optional[Pattern] = None):
if escape_chars_re:
self.escaped_chars_re = escape_chars_re
else:
self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS)
def escape(self, string):
def escape_symbol(match):
value = match.group(0)
return f"\\{value}"
return self.escaped_chars_re.sub(escape_symbol, string)
escaper = TokenEscaper() escaper = TokenEscaper()
# For basic exact-match field types like an indexed string, we create a TAG
# field in the RediSearch index. TAG is designed for multi-value fields
# separated by a "separator" character. We're using the field for single values
# (multi-value TAGs will be exposed as a separate field type), and we use the
# pipe character (|) as the separator. There is no way to escape this character
# in hash fields or JSON objects, so if someone indexes a value that includes
# the pipe, we'll warn but allow, and then warn again if they try to query for
# values that contain this separator.
SINGLE_VALUE_TAG_FIELD_SEPARATOR = "|"
# This is the default field separator in RediSearch. We need it to determine if
# someone has accidentally passed in the field separator with string value of a
# multi-value field lookup, like a IN or NOT_IN.
DEFAULT_REDISEARCH_FIELD_SEPARATOR = ","
class RedisModelError(Exception): class RedisModelError(Exception):
pass """Raised when a problem exists in the definition of a RedisModel."""
class QuerySyntaxError(Exception):
"""Raised when a query is constructed improperly."""
class NotFoundError(Exception): class NotFoundError(Exception):
"""A query found no results.""" """Raised when a query found no results."""
class Operators(Enum): class Operators(Enum):
@ -91,9 +87,45 @@ class Operators(Enum):
LIKE = 12 LIKE = 12
ALL = 13 ALL = 13
def __str__(self):
return str(self.name)
ExpressionOrModelField = Union['Expression', 'NegatedExpression', ModelField]
class ExpressionProtocol(Protocol):
op: Operators
left: ExpressionOrModelField
right: ExpressionOrModelField
def __invert__(self) -> 'Expression':
pass
def __and__(self, other: ExpressionOrModelField):
pass
def __or__(self, other: ExpressionOrModelField):
pass
@property
def name(self) -> str:
raise NotImplementedError
@property
def tree(self) -> str:
raise NotImplementedError
@dataclasses.dataclass @dataclasses.dataclass
class NegatedExpression: class NegatedExpression:
"""A negated Expression object.
For now, this is a separate dataclass from Expression that acts as a facade
around an Expression, indicating to model code (specifically, code
responsible for querying) to negate the logic in the wrapped Expression. A
better design is probably possible, maybe at least an ExpressionProtocol?
"""
expression: 'Expression' expression: 'Expression'
def __invert__(self): def __invert__(self):
@ -105,22 +137,53 @@ class NegatedExpression:
def __or__(self, other): def __or__(self, other):
return Expression(left=self, op=Operators.OR, right=other) return Expression(left=self, op=Operators.OR, right=other)
@property
def left(self):
return self.expression.left
@property
def right(self):
return self.expression.right
@property
def op(self):
return self.expression.op
@property
def name(self):
if self.expression.op is Operators.EQ:
return f"NOT {self.expression.name}"
else:
return f"{self.expression.name} NOT"
@property
def tree(self):
return render_tree(self)
@dataclasses.dataclass @dataclasses.dataclass
class Expression: class Expression:
op: Operators op: Operators
left: Any left: ExpressionOrModelField
right: Any right: ExpressionOrModelField
def __invert__(self): def __invert__(self):
return NegatedExpression(self) return NegatedExpression(self)
def __and__(self, other): def __and__(self, other: ExpressionOrModelField):
return Expression(left=self, op=Operators.AND, right=other) return Expression(left=self, op=Operators.AND, right=other)
def __or__(self, other): def __or__(self, other: ExpressionOrModelField):
return Expression(left=self, op=Operators.OR, right=other) return Expression(left=self, op=Operators.OR, right=other)
@property
def name(self):
return str(self.op)
@property
def tree(self):
return render_tree(self)
ExpressionOrNegated = Union[Expression, NegatedExpression] ExpressionOrNegated = Union[Expression, NegatedExpression]
@ -129,22 +192,22 @@ class ExpressionProxy:
def __init__(self, field: ModelField): def __init__(self, field: ModelField):
self.field = field self.field = field
def __eq__(self, other: Any) -> Expression: def __eq__(self, other: Any) -> Expression: # type: ignore[override]
return Expression(left=self.field, op=Operators.EQ, right=other) return Expression(left=self.field, op=Operators.EQ, right=other)
def __ne__(self, other: Any) -> Expression: def __ne__(self, other: Any) -> Expression: # type: ignore[override]
return Expression(left=self.field, op=Operators.NE, right=other) return Expression(left=self.field, op=Operators.NE, right=other)
def __lt__(self, other: Any) -> Expression: def __lt__(self, other: Any) -> Expression: # type: ignore[override]
return Expression(left=self.field, op=Operators.LT, right=other) return Expression(left=self.field, op=Operators.LT, right=other)
def __le__(self, other: Any) -> Expression: def __le__(self, other: Any) -> Expression: # type: ignore[override]
return Expression(left=self.field, op=Operators.LE, right=other) return Expression(left=self.field, op=Operators.LE, right=other)
def __gt__(self, other: Any) -> Expression: def __gt__(self, other: Any) -> Expression: # type: ignore[override]
return Expression(left=self.field, op=Operators.GT, right=other) return Expression(left=self.field, op=Operators.GT, right=other)
def __ge__(self, other: Any) -> Expression: def __ge__(self, other: Any) -> Expression: # type: ignore[override]
return Expression(left=self.field, op=Operators.GE, right=other) return Expression(left=self.field, op=Operators.GE, right=other)
@ -184,9 +247,9 @@ class FindQuery:
self.sort_fields = [] self.sort_fields = []
self._expression = None self._expression = None
self._query = None self._query: Optional[str] = None
self._pagination = [] self._pagination: list[str] = []
self._model_cache = [] self._model_cache: list[RedisModel] = []
@property @property
def pagination(self): def pagination(self):
@ -236,24 +299,24 @@ class FindQuery:
else: else:
# TAG fields are the default field type. # TAG fields are the default field type.
# TODO: A ListField or ArrayField that supports multiple values # TODO: A ListField or ArrayField that supports multiple values
# and contains logic. # and contains logic should allow IN and NOT_IN queries.
return RediSearchFieldTypes.TAG return RediSearchFieldTypes.TAG
@staticmethod @staticmethod
def expand_tag_value(value): def expand_tag_value(value):
err = RedisModelError(f"Using the IN operator requires passing a sequence of "
"possible values. You passed: {value}")
if isinstance(str, value): if isinstance(str, value):
raise err return value
try: try:
expanded_value = "|".join([escaper.escape(v) for v in value]) expanded_value = "|".join([escaper.escape(v) for v in value])
except TypeError: except TypeError:
raise err raise QuerySyntaxError("Values passed to an IN query must be iterables,"
"like a list of strings. For more information, see:"
"TODO: doc.")
return expanded_value return expanded_value
@classmethod @classmethod
def resolve_value(cls, field_name: str, field_type: RediSearchFieldTypes, def resolve_value(cls, field_name: str, field_type: RediSearchFieldTypes,
op: Operators, value: Any) -> str: field_info: PydanticFieldInfo, op: Operators, value: Any) -> str:
result = "" result = ""
if field_type is RediSearchFieldTypes.TEXT: if field_type is RediSearchFieldTypes.TEXT:
result = f"@{field_name}:" result = f"@{field_name}:"
@ -282,17 +345,41 @@ class FindQuery:
result += f"@{field_name}:[{value} +inf]" result += f"@{field_name}:[{value} +inf]"
elif op is Operators.LE: elif op is Operators.LE:
result += f"@{field_name}:[-inf {value}]" result += f"@{field_name}:[-inf {value}]"
# TODO: How will we know the difference between a multi-value use of a TAG
# field and our hidden use of TAG for exact-match queries?
elif field_type is RediSearchFieldTypes.TAG: elif field_type is RediSearchFieldTypes.TAG:
if op is Operators.EQ: if op is Operators.EQ:
value = escaper.escape(value) separator_char = getattr(field_info, 'separator',
result += f"@{field_name}:{{{value}}}" SINGLE_VALUE_TAG_FIELD_SEPARATOR)
if value == separator_char:
# The value is ONLY the TAG field separator character --
# this is not going to work.
log.warning("Your query against the field %s is for a single character, %s, "
"that is used internally by redis-developer-python. We must ignore "
"this portion of the query. Please review your query to find "
"an alternative query that uses a string containing more than "
"just the character %s.", field_name, separator_char, separator_char)
return ""
if separator_char in value:
# The value contains the TAG field separator. We can work
# around this by breaking apart the values and unioning them
# with multiple field:{} queries.
values = filter(None, value.split(separator_char))
for value in values:
value = escaper.escape(value)
result += f"@{field_name}:{{{value}}}"
else:
value = escaper.escape(value)
result += f"@{field_name}:{{{value}}}"
elif op is Operators.NE: elif op is Operators.NE:
value = escaper.escape(value) value = escaper.escape(value)
result += f"-(@{field_name}:{{{value}}})" result += f"-(@{field_name}:{{{value}}})"
elif op is Operators.IN: elif op is Operators.IN:
# TODO: Implement IN, test this...
expanded_value = cls.expand_tag_value(value) expanded_value = cls.expand_tag_value(value)
result += f"(@{field_name}:{{{expanded_value}}})" result += f"(@{field_name}:{{{expanded_value}}})"
elif op is Operators.NOT_IN: elif op is Operators.NOT_IN:
# TODO: Implement NOT_IN, test this...
expanded_value = cls.expand_tag_value(value) expanded_value = cls.expand_tag_value(value)
result += f"-(@{field_name}:{{{expanded_value}}})" result += f"-(@{field_name}:{{{expanded_value}}})"
@ -314,10 +401,11 @@ class FindQuery:
return ["SORTBY", *fields] return ["SORTBY", *fields]
@classmethod @classmethod
def resolve_redisearch_query(cls, expression: ExpressionOrNegated): def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
"""Resolve an expression to a string RediSearch query.""" """Resolve an expression to a string RediSearch query."""
field_type = None field_type = None
field_name = None field_name = None
field_info = None
encompassing_expression_is_negated = False encompassing_expression_is_negated = False
result = "" result = ""
@ -328,7 +416,7 @@ class FindQuery:
if expression.op is Operators.ALL: if expression.op is Operators.ALL:
if encompassing_expression_is_negated: if encompassing_expression_is_negated:
# TODO: Is there a use case for this, perhaps for dynamic # TODO: Is there a use case for this, perhaps for dynamic
# scoring purposes? # scoring purposes?
raise QueryNotSupportedError("You cannot negate a query for all results.") raise QueryNotSupportedError("You cannot negate a query for all results.")
return "*" return "*"
@ -338,6 +426,7 @@ class FindQuery:
elif isinstance(expression.left, ModelField): elif isinstance(expression.left, ModelField):
field_type = cls.resolve_field_type(expression.left) field_type = cls.resolve_field_type(expression.left)
field_name = expression.left.name field_name = expression.left.name
field_info = expression.left.field_info
else: else:
raise QueryNotSupportedError(f"A query expression should start with either a field " raise QueryNotSupportedError(f"A query expression should start with either a field "
f"or an expression enclosed in parenthesis. See docs: " f"or an expression enclosed in parenthesis. See docs: "
@ -365,8 +454,7 @@ class FindQuery:
if isinstance(right, ModelField): if isinstance(right, ModelField):
raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO") raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO")
else: else:
# TODO: Optionals causing IDE errors here result += cls.resolve_value(field_name, field_type, field_info, expression.op, right)
result += cls.resolve_value(field_name, field_type, expression.op, right)
if encompassing_expression_is_negated: if encompassing_expression_is_negated:
result = f"-({result})" result = f"-({result})"
@ -416,7 +504,10 @@ class FindQuery:
def first(self): def first(self):
query = FindQuery(expressions=self.expressions, model=self.model, query = FindQuery(expressions=self.expressions, model=self.model,
offset=0, limit=1, sort_fields=self.sort_fields) offset=0, limit=1, sort_fields=self.sort_fields)
return query.execute()[0] results = query.execute()
if not results:
raise NotFoundError()
return results[0]
def all(self, batch_size=10): def all(self, batch_size=10):
if batch_size != self.page_size: if batch_size != self.page_size:
@ -494,9 +585,13 @@ class PrimaryKeyCreator(Protocol):
"""Create a new primary key""" """Create a new primary key"""
class Uuid4PrimaryKey: class UlidPrimaryKey:
def create_pk(self, *args, **kwargs) -> str: """A client-side generated primary key that follows the ULID spec.
return str(uuid.uuid4()) https://github.com/ulid/javascript#specification
"""
@staticmethod
def create_pk(*args, **kwargs) -> str:
return str(ULID())
def __dataclass_transform__( def __dataclass_transform__(
@ -601,8 +696,24 @@ class PrimaryKey:
field: ModelField field: ModelField
class MetaProtocol(Protocol):
global_key_prefix: str
model_key_prefix: str
primary_key_pattern: str
database: redis.Redis
primary_key: PrimaryKey
primary_key_creator_cls: Type[PrimaryKeyCreator]
index_name: str
abstract: bool
@dataclasses.dataclass
class DefaultMeta: class DefaultMeta:
# TODO: Should this really be optional here? """A default placeholder Meta object.
TODO: Revisit whether this is really necessary, and whether making
these all optional here is the right choice.
"""
global_key_prefix: Optional[str] = None global_key_prefix: Optional[str] = None
model_key_prefix: Optional[str] = None model_key_prefix: Optional[str] = None
primary_key_pattern: Optional[str] = None primary_key_pattern: Optional[str] = None
@ -614,6 +725,8 @@ class DefaultMeta:
class ModelMeta(ModelMetaclass): class ModelMeta(ModelMetaclass):
_meta: MetaProtocol
def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 def __new__(cls, name, bases, attrs, **kwargs): # noqa C901
meta = attrs.pop('Meta', None) meta = attrs.pop('Meta', None)
new_class = super().__new__(cls, name, bases, attrs, **kwargs) new_class = super().__new__(cls, name, bases, attrs, **kwargs)
@ -656,11 +769,16 @@ class ModelMeta(ModelMetaclass):
redis.Redis(decode_responses=True)) redis.Redis(decode_responses=True))
if not getattr(new_class._meta, 'primary_key_creator_cls', None): if not getattr(new_class._meta, 'primary_key_creator_cls', None):
new_class._meta.primary_key_creator_cls = getattr(base_meta, "primary_key_creator_cls", new_class._meta.primary_key_creator_cls = getattr(base_meta, "primary_key_creator_cls",
Uuid4PrimaryKey) UlidPrimaryKey)
if not getattr(new_class._meta, 'index_name', None): if not getattr(new_class._meta, 'index_name', None):
new_class._meta.index_name = f"{new_class._meta.global_key_prefix}:" \ new_class._meta.index_name = f"{new_class._meta.global_key_prefix}:" \
f"{new_class._meta.model_key_prefix}:index" f"{new_class._meta.model_key_prefix}:index"
# Not an abstract model class
if abc.ABC not in bases:
key = f"{new_class.__module__}.{new_class.__qualname__}"
model_registry[key] = new_class
return new_class return new_class
@ -680,12 +798,8 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
__pydantic_self__.validate_primary_key() __pydantic_self__.validate_primary_key()
def __lt__(self, other): def __lt__(self, other):
"""Default sort: compare all shared model fields.""" """Default sort: compare primary key of models."""
my_keys = set(self.__fields__.keys()) return self.pk < other.pk
other_keys = set(other.__fields__.keys())
shared_keys = list(my_keys & other_keys)
lt = [getattr(self, k) < getattr(other, k) for k in shared_keys]
return len(lt) > len(shared_keys) / 2
@validator("pk", always=True) @validator("pk", always=True)
def validate_pk(cls, v): def validate_pk(cls, v):
@ -726,7 +840,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
return cls._meta.database return cls._meta.database
@classmethod @classmethod
def find(cls, *expressions: Union[Any, Expression]): # TODO: How to type annotate this? def find(cls, *expressions: Union[Any, Expression]) -> FindQuery: # TODO: How to type annotate this?
return FindQuery(expressions=expressions, model=cls) return FindQuery(expressions=expressions, model=cls)
@classmethod @classmethod
@ -760,7 +874,17 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
except KeyError: except KeyError:
pass pass
doc = cls(**fields) try:
fields['json'] = fields['$']
del fields['$']
except KeyError:
pass
if 'json' in fields:
json_fields = json.loads(fields['json'])
doc = cls(**json_fields)
else:
doc = cls(**fields)
docs.append(doc) docs.append(doc)
return docs return docs
@ -847,7 +971,7 @@ class HashModel(RedisModel, abc.ABC):
_type = field.outer_type_ _type = field.outer_type_
if getattr(field.field_info, 'primary_key', None): if getattr(field.field_info, 'primary_key', None):
if issubclass(_type, str): if issubclass(_type, str):
redisearch_field = f"{name} TAG" redisearch_field = f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
else: else:
redisearch_field = cls.schema_for_type(name, _type, field.field_info) redisearch_field = cls.schema_for_type(name, _type, field.field_info)
schema_parts.append(redisearch_field) schema_parts.append(redisearch_field)
@ -872,7 +996,7 @@ class HashModel(RedisModel, abc.ABC):
return schema_parts return schema_parts
@classmethod @classmethod
def schema_for_type(cls, name, typ: Type, field_info: FieldInfo): def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
if get_origin(typ) == list: if get_origin(typ) == list:
embedded_cls = get_args(typ) embedded_cls = get_args(typ)
if not embedded_cls: if not embedded_cls:
@ -885,9 +1009,10 @@ class HashModel(RedisModel, abc.ABC):
return f"{name} NUMERIC" return f"{name} NUMERIC"
elif issubclass(typ, str): elif issubclass(typ, str):
if getattr(field_info, 'full_text_search', False) is True: if getattr(field_info, 'full_text_search', False) is True:
return f"{name} TAG {name}_fts TEXT" return f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR} " \
f"{name}_fts TEXT"
else: else:
return f"{name} TAG" return f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
elif issubclass(typ, RedisModel): elif issubclass(typ, RedisModel):
sub_fields = [] sub_fields = []
for embedded_name, field in typ.__fields__.items(): for embedded_name, field in typ.__fields__.items():
@ -895,7 +1020,7 @@ class HashModel(RedisModel, abc.ABC):
field.field_info)) field.field_info))
return " ".join(sub_fields) return " ".join(sub_fields)
else: else:
return f"{name} TAG" return f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
class JsonModel(RedisModel, abc.ABC): class JsonModel(RedisModel, abc.ABC):
@ -927,7 +1052,7 @@ class JsonModel(RedisModel, abc.ABC):
_type = field.outer_type_ _type = field.outer_type_
if getattr(field.field_info, 'primary_key', None): if getattr(field.field_info, 'primary_key', None):
if issubclass(_type, str): if issubclass(_type, str):
redisearch_field = f"{json_path}.{name} AS {name} TAG" redisearch_field = f"{json_path}.{name} AS {name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
else: else:
redisearch_field = cls.schema_for_type(f"{json_path}.{name}", name, "", _type, field.field_info) redisearch_field = cls.schema_for_type(f"{json_path}.{name}", name, "", _type, field.field_info)
schema_parts.append(redisearch_field) schema_parts.append(redisearch_field)
@ -957,8 +1082,8 @@ class JsonModel(RedisModel, abc.ABC):
# find it in the JSON document, AND the name of the field as it should # find it in the JSON document, AND the name of the field as it should
# be in the redisearch schema (address_address_line_1). Maybe both "name" # be in the redisearch schema (address_address_line_1). Maybe both "name"
# and "name_prefix"? # and "name_prefix"?
def schema_for_type(cls, json_path: str, name: str, name_prefix: str, typ: Type, def schema_for_type(cls, json_path: str, name: str, name_prefix: str, typ: Any,
field_info: FieldInfo) -> str: field_info: PydanticFieldInfo) -> str:
index_field_name = f"{name_prefix}{name}" index_field_name = f"{name_prefix}{name}"
should_index = getattr(field_info, 'index', False) should_index = getattr(field_info, 'index', False)
@ -986,10 +1111,10 @@ class JsonModel(RedisModel, abc.ABC):
return f"{json_path} AS {index_field_name} NUMERIC" return f"{json_path} AS {index_field_name} NUMERIC"
elif issubclass(typ, str): elif issubclass(typ, str):
if getattr(field_info, 'full_text_search', False) is True: if getattr(field_info, 'full_text_search', False) is True:
return f"{json_path} AS {index_field_name} TAG " \ return f"{json_path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR} " \
f"{json_path} AS {index_field_name}_fts TEXT" f"{json_path} AS {index_field_name}_fts TEXT"
else: else:
return f"{json_path} AS {index_field_name} TAG" return f"{json_path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
else: else:
return f"{json_path} AS {index_field_name} TAG" return f"{json_path} AS {index_field_name} TAG"

View file

@ -0,0 +1,59 @@
"""
This code adapted from the library "pptree," Copyright (c) 2017 Clément Michard
and released under the MIT license: https://github.com/clemtoy/pptree
"""
import io
def render_tree(current_node, nameattr='name', left_child='left',
right_child='right', indent='', last='updown',
buffer=None):
"""Print a tree-like structure, `current_node`.
This is a mostly-direct-copy of the print_tree() function from the ppbtree
module of the pptree library, but instead of printing to standard out, we
write to a StringIO buffer, then use that buffer to accumulate written lines
during recursive calls to render_tree().
"""
if buffer is None:
buffer = io.StringIO()
if hasattr(current_node, nameattr):
name = lambda node: getattr(node, nameattr)
else:
name = lambda node: str(node)
up = getattr(current_node, left_child, None)
down = getattr(current_node, right_child, None)
if up is not None:
next_last = 'up'
next_indent = '{0}{1}{2}'.format(indent, ' ' if 'up' in last else '|', ' ' * len(str(name(current_node))))
render_tree(up, nameattr, left_child, right_child, next_indent, next_last, buffer)
if last == 'up':
start_shape = ''
elif last == 'down':
start_shape = ''
elif last == 'updown':
start_shape = ' '
else:
start_shape = ''
if up is not None and down is not None:
end_shape = ''
elif up:
end_shape = ''
elif down:
end_shape = ''
else:
end_shape = ''
print('{0}{1}{2}{3}'.format(indent, start_shape, name(current_node), end_shape),
file=buffer)
if down is not None:
next_last = 'down'
next_indent = '{0}{1}{2}'.format(indent, ' ' if 'down' in last else '|', ' ' * len(str(name(current_node))))
render_tree(down, nameattr, left_child, right_child, next_indent, next_last, buffer)
return f"\n{buffer.getvalue()}"

View file

@ -0,0 +1,24 @@
import re
from typing import Optional, Pattern
class TokenEscaper:
"""
Escape punctuation within an input string.
"""
# Characters that RediSearch requires us to escape during queries.
# Source: https://oss.redis.com/redisearch/Escaping/#the_rules_of_text_field_tokenization
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\ ]"
def __init__(self, escape_chars_re: Optional[Pattern] = None):
if escape_chars_re:
self.escaped_chars_re = escape_chars_re
else:
self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS)
def escape(self, string: str) -> str:
def escape_symbol(match):
value = match.group(0)
return f"\\{value}"
return self.escaped_chars_re.sub(escape_symbol, string)

View file

@ -163,13 +163,13 @@ def test_updates_a_model(members):
# Or, affecting multiple model instances with an implicit save: # Or, affecting multiple model instances with an implicit save:
Member.find(Member.last_name == "Brookins").update(last_name="Smith") Member.find(Member.last_name == "Brookins").update(last_name="Smith")
results = Member.find(Member.last_name == "Smith") results = Member.find(Member.last_name == "Smith")
assert sorted(results) == members assert results == members
def test_paginate_query(members): def test_paginate_query(members):
member1, member2, member3 = members member1, member2, member3 = members
actual = Member.find().all(batch_size=1) actual = Member.find().all(batch_size=1)
assert sorted(actual) == [member1, member2, member3] assert actual == [member1, member2, member3]
def test_access_result_by_index_cached(members): def test_access_result_by_index_cached(members):
@ -202,7 +202,7 @@ def test_exact_match_queries(members):
member1, member2, member3 = members member1, member2, member3 = members
actual = Member.find(Member.last_name == "Brookins").all() actual = Member.find(Member.last_name == "Brookins").all()
assert sorted(actual) == [member1, member2] assert actual == [member1, member2]
actual = Member.find( actual = Member.find(
(Member.last_name == "Brookins") & ~(Member.first_name == "Andrew")).all() (Member.last_name == "Brookins") & ~(Member.first_name == "Andrew")).all()
@ -218,7 +218,7 @@ def test_exact_match_queries(members):
(Member.last_name == "Brookins") & (Member.first_name == "Andrew") (Member.last_name == "Brookins") & (Member.first_name == "Andrew")
| (Member.first_name == "Kim") | (Member.first_name == "Kim")
).all() ).all()
assert actual == [member2, member1] assert actual == [member1, member2]
actual = Member.find(Member.first_name == "Kim", Member.last_name == "Brookins").all() actual = Member.find(Member.first_name == "Kim", Member.last_name == "Brookins").all()
assert actual == [member2] assert actual == [member2]
@ -230,7 +230,7 @@ def test_recursive_query_resolution(members):
actual = Member.find((Member.last_name == "Brookins") | ( actual = Member.find((Member.last_name == "Brookins") | (
Member.age == 100 Member.age == 100
) & (Member.last_name == "Smith")).all() ) & (Member.last_name == "Smith")).all()
assert sorted(actual) == [member1, member2, member3] assert actual == [member1, member2, member3]
def test_tag_queries_boolean_logic(members): def test_tag_queries_boolean_logic(members):
@ -239,35 +239,107 @@ def test_tag_queries_boolean_logic(members):
actual = Member.find( actual = Member.find(
(Member.first_name == "Andrew") & (Member.first_name == "Andrew") &
(Member.last_name == "Brookins") | (Member.last_name == "Smith")).all() (Member.last_name == "Brookins") | (Member.last_name == "Smith")).all()
assert sorted(actual) == [member1, member3] assert actual == [member1, member3]
def test_tag_queries_punctuation(): def test_tag_queries_punctuation():
member = Member( member1 = Member(
first_name="Andrew the Michael", first_name="Andrew, the Michael",
last_name="St. Brookins-on-Pier", last_name="St. Brookins-on-Pier",
email="a@example.com", email="a|b@example.com", # NOTE: This string uses the TAG field separator.
age=38, age=38,
join_date=today join_date=today,
) )
member.save() member1.save()
assert Member.find(Member.first_name == "Andrew the Michael").first() == member member2 = Member(
assert Member.find(Member.last_name == "St. Brookins-on-Pier").first() == member first_name="Bob",
assert Member.find(Member.email == "a@example.com").first() == member last_name="the Villain",
email="a|villain@example.com", # NOTE: This string uses the TAG field separator.
age=38,
join_date=today,
)
member2.save()
assert Member.find(Member.first_name == "Andrew, the Michael").first() == member1
assert Member.find(Member.last_name == "St. Brookins-on-Pier").first() == member1
# Notice that when we index and query multiple values that use the internal
# TAG separator for single-value exact-match fields, like an indexed string,
# the queries will succeed. We apply a workaround that queries for the union
# of the two values separated by the tag separator.
assert Member.find(Member.email == "a|b@example.com").all() == [member1]
assert Member.find(Member.email == "a|villain@example.com").all() == [member2]
def test_tag_queries_negation(members): def test_tag_queries_negation(members):
member1, member2, member3 = members member1, member2, member3 = members
actual = Member.find( """
first_name
NOT EQ
Andrew
"""
query = Member.find(
~(Member.first_name == "Andrew")
)
assert query.all() == [member2]
"""
first_name
NOT EQ
| Andrew
AND
| last_name
EQ
Brookins
"""
query = Member.find(
~(Member.first_name == "Andrew") & (Member.last_name == "Brookins")
)
assert query.all() == [member2]
"""
first_name
NOT EQ
| Andrew
AND
| last_name
| EQ
| | Brookins
OR
| last_name
EQ
Smith
"""
query = Member.find(
~(Member.first_name == "Andrew") & ~(Member.first_name == "Andrew") &
(Member.last_name == "Brookins") | (Member.last_name == "Smith")).all() ((Member.last_name == "Brookins") | (Member.last_name == "Smith")))
assert sorted(actual) == [member2, member3] assert query.all() == [member2]
"""
first_name
NOT EQ
| Andrew
AND
| | last_name
| EQ
| Brookins
OR
| last_name
EQ
Smith
"""
query = Member.find(
~(Member.first_name == "Andrew") &
(Member.last_name == "Brookins") | (Member.last_name == "Smith"))
assert query.all() == [member2, member3]
actual = Member.find( actual = Member.find(
(Member.first_name == "Andrew") & ~(Member.last_name == "Brookins")).all() (Member.first_name == "Andrew") & ~(Member.last_name == "Brookins")).all()
assert sorted(actual) == [member3] assert actual == [member3]
def test_numeric_queries(members): def test_numeric_queries(members):
@ -277,7 +349,7 @@ def test_numeric_queries(members):
assert actual == [member2] assert actual == [member2]
actual = Member.find(Member.age > 34).all() actual = Member.find(Member.age > 34).all()
assert sorted(actual) == [member1, member3] assert actual == [member1, member3]
actual = Member.find(Member.age < 35).all() actual = Member.find(Member.age < 35).all()
assert actual == [member2] assert actual == [member2]
@ -289,17 +361,17 @@ def test_numeric_queries(members):
assert actual == [member3] assert actual == [member3]
actual = Member.find(~(Member.age == 100)).all() actual = Member.find(~(Member.age == 100)).all()
assert sorted(actual) == [member1, member2] assert actual == [member1, member2]
def test_sorting(members): def test_sorting(members):
member1, member2, member3 = members member1, member2, member3 = members
actual = Member.find(Member.age > 34).sort_by('age').all() actual = Member.find(Member.age > 34).sort_by('age').all()
assert sorted(actual) == [member3, member1] assert actual == [member1, member3]
actual = Member.find(Member.age > 34).sort_by('-age').all() actual = Member.find(Member.age > 34).sort_by('-age').all()
assert sorted(actual) == [member1, member3] assert actual == [member3, member1]
with pytest.raises(QueryNotSupportedError): with pytest.raises(QueryNotSupportedError):
# This field does not exist. # This field does not exist.

View file

@ -44,8 +44,8 @@ class Order(BaseJsonModel):
class Member(BaseJsonModel): class Member(BaseJsonModel):
first_name: str first_name: str = Field(index=True)
last_name: str last_name: str = Field(index=True)
email: str = Field(index=True) email: str = Field(index=True)
join_date: datetime.date join_date: datetime.date
age: int = Field(index=True) age: int = Field(index=True)
@ -190,7 +190,7 @@ def test_updates_a_model(members):
# Or, affecting multiple model instances with an implicit save: # Or, affecting multiple model instances with an implicit save:
Member.find(Member.last_name == "Brookins").update(last_name="Smith") Member.find(Member.last_name == "Brookins").update(last_name="Smith")
results = Member.find(Member.last_name == "Smith") results = Member.find(Member.last_name == "Smith")
assert sorted(results) == members assert results == members
# Or, updating a field in an embedded model: # Or, updating a field in an embedded model:
member2.update(address__city="Happy Valley") member2.update(address__city="Happy Valley")
@ -200,7 +200,7 @@ def test_updates_a_model(members):
def test_paginate_query(members): def test_paginate_query(members):
member1, member2, member3 = members member1, member2, member3 = members
actual = Member.find().all(batch_size=1) actual = Member.find().all(batch_size=1)
assert sorted(actual) == [member1, member2, member3] assert actual == [member1, member2, member3]
def test_access_result_by_index_cached(members): def test_access_result_by_index_cached(members):
@ -233,7 +233,7 @@ def test_exact_match_queries(members):
member1, member2, member3 = members member1, member2, member3 = members
actual = Member.find(Member.last_name == "Brookins").all() actual = Member.find(Member.last_name == "Brookins").all()
assert sorted(actual) == [member1, member2] assert actual == [member1, member2]
actual = Member.find( actual = Member.find(
(Member.last_name == "Brookins") & ~(Member.first_name == "Andrew")).all() (Member.last_name == "Brookins") & ~(Member.first_name == "Andrew")).all()
@ -264,7 +264,7 @@ def test_recursive_query_resolution(members):
actual = Member.find((Member.last_name == "Brookins") | ( actual = Member.find((Member.last_name == "Brookins") | (
Member.age == 100 Member.age == 100
) & (Member.last_name == "Smith")).all() ) & (Member.last_name == "Smith")).all()
assert sorted(actual) == [member1, member2, member3] assert actual == [member1, member2, member3]
def test_tag_queries_boolean_logic(members): def test_tag_queries_boolean_logic(members):
@ -273,35 +273,109 @@ def test_tag_queries_boolean_logic(members):
actual = Member.find( actual = Member.find(
(Member.first_name == "Andrew") & (Member.first_name == "Andrew") &
(Member.last_name == "Brookins") | (Member.last_name == "Smith")).all() (Member.last_name == "Brookins") | (Member.last_name == "Smith")).all()
assert sorted(actual) == [member1, member3] assert actual == [member1, member3]
def test_tag_queries_punctuation(): def test_tag_queries_punctuation(address):
member = Member( member1 = Member(
first_name="Andrew, the Michael", # This string uses the TAG field separator. first_name="Andrew, the Michael",
last_name="St. Brookins-on-Pier", last_name="St. Brookins-on-Pier",
email="a@example.com", email="a|b@example.com", # NOTE: This string uses the TAG field separator.
age=38, age=38,
join_date=today join_date=today,
address=address
) )
member.save() member1.save()
assert Member.find(Member.first_name == "Andrew the Michael").first() == member member2 = Member(
assert Member.find(Member.last_name == "St. Brookins-on-Pier").first() == member first_name="Bob",
assert Member.find(Member.email == "a@example.com").first() == member last_name="the Villain",
email="a|villain@example.com", # NOTE: This string uses the TAG field separator.
age=38,
join_date=today,
address=address
)
member2.save()
assert Member.find(Member.first_name == "Andrew, the Michael").first() == member1
assert Member.find(Member.last_name == "St. Brookins-on-Pier").first() == member1
# Notice that when we index and query multiple values that use the internal
# TAG separator for single-value exact-match fields, like an indexed string,
# the queries will succeed. We apply a workaround that queries for the union
# of the two values separated by the tag separator.
assert Member.find(Member.email == "a|b@example.com").all() == [member1]
assert Member.find(Member.email == "a|villain@example.com").all() == [member2]
def test_tag_queries_negation(members): def test_tag_queries_negation(members):
member1, member2, member3 = members member1, member2, member3 = members
actual = Member.find( """
first_name
NOT EQ
Andrew
"""
query = Member.find(
~(Member.first_name == "Andrew")
)
assert query.all() == [member2]
"""
first_name
NOT EQ
| Andrew
AND
| last_name
EQ
Brookins
"""
query = Member.find(
~(Member.first_name == "Andrew") & (Member.last_name == "Brookins")
)
assert query.all() == [member2]
"""
first_name
NOT EQ
| Andrew
AND
| last_name
| EQ
| | Brookins
OR
| last_name
EQ
Smith
"""
query = Member.find(
~(Member.first_name == "Andrew") & ~(Member.first_name == "Andrew") &
(Member.last_name == "Brookins") | (Member.last_name == "Smith")).all() ((Member.last_name == "Brookins") | (Member.last_name == "Smith")))
assert sorted(actual) == [member2, member3] assert query.all() == [member2]
"""
first_name
NOT EQ
| Andrew
AND
| | last_name
| EQ
| Brookins
OR
| last_name
EQ
Smith
"""
query = Member.find(
~(Member.first_name == "Andrew") &
(Member.last_name == "Brookins") | (Member.last_name == "Smith"))
assert query.all() == [member2, member3]
actual = Member.find( actual = Member.find(
(Member.first_name == "Andrew") & ~(Member.last_name == "Brookins")).all() (Member.first_name == "Andrew") & ~(Member.last_name == "Brookins")).all()
assert sorted(actual) == [member3] assert actual == [member3]
def test_numeric_queries(members): def test_numeric_queries(members):
@ -311,7 +385,7 @@ def test_numeric_queries(members):
assert actual == [member2] assert actual == [member2]
actual = Member.find(Member.age > 34).all() actual = Member.find(Member.age > 34).all()
assert sorted(actual) == [member1, member3] assert actual == [member1, member3]
actual = Member.find(Member.age < 35).all() actual = Member.find(Member.age < 35).all()
assert actual == [member2] assert actual == [member2]
@ -323,17 +397,17 @@ def test_numeric_queries(members):
assert actual == [member3] assert actual == [member3]
actual = Member.find(~(Member.age == 100)).all() actual = Member.find(~(Member.age == 100)).all()
assert sorted(actual) == [member1, member2] assert actual == [member1, member2]
def test_sorting(members): def test_sorting(members):
member1, member2, member3 = members member1, member2, member3 = members
actual = Member.find(Member.age > 34).sort_by('age').all() actual = Member.find(Member.age > 34).sort_by('age').all()
assert sorted(actual) == [member3, member1] assert actual == [member1, member3]
actual = Member.find(Member.age > 34).sort_by('-age').all() actual = Member.find(Member.age > 34).sort_by('-age').all()
assert sorted(actual) == [member1, member3] assert actual == [member3, member1]
with pytest.raises(QueryNotSupportedError): with pytest.raises(QueryNotSupportedError):
# This field does not exist. # This field does not exist.
@ -354,7 +428,10 @@ def test_schema():
assert Member.redisearch_schema() == "ON JSON PREFIX 1 " \ assert Member.redisearch_schema() == "ON JSON PREFIX 1 " \
"redis-developer:tests.test_json_model.Member: " \ "redis-developer:tests.test_json_model.Member: " \
"SCHEMA $.pk AS pk TAG " \ "SCHEMA $.pk AS pk TAG " \
"$.first_name AS first_name TAG " \
"$.last_name AS last_name TAG " \
"$.email AS email TAG " \ "$.email AS email TAG " \
"$.age AS age NUMERIC " \
"$.address.pk AS address_pk TAG " \ "$.address.pk AS address_pk TAG " \
"$.address.postal_code AS address_postal_code TAG " \ "$.address.postal_code AS address_postal_code TAG " \
"$.orders[].pk AS orders_pk TAG " \ "$.orders[].pk AS orders_pk TAG " \