From 8f32b359f06096d9c3b3181660b05402501f41c5 Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Tue, 5 Oct 2021 16:40:02 -0700 Subject: [PATCH] Handle TAG queries that include the separator --- poetry.lock | 105 +++++++++- pyproject.toml | 5 + redis_developer/orm/model.py | 285 +++++++++++++++++++-------- redis_developer/orm/render_tree.py | 59 ++++++ redis_developer/orm/token_escaper.py | 24 +++ tests/test_hash_model.py | 116 ++++++++--- tests/test_json_model.py | 125 +++++++++--- 7 files changed, 591 insertions(+), 128 deletions(-) create mode 100644 redis_developer/orm/render_tree.py create mode 100644 redis_developer/orm/token_escaper.py diff --git a/poetry.lock b/poetry.lock index e6b534d..070e4ac 100644 --- a/poetry.lock +++ b/poetry.lock @@ -165,6 +165,31 @@ python-versions = ">=3.5" [package.dependencies] traitlets = "*" +[[package]] +name = "mypy" +version = "0.910" +description = "Optional static typing for Python" +category = "main" +optional = false +python-versions = ">=3.5" + +[package.dependencies] +mypy-extensions = ">=0.4.3,<0.5.0" +toml = "*" +typing-extensions = ">=3.7.4" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +python2 = ["typed-ast (>=1.4.0,<1.5.0)"] + +[[package]] +name = "mypy-extensions" +version = "0.4.3" +description = "Experimental type system extensions for programs checked with the mypy typechecker." +category = "main" +optional = false +python-versions = "*" + [[package]] name = "packaging" version = "21.0" @@ -219,6 +244,14 @@ python-versions = ">=3.6" dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "pptree" +version = "3.1" +description = "Pretty print trees" +category = "main" +optional = false +python-versions = "*" + [[package]] name = "prompt-toolkit" version = "3.0.20" @@ -298,6 +331,14 @@ toml = "*" [package.extras] testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"] +[[package]] +name = "python-ulid" +version = "1.0.3" +description = "Universally Unique Lexicographically Sortable Identifier" +category = "main" +optional = false +python-versions = "*" + [[package]] name = "redis" version = "3.5.3" @@ -321,7 +362,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" name = "toml" version = "0.10.2" description = "Python Library for Tom's Obvious, Minimal Language" -category = "dev" +category = "main" optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" @@ -336,6 +377,22 @@ python-versions = ">=3.7" [package.extras] test = ["pytest"] +[[package]] +name = "types-redis" +version = "3.5.9" +description = "Typing stubs for redis" +category = "main" +optional = false +python-versions = "*" + +[[package]] +name = "types-six" +version = "1.16.1" +description = "Typing stubs for six" +category = "main" +optional = false +python-versions = "*" + [[package]] name = "typing-extensions" version = "3.10.0.2" @@ -355,7 +412,7 @@ python-versions = "*" [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "b3f0c7c5701bb2c317df7f2f42218ef6c38d9b035ed49c9a9469df6a6727973c" +content-hash = "baa4bd3c38445c3325bdd317ecbfe99ccaf4bef438970ed31f5c49cc782d575e" [metadata.files] aioredis = [ @@ -413,6 +470,35 @@ matplotlib-inline = [ {file = "matplotlib-inline-0.1.3.tar.gz", hash = "sha256:a04bfba22e0d1395479f866853ec1ee28eea1485c1d69a6faf00dc3e24ff34ee"}, {file = "matplotlib_inline-0.1.3-py3-none-any.whl", hash = "sha256:aed605ba3b72462d64d475a21a9296f400a19c4f74a31b59103d2a99ffd5aa5c"}, ] +mypy = [ + {file = "mypy-0.910-cp35-cp35m-macosx_10_9_x86_64.whl", hash = "sha256:a155d80ea6cee511a3694b108c4494a39f42de11ee4e61e72bc424c490e46457"}, + {file = "mypy-0.910-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:b94e4b785e304a04ea0828759172a15add27088520dc7e49ceade7834275bedb"}, + {file = "mypy-0.910-cp35-cp35m-manylinux2010_x86_64.whl", hash = "sha256:088cd9c7904b4ad80bec811053272986611b84221835e079be5bcad029e79dd9"}, + {file = "mypy-0.910-cp35-cp35m-win_amd64.whl", hash = "sha256:adaeee09bfde366d2c13fe6093a7df5df83c9a2ba98638c7d76b010694db760e"}, + {file = "mypy-0.910-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:ecd2c3fe726758037234c93df7e98deb257fd15c24c9180dacf1ef829da5f921"}, + {file = "mypy-0.910-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:d9dd839eb0dc1bbe866a288ba3c1afc33a202015d2ad83b31e875b5905a079b6"}, + {file = "mypy-0.910-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:3e382b29f8e0ccf19a2df2b29a167591245df90c0b5a2542249873b5c1d78212"}, + {file = "mypy-0.910-cp36-cp36m-win_amd64.whl", hash = "sha256:53fd2eb27a8ee2892614370896956af2ff61254c275aaee4c230ae771cadd885"}, + {file = "mypy-0.910-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b6fb13123aeef4a3abbcfd7e71773ff3ff1526a7d3dc538f3929a49b42be03f0"}, + {file = "mypy-0.910-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:e4dab234478e3bd3ce83bac4193b2ecd9cf94e720ddd95ce69840273bf44f6de"}, + {file = "mypy-0.910-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:7df1ead20c81371ccd6091fa3e2878559b5c4d4caadaf1a484cf88d93ca06703"}, + {file = "mypy-0.910-cp37-cp37m-win_amd64.whl", hash = "sha256:0aadfb2d3935988ec3815952e44058a3100499f5be5b28c34ac9d79f002a4a9a"}, + {file = "mypy-0.910-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ec4e0cd079db280b6bdabdc807047ff3e199f334050db5cbb91ba3e959a67504"}, + {file = "mypy-0.910-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:119bed3832d961f3a880787bf621634ba042cb8dc850a7429f643508eeac97b9"}, + {file = "mypy-0.910-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:866c41f28cee548475f146aa4d39a51cf3b6a84246969f3759cb3e9c742fc072"}, + {file = "mypy-0.910-cp38-cp38-win_amd64.whl", hash = "sha256:ceb6e0a6e27fb364fb3853389607cf7eb3a126ad335790fa1e14ed02fba50811"}, + {file = "mypy-0.910-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1a85e280d4d217150ce8cb1a6dddffd14e753a4e0c3cf90baabb32cefa41b59e"}, + {file = "mypy-0.910-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:42c266ced41b65ed40a282c575705325fa7991af370036d3f134518336636f5b"}, + {file = "mypy-0.910-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:3c4b8ca36877fc75339253721f69603a9c7fdb5d4d5a95a1a1b899d8b86a4de2"}, + {file = "mypy-0.910-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:c0df2d30ed496a08de5daed2a9ea807d07c21ae0ab23acf541ab88c24b26ab97"}, + {file = "mypy-0.910-cp39-cp39-win_amd64.whl", hash = "sha256:c6c2602dffb74867498f86e6129fd52a2770c48b7cd3ece77ada4fa38f94eba8"}, + {file = "mypy-0.910-py3-none-any.whl", hash = "sha256:ef565033fa5a958e62796867b1df10c40263ea9ded87164d67572834e57a174d"}, + {file = "mypy-0.910.tar.gz", hash = "sha256:704098302473cb31a218f1775a873b376b30b4c18229421e9e9dc8916fd16150"}, +] +mypy-extensions = [ + {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"}, + {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"}, +] packaging = [ {file = "packaging-21.0-py3-none-any.whl", hash = "sha256:c86254f9220d55e31cc94d69bade760f0847da8000def4dfe1c6b872fd14ff14"}, {file = "packaging-21.0.tar.gz", hash = "sha256:7dc96269f53a4ccec5c0670940a4281106dd0bb343f47b7471f779df49c2fbe7"}, @@ -433,6 +519,9 @@ pluggy = [ {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"}, {file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"}, ] +pptree = [ + {file = "pptree-3.1.tar.gz", hash = "sha256:4dd0ba2f58000cbd29d68a5b64bac29bcb5a663642f79404877c0059668a69f6"}, +] prompt-toolkit = [ {file = "prompt_toolkit-3.0.20-py3-none-any.whl", hash = "sha256:6076e46efae19b1e0ca1ec003ed37a933dc94b4d20f486235d436e64771dcd5c"}, {file = "prompt_toolkit-3.0.20.tar.gz", hash = "sha256:eb71d5a6b72ce6db177af4a7d4d7085b99756bf656d98ffcc4fecd36850eea6c"}, @@ -481,6 +570,10 @@ pytest = [ {file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"}, {file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"}, ] +python-ulid = [ + {file = "python-ulid-1.0.3.tar.gz", hash = "sha256:5dd8b969312a40e2212cec9c1ad63f25d4b6eafd92ee3195883e0287b6e9d19e"}, + {file = "python_ulid-1.0.3-py3-none-any.whl", hash = "sha256:8704dc20f547f531fe3a41d4369842d737a0f275403b909d0872e7ea0fe8d6f2"}, +] redis = [ {file = "redis-3.5.3-py2.py3-none-any.whl", hash = "sha256:432b788c4530cfe16d8d943a09d40ca6c16149727e4afe8c2c9d5580c59d9f24"}, {file = "redis-3.5.3.tar.gz", hash = "sha256:0e7e0cfca8660dea8b7d5cd8c4f6c5e29e11f31158c0b0ae91a397f00e5a05a2"}, @@ -497,6 +590,14 @@ traitlets = [ {file = "traitlets-5.1.0-py3-none-any.whl", hash = "sha256:03f172516916220b58c9f19d7f854734136dd9528103d04e9bf139a92c9f54c4"}, {file = "traitlets-5.1.0.tar.gz", hash = "sha256:bd382d7ea181fbbcce157c133db9a829ce06edffe097bcf3ab945b435452b46d"}, ] +types-redis = [ + {file = "types-redis-3.5.9.tar.gz", hash = "sha256:f142c48f4080757ca2a9441ec40213bda3b1535eebebfc4f3519e5aa46498076"}, + {file = "types_redis-3.5.9-py3-none-any.whl", hash = "sha256:5f5648ffc025708858097173cf695164c20f2b5e3f57177de14e352cae8cc335"}, +] +types-six = [ + {file = "types-six-1.16.1.tar.gz", hash = "sha256:a9e6769cb0808f920958ac95f75c5191f49e21e041eac127fa62e286e1005616"}, + {file = "types_six-1.16.1-py2.py3-none-any.whl", hash = "sha256:b14f5abe26c0997bd41a1a32d6816af25932f7bfbc54246dfdc8f6f6404fd1d4"}, +] typing-extensions = [ {file = "typing_extensions-3.10.0.2-py2-none-any.whl", hash = "sha256:d8226d10bc02a29bcc81df19a26e56a9647f8b0a6d4a83924139f4a8b01f17b7"}, {file = "typing_extensions-3.10.0.2-py3-none-any.whl", hash = "sha256:f1d25edafde516b146ecd0613dabcc61409817af4766fbbcfb8d1ad4ec441a34"}, diff --git a/pyproject.toml b/pyproject.toml index 206bc5a..dd3def5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,11 @@ aioredis = "^2.0.0" pydantic = "^1.8.2" click = "^8.0.1" six = "^1.16.0" +pptree = "^3.1" +mypy = "^0.910" +types-redis = "^3.5.9" +types-six = "^1.16.1" +python-ulid = "^1.0.3" [tool.poetry.dev-dependencies] pytest = "^6.2.4" diff --git a/redis_developer/orm/model.py b/redis_developer/orm/model.py index bb54327..b69a678 100644 --- a/redis_developer/orm/model.py +++ b/redis_developer/orm/model.py @@ -1,9 +1,9 @@ import abc import dataclasses import decimal +import json import logging import operator -import re from copy import deepcopy from enum import Enum from functools import reduce @@ -22,10 +22,9 @@ from typing import ( no_type_check, Protocol, List, - Type, - Pattern, get_origin, get_args + get_origin, + get_args, Type ) -import uuid import redis from pydantic import BaseModel, validator @@ -34,46 +33,43 @@ from pydantic.fields import ModelField, Undefined, UndefinedType from pydantic.main import ModelMetaclass from pydantic.typing import NoArgAnyCallable, resolve_annotations from pydantic.utils import Representation +from ulid import ULID from .encoders import jsonable_encoder +from .render_tree import render_tree +from .token_escaper import TokenEscaper model_registry = {} _T = TypeVar("_T") log = logging.getLogger(__name__) - - -class TokenEscaper: - """ - Escape punctuation within an input string. - """ - - # Characters that RediSearch requires us to escape during queries. - # Source: https://oss.redis.com/redisearch/Escaping/#the_rules_of_text_field_tokenization - DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\ ]" - - def __init__(self, escape_chars_re: Optional[Pattern] = None): - if escape_chars_re: - self.escaped_chars_re = escape_chars_re - else: - self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS) - - def escape(self, string): - def escape_symbol(match): - value = match.group(0) - return f"\\{value}" - - return self.escaped_chars_re.sub(escape_symbol, string) - - escaper = TokenEscaper() +# For basic exact-match field types like an indexed string, we create a TAG +# field in the RediSearch index. TAG is designed for multi-value fields +# separated by a "separator" character. We're using the field for single values +# (multi-value TAGs will be exposed as a separate field type), and we use the +# pipe character (|) as the separator. There is no way to escape this character +# in hash fields or JSON objects, so if someone indexes a value that includes +# the pipe, we'll warn but allow, and then warn again if they try to query for +# values that contain this separator. +SINGLE_VALUE_TAG_FIELD_SEPARATOR = "|" + +# This is the default field separator in RediSearch. We need it to determine if +# someone has accidentally passed in the field separator with string value of a +# multi-value field lookup, like a IN or NOT_IN. +DEFAULT_REDISEARCH_FIELD_SEPARATOR = "," + class RedisModelError(Exception): - pass + """Raised when a problem exists in the definition of a RedisModel.""" + + +class QuerySyntaxError(Exception): + """Raised when a query is constructed improperly.""" class NotFoundError(Exception): - """A query found no results.""" + """Raised when a query found no results.""" class Operators(Enum): @@ -91,9 +87,45 @@ class Operators(Enum): LIKE = 12 ALL = 13 + def __str__(self): + return str(self.name) + + +ExpressionOrModelField = Union['Expression', 'NegatedExpression', ModelField] + + +class ExpressionProtocol(Protocol): + op: Operators + left: ExpressionOrModelField + right: ExpressionOrModelField + + def __invert__(self) -> 'Expression': + pass + + def __and__(self, other: ExpressionOrModelField): + pass + + def __or__(self, other: ExpressionOrModelField): + pass + + @property + def name(self) -> str: + raise NotImplementedError + + @property + def tree(self) -> str: + raise NotImplementedError + @dataclasses.dataclass class NegatedExpression: + """A negated Expression object. + + For now, this is a separate dataclass from Expression that acts as a facade + around an Expression, indicating to model code (specifically, code + responsible for querying) to negate the logic in the wrapped Expression. A + better design is probably possible, maybe at least an ExpressionProtocol? + """ expression: 'Expression' def __invert__(self): @@ -105,22 +137,53 @@ class NegatedExpression: def __or__(self, other): return Expression(left=self, op=Operators.OR, right=other) + @property + def left(self): + return self.expression.left + + @property + def right(self): + return self.expression.right + + @property + def op(self): + return self.expression.op + + @property + def name(self): + if self.expression.op is Operators.EQ: + return f"NOT {self.expression.name}" + else: + return f"{self.expression.name} NOT" + + @property + def tree(self): + return render_tree(self) + @dataclasses.dataclass class Expression: op: Operators - left: Any - right: Any + left: ExpressionOrModelField + right: ExpressionOrModelField def __invert__(self): return NegatedExpression(self) - def __and__(self, other): + def __and__(self, other: ExpressionOrModelField): return Expression(left=self, op=Operators.AND, right=other) - def __or__(self, other): + def __or__(self, other: ExpressionOrModelField): return Expression(left=self, op=Operators.OR, right=other) + @property + def name(self): + return str(self.op) + + @property + def tree(self): + return render_tree(self) + ExpressionOrNegated = Union[Expression, NegatedExpression] @@ -129,22 +192,22 @@ class ExpressionProxy: def __init__(self, field: ModelField): self.field = field - def __eq__(self, other: Any) -> Expression: + def __eq__(self, other: Any) -> Expression: # type: ignore[override] return Expression(left=self.field, op=Operators.EQ, right=other) - def __ne__(self, other: Any) -> Expression: + def __ne__(self, other: Any) -> Expression: # type: ignore[override] return Expression(left=self.field, op=Operators.NE, right=other) - def __lt__(self, other: Any) -> Expression: + def __lt__(self, other: Any) -> Expression: # type: ignore[override] return Expression(left=self.field, op=Operators.LT, right=other) - def __le__(self, other: Any) -> Expression: + def __le__(self, other: Any) -> Expression: # type: ignore[override] return Expression(left=self.field, op=Operators.LE, right=other) - def __gt__(self, other: Any) -> Expression: + def __gt__(self, other: Any) -> Expression: # type: ignore[override] return Expression(left=self.field, op=Operators.GT, right=other) - def __ge__(self, other: Any) -> Expression: + def __ge__(self, other: Any) -> Expression: # type: ignore[override] return Expression(left=self.field, op=Operators.GE, right=other) @@ -184,9 +247,9 @@ class FindQuery: self.sort_fields = [] self._expression = None - self._query = None - self._pagination = [] - self._model_cache = [] + self._query: Optional[str] = None + self._pagination: list[str] = [] + self._model_cache: list[RedisModel] = [] @property def pagination(self): @@ -236,24 +299,24 @@ class FindQuery: else: # TAG fields are the default field type. # TODO: A ListField or ArrayField that supports multiple values - # and contains logic. + # and contains logic should allow IN and NOT_IN queries. return RediSearchFieldTypes.TAG @staticmethod def expand_tag_value(value): - err = RedisModelError(f"Using the IN operator requires passing a sequence of " - "possible values. You passed: {value}") if isinstance(str, value): - raise err + return value try: expanded_value = "|".join([escaper.escape(v) for v in value]) except TypeError: - raise err + raise QuerySyntaxError("Values passed to an IN query must be iterables," + "like a list of strings. For more information, see:" + "TODO: doc.") return expanded_value @classmethod def resolve_value(cls, field_name: str, field_type: RediSearchFieldTypes, - op: Operators, value: Any) -> str: + field_info: PydanticFieldInfo, op: Operators, value: Any) -> str: result = "" if field_type is RediSearchFieldTypes.TEXT: result = f"@{field_name}:" @@ -282,17 +345,41 @@ class FindQuery: result += f"@{field_name}:[{value} +inf]" elif op is Operators.LE: result += f"@{field_name}:[-inf {value}]" + # TODO: How will we know the difference between a multi-value use of a TAG + # field and our hidden use of TAG for exact-match queries? elif field_type is RediSearchFieldTypes.TAG: if op is Operators.EQ: - value = escaper.escape(value) - result += f"@{field_name}:{{{value}}}" + separator_char = getattr(field_info, 'separator', + SINGLE_VALUE_TAG_FIELD_SEPARATOR) + if value == separator_char: + # The value is ONLY the TAG field separator character -- + # this is not going to work. + log.warning("Your query against the field %s is for a single character, %s, " + "that is used internally by redis-developer-python. We must ignore " + "this portion of the query. Please review your query to find " + "an alternative query that uses a string containing more than " + "just the character %s.", field_name, separator_char, separator_char) + return "" + if separator_char in value: + # The value contains the TAG field separator. We can work + # around this by breaking apart the values and unioning them + # with multiple field:{} queries. + values = filter(None, value.split(separator_char)) + for value in values: + value = escaper.escape(value) + result += f"@{field_name}:{{{value}}}" + else: + value = escaper.escape(value) + result += f"@{field_name}:{{{value}}}" elif op is Operators.NE: value = escaper.escape(value) result += f"-(@{field_name}:{{{value}}})" elif op is Operators.IN: + # TODO: Implement IN, test this... expanded_value = cls.expand_tag_value(value) result += f"(@{field_name}:{{{expanded_value}}})" elif op is Operators.NOT_IN: + # TODO: Implement NOT_IN, test this... expanded_value = cls.expand_tag_value(value) result += f"-(@{field_name}:{{{expanded_value}}})" @@ -314,10 +401,11 @@ class FindQuery: return ["SORTBY", *fields] @classmethod - def resolve_redisearch_query(cls, expression: ExpressionOrNegated): + def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str: """Resolve an expression to a string RediSearch query.""" field_type = None field_name = None + field_info = None encompassing_expression_is_negated = False result = "" @@ -328,7 +416,7 @@ class FindQuery: if expression.op is Operators.ALL: if encompassing_expression_is_negated: # TODO: Is there a use case for this, perhaps for dynamic - # scoring purposes? + # scoring purposes? raise QueryNotSupportedError("You cannot negate a query for all results.") return "*" @@ -338,6 +426,7 @@ class FindQuery: elif isinstance(expression.left, ModelField): field_type = cls.resolve_field_type(expression.left) field_name = expression.left.name + field_info = expression.left.field_info else: raise QueryNotSupportedError(f"A query expression should start with either a field " f"or an expression enclosed in parenthesis. See docs: " @@ -365,8 +454,7 @@ class FindQuery: if isinstance(right, ModelField): raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO") else: - # TODO: Optionals causing IDE errors here - result += cls.resolve_value(field_name, field_type, expression.op, right) + result += cls.resolve_value(field_name, field_type, field_info, expression.op, right) if encompassing_expression_is_negated: result = f"-({result})" @@ -416,7 +504,10 @@ class FindQuery: def first(self): query = FindQuery(expressions=self.expressions, model=self.model, offset=0, limit=1, sort_fields=self.sort_fields) - return query.execute()[0] + results = query.execute() + if not results: + raise NotFoundError() + return results[0] def all(self, batch_size=10): if batch_size != self.page_size: @@ -494,9 +585,13 @@ class PrimaryKeyCreator(Protocol): """Create a new primary key""" -class Uuid4PrimaryKey: - def create_pk(self, *args, **kwargs) -> str: - return str(uuid.uuid4()) +class UlidPrimaryKey: + """A client-side generated primary key that follows the ULID spec. + https://github.com/ulid/javascript#specification + """ + @staticmethod + def create_pk(*args, **kwargs) -> str: + return str(ULID()) def __dataclass_transform__( @@ -601,8 +696,24 @@ class PrimaryKey: field: ModelField +class MetaProtocol(Protocol): + global_key_prefix: str + model_key_prefix: str + primary_key_pattern: str + database: redis.Redis + primary_key: PrimaryKey + primary_key_creator_cls: Type[PrimaryKeyCreator] + index_name: str + abstract: bool + + +@dataclasses.dataclass class DefaultMeta: - # TODO: Should this really be optional here? + """A default placeholder Meta object. + + TODO: Revisit whether this is really necessary, and whether making + these all optional here is the right choice. + """ global_key_prefix: Optional[str] = None model_key_prefix: Optional[str] = None primary_key_pattern: Optional[str] = None @@ -614,6 +725,8 @@ class DefaultMeta: class ModelMeta(ModelMetaclass): + _meta: MetaProtocol + def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 meta = attrs.pop('Meta', None) new_class = super().__new__(cls, name, bases, attrs, **kwargs) @@ -656,11 +769,16 @@ class ModelMeta(ModelMetaclass): redis.Redis(decode_responses=True)) if not getattr(new_class._meta, 'primary_key_creator_cls', None): new_class._meta.primary_key_creator_cls = getattr(base_meta, "primary_key_creator_cls", - Uuid4PrimaryKey) + UlidPrimaryKey) if not getattr(new_class._meta, 'index_name', None): new_class._meta.index_name = f"{new_class._meta.global_key_prefix}:" \ f"{new_class._meta.model_key_prefix}:index" + # Not an abstract model class + if abc.ABC not in bases: + key = f"{new_class.__module__}.{new_class.__qualname__}" + model_registry[key] = new_class + return new_class @@ -680,12 +798,8 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): __pydantic_self__.validate_primary_key() def __lt__(self, other): - """Default sort: compare all shared model fields.""" - my_keys = set(self.__fields__.keys()) - other_keys = set(other.__fields__.keys()) - shared_keys = list(my_keys & other_keys) - lt = [getattr(self, k) < getattr(other, k) for k in shared_keys] - return len(lt) > len(shared_keys) / 2 + """Default sort: compare primary key of models.""" + return self.pk < other.pk @validator("pk", always=True) def validate_pk(cls, v): @@ -726,7 +840,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): return cls._meta.database @classmethod - def find(cls, *expressions: Union[Any, Expression]): # TODO: How to type annotate this? + def find(cls, *expressions: Union[Any, Expression]) -> FindQuery: # TODO: How to type annotate this? return FindQuery(expressions=expressions, model=cls) @classmethod @@ -760,7 +874,17 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): except KeyError: pass - doc = cls(**fields) + try: + fields['json'] = fields['$'] + del fields['$'] + except KeyError: + pass + + if 'json' in fields: + json_fields = json.loads(fields['json']) + doc = cls(**json_fields) + else: + doc = cls(**fields) docs.append(doc) return docs @@ -847,7 +971,7 @@ class HashModel(RedisModel, abc.ABC): _type = field.outer_type_ if getattr(field.field_info, 'primary_key', None): if issubclass(_type, str): - redisearch_field = f"{name} TAG" + redisearch_field = f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}" else: redisearch_field = cls.schema_for_type(name, _type, field.field_info) schema_parts.append(redisearch_field) @@ -872,7 +996,7 @@ class HashModel(RedisModel, abc.ABC): return schema_parts @classmethod - def schema_for_type(cls, name, typ: Type, field_info: FieldInfo): + def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo): if get_origin(typ) == list: embedded_cls = get_args(typ) if not embedded_cls: @@ -885,9 +1009,10 @@ class HashModel(RedisModel, abc.ABC): return f"{name} NUMERIC" elif issubclass(typ, str): if getattr(field_info, 'full_text_search', False) is True: - return f"{name} TAG {name}_fts TEXT" + return f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR} " \ + f"{name}_fts TEXT" else: - return f"{name} TAG" + return f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}" elif issubclass(typ, RedisModel): sub_fields = [] for embedded_name, field in typ.__fields__.items(): @@ -895,7 +1020,7 @@ class HashModel(RedisModel, abc.ABC): field.field_info)) return " ".join(sub_fields) else: - return f"{name} TAG" + return f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}" class JsonModel(RedisModel, abc.ABC): @@ -927,7 +1052,7 @@ class JsonModel(RedisModel, abc.ABC): _type = field.outer_type_ if getattr(field.field_info, 'primary_key', None): if issubclass(_type, str): - redisearch_field = f"{json_path}.{name} AS {name} TAG" + redisearch_field = f"{json_path}.{name} AS {name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}" else: redisearch_field = cls.schema_for_type(f"{json_path}.{name}", name, "", _type, field.field_info) schema_parts.append(redisearch_field) @@ -957,8 +1082,8 @@ class JsonModel(RedisModel, abc.ABC): # find it in the JSON document, AND the name of the field as it should # be in the redisearch schema (address_address_line_1). Maybe both "name" # and "name_prefix"? - def schema_for_type(cls, json_path: str, name: str, name_prefix: str, typ: Type, - field_info: FieldInfo) -> str: + def schema_for_type(cls, json_path: str, name: str, name_prefix: str, typ: Any, + field_info: PydanticFieldInfo) -> str: index_field_name = f"{name_prefix}{name}" should_index = getattr(field_info, 'index', False) @@ -986,10 +1111,10 @@ class JsonModel(RedisModel, abc.ABC): return f"{json_path} AS {index_field_name} NUMERIC" elif issubclass(typ, str): if getattr(field_info, 'full_text_search', False) is True: - return f"{json_path} AS {index_field_name} TAG " \ + return f"{json_path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR} " \ f"{json_path} AS {index_field_name}_fts TEXT" else: - return f"{json_path} AS {index_field_name} TAG" + return f"{json_path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}" else: return f"{json_path} AS {index_field_name} TAG" diff --git a/redis_developer/orm/render_tree.py b/redis_developer/orm/render_tree.py new file mode 100644 index 0000000..7bc0c91 --- /dev/null +++ b/redis_developer/orm/render_tree.py @@ -0,0 +1,59 @@ +""" +This code adapted from the library "pptree," Copyright (c) 2017 Clément Michard +and released under the MIT license: https://github.com/clemtoy/pptree +""" +import io + + +def render_tree(current_node, nameattr='name', left_child='left', + right_child='right', indent='', last='updown', + buffer=None): + """Print a tree-like structure, `current_node`. + + This is a mostly-direct-copy of the print_tree() function from the ppbtree + module of the pptree library, but instead of printing to standard out, we + write to a StringIO buffer, then use that buffer to accumulate written lines + during recursive calls to render_tree(). + """ + if buffer is None: + buffer = io.StringIO() + if hasattr(current_node, nameattr): + name = lambda node: getattr(node, nameattr) + else: + name = lambda node: str(node) + + up = getattr(current_node, left_child, None) + down = getattr(current_node, right_child, None) + + if up is not None: + next_last = 'up' + next_indent = '{0}{1}{2}'.format(indent, ' ' if 'up' in last else '|', ' ' * len(str(name(current_node)))) + render_tree(up, nameattr, left_child, right_child, next_indent, next_last, buffer) + + if last == 'up': + start_shape = '┌' + elif last == 'down': + start_shape = '└' + elif last == 'updown': + start_shape = ' ' + else: + start_shape = '├' + + if up is not None and down is not None: + end_shape = '┤' + elif up: + end_shape = '┘' + elif down: + end_shape = '┐' + else: + end_shape = '' + + print('{0}{1}{2}{3}'.format(indent, start_shape, name(current_node), end_shape), + file=buffer) + + if down is not None: + next_last = 'down' + next_indent = '{0}{1}{2}'.format(indent, ' ' if 'down' in last else '|', ' ' * len(str(name(current_node)))) + render_tree(down, nameattr, left_child, right_child, next_indent, next_last, buffer) + + return f"\n{buffer.getvalue()}" \ No newline at end of file diff --git a/redis_developer/orm/token_escaper.py b/redis_developer/orm/token_escaper.py new file mode 100644 index 0000000..f623648 --- /dev/null +++ b/redis_developer/orm/token_escaper.py @@ -0,0 +1,24 @@ +import re +from typing import Optional, Pattern + + +class TokenEscaper: + """ + Escape punctuation within an input string. + """ + # Characters that RediSearch requires us to escape during queries. + # Source: https://oss.redis.com/redisearch/Escaping/#the_rules_of_text_field_tokenization + DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\ ]" + + def __init__(self, escape_chars_re: Optional[Pattern] = None): + if escape_chars_re: + self.escaped_chars_re = escape_chars_re + else: + self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS) + + def escape(self, string: str) -> str: + def escape_symbol(match): + value = match.group(0) + return f"\\{value}" + + return self.escaped_chars_re.sub(escape_symbol, string) diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index 898321d..55962c4 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -163,13 +163,13 @@ def test_updates_a_model(members): # Or, affecting multiple model instances with an implicit save: Member.find(Member.last_name == "Brookins").update(last_name="Smith") results = Member.find(Member.last_name == "Smith") - assert sorted(results) == members + assert results == members def test_paginate_query(members): member1, member2, member3 = members actual = Member.find().all(batch_size=1) - assert sorted(actual) == [member1, member2, member3] + assert actual == [member1, member2, member3] def test_access_result_by_index_cached(members): @@ -202,7 +202,7 @@ def test_exact_match_queries(members): member1, member2, member3 = members actual = Member.find(Member.last_name == "Brookins").all() - assert sorted(actual) == [member1, member2] + assert actual == [member1, member2] actual = Member.find( (Member.last_name == "Brookins") & ~(Member.first_name == "Andrew")).all() @@ -218,7 +218,7 @@ def test_exact_match_queries(members): (Member.last_name == "Brookins") & (Member.first_name == "Andrew") | (Member.first_name == "Kim") ).all() - assert actual == [member2, member1] + assert actual == [member1, member2] actual = Member.find(Member.first_name == "Kim", Member.last_name == "Brookins").all() assert actual == [member2] @@ -230,7 +230,7 @@ def test_recursive_query_resolution(members): actual = Member.find((Member.last_name == "Brookins") | ( Member.age == 100 ) & (Member.last_name == "Smith")).all() - assert sorted(actual) == [member1, member2, member3] + assert actual == [member1, member2, member3] def test_tag_queries_boolean_logic(members): @@ -239,35 +239,107 @@ def test_tag_queries_boolean_logic(members): actual = Member.find( (Member.first_name == "Andrew") & (Member.last_name == "Brookins") | (Member.last_name == "Smith")).all() - assert sorted(actual) == [member1, member3] + assert actual == [member1, member3] def test_tag_queries_punctuation(): - member = Member( - first_name="Andrew the Michael", + member1 = Member( + first_name="Andrew, the Michael", last_name="St. Brookins-on-Pier", - email="a@example.com", + email="a|b@example.com", # NOTE: This string uses the TAG field separator. age=38, - join_date=today + join_date=today, ) - member.save() + member1.save() - assert Member.find(Member.first_name == "Andrew the Michael").first() == member - assert Member.find(Member.last_name == "St. Brookins-on-Pier").first() == member - assert Member.find(Member.email == "a@example.com").first() == member + member2 = Member( + first_name="Bob", + last_name="the Villain", + email="a|villain@example.com", # NOTE: This string uses the TAG field separator. + age=38, + join_date=today, + ) + member2.save() + + assert Member.find(Member.first_name == "Andrew, the Michael").first() == member1 + assert Member.find(Member.last_name == "St. Brookins-on-Pier").first() == member1 + + # Notice that when we index and query multiple values that use the internal + # TAG separator for single-value exact-match fields, like an indexed string, + # the queries will succeed. We apply a workaround that queries for the union + # of the two values separated by the tag separator. + assert Member.find(Member.email == "a|b@example.com").all() == [member1] + assert Member.find(Member.email == "a|villain@example.com").all() == [member2] def test_tag_queries_negation(members): member1, member2, member3 = members - actual = Member.find( + """ + ┌first_name + NOT EQ┤ + └Andrew + + """ + query = Member.find( + ~(Member.first_name == "Andrew") + ) + assert query.all() == [member2] + + """ + ┌first_name + ┌NOT EQ┤ + | └Andrew + AND┤ + | ┌last_name + └EQ┤ + └Brookins + + """ + query = Member.find( + ~(Member.first_name == "Andrew") & (Member.last_name == "Brookins") + ) + assert query.all() == [member2] + + """ + ┌first_name + ┌NOT EQ┤ + | └Andrew + AND┤ + | ┌last_name + | ┌EQ┤ + | | └Brookins + └OR┤ + | ┌last_name + └EQ┤ + └Smith + """ + query = Member.find( ~(Member.first_name == "Andrew") & - (Member.last_name == "Brookins") | (Member.last_name == "Smith")).all() - assert sorted(actual) == [member2, member3] + ((Member.last_name == "Brookins") | (Member.last_name == "Smith"))) + assert query.all() == [member2] + + """ + ┌first_name + ┌NOT EQ┤ + | └Andrew + ┌AND┤ + | | ┌last_name + | └EQ┤ + | └Brookins + OR┤ + | ┌last_name + └EQ┤ + └Smith + """ + query = Member.find( + ~(Member.first_name == "Andrew") & + (Member.last_name == "Brookins") | (Member.last_name == "Smith")) + assert query.all() == [member2, member3] actual = Member.find( (Member.first_name == "Andrew") & ~(Member.last_name == "Brookins")).all() - assert sorted(actual) == [member3] + assert actual == [member3] def test_numeric_queries(members): @@ -277,7 +349,7 @@ def test_numeric_queries(members): assert actual == [member2] actual = Member.find(Member.age > 34).all() - assert sorted(actual) == [member1, member3] + assert actual == [member1, member3] actual = Member.find(Member.age < 35).all() assert actual == [member2] @@ -289,17 +361,17 @@ def test_numeric_queries(members): assert actual == [member3] actual = Member.find(~(Member.age == 100)).all() - assert sorted(actual) == [member1, member2] + assert actual == [member1, member2] def test_sorting(members): member1, member2, member3 = members actual = Member.find(Member.age > 34).sort_by('age').all() - assert sorted(actual) == [member3, member1] + assert actual == [member1, member3] actual = Member.find(Member.age > 34).sort_by('-age').all() - assert sorted(actual) == [member1, member3] + assert actual == [member3, member1] with pytest.raises(QueryNotSupportedError): # This field does not exist. diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 847239c..1382ea4 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -44,8 +44,8 @@ class Order(BaseJsonModel): class Member(BaseJsonModel): - first_name: str - last_name: str + first_name: str = Field(index=True) + last_name: str = Field(index=True) email: str = Field(index=True) join_date: datetime.date age: int = Field(index=True) @@ -190,7 +190,7 @@ def test_updates_a_model(members): # Or, affecting multiple model instances with an implicit save: Member.find(Member.last_name == "Brookins").update(last_name="Smith") results = Member.find(Member.last_name == "Smith") - assert sorted(results) == members + assert results == members # Or, updating a field in an embedded model: member2.update(address__city="Happy Valley") @@ -200,7 +200,7 @@ def test_updates_a_model(members): def test_paginate_query(members): member1, member2, member3 = members actual = Member.find().all(batch_size=1) - assert sorted(actual) == [member1, member2, member3] + assert actual == [member1, member2, member3] def test_access_result_by_index_cached(members): @@ -233,7 +233,7 @@ def test_exact_match_queries(members): member1, member2, member3 = members actual = Member.find(Member.last_name == "Brookins").all() - assert sorted(actual) == [member1, member2] + assert actual == [member1, member2] actual = Member.find( (Member.last_name == "Brookins") & ~(Member.first_name == "Andrew")).all() @@ -264,7 +264,7 @@ def test_recursive_query_resolution(members): actual = Member.find((Member.last_name == "Brookins") | ( Member.age == 100 ) & (Member.last_name == "Smith")).all() - assert sorted(actual) == [member1, member2, member3] + assert actual == [member1, member2, member3] def test_tag_queries_boolean_logic(members): @@ -273,35 +273,109 @@ def test_tag_queries_boolean_logic(members): actual = Member.find( (Member.first_name == "Andrew") & (Member.last_name == "Brookins") | (Member.last_name == "Smith")).all() - assert sorted(actual) == [member1, member3] + assert actual == [member1, member3] -def test_tag_queries_punctuation(): - member = Member( - first_name="Andrew, the Michael", # This string uses the TAG field separator. +def test_tag_queries_punctuation(address): + member1 = Member( + first_name="Andrew, the Michael", last_name="St. Brookins-on-Pier", - email="a@example.com", + email="a|b@example.com", # NOTE: This string uses the TAG field separator. age=38, - join_date=today + join_date=today, + address=address ) - member.save() + member1.save() - assert Member.find(Member.first_name == "Andrew the Michael").first() == member - assert Member.find(Member.last_name == "St. Brookins-on-Pier").first() == member - assert Member.find(Member.email == "a@example.com").first() == member + member2 = Member( + first_name="Bob", + last_name="the Villain", + email="a|villain@example.com", # NOTE: This string uses the TAG field separator. + age=38, + join_date=today, + address=address + ) + member2.save() + + assert Member.find(Member.first_name == "Andrew, the Michael").first() == member1 + assert Member.find(Member.last_name == "St. Brookins-on-Pier").first() == member1 + + # Notice that when we index and query multiple values that use the internal + # TAG separator for single-value exact-match fields, like an indexed string, + # the queries will succeed. We apply a workaround that queries for the union + # of the two values separated by the tag separator. + assert Member.find(Member.email == "a|b@example.com").all() == [member1] + assert Member.find(Member.email == "a|villain@example.com").all() == [member2] def test_tag_queries_negation(members): member1, member2, member3 = members - actual = Member.find( + """ + ┌first_name + NOT EQ┤ + └Andrew + + """ + query = Member.find( + ~(Member.first_name == "Andrew") + ) + assert query.all() == [member2] + + """ + ┌first_name + ┌NOT EQ┤ + | └Andrew + AND┤ + | ┌last_name + └EQ┤ + └Brookins + + """ + query = Member.find( + ~(Member.first_name == "Andrew") & (Member.last_name == "Brookins") + ) + assert query.all() == [member2] + + """ + ┌first_name + ┌NOT EQ┤ + | └Andrew + AND┤ + | ┌last_name + | ┌EQ┤ + | | └Brookins + └OR┤ + | ┌last_name + └EQ┤ + └Smith + """ + query = Member.find( ~(Member.first_name == "Andrew") & - (Member.last_name == "Brookins") | (Member.last_name == "Smith")).all() - assert sorted(actual) == [member2, member3] + ((Member.last_name == "Brookins") | (Member.last_name == "Smith"))) + assert query.all() == [member2] + + """ + ┌first_name + ┌NOT EQ┤ + | └Andrew + ┌AND┤ + | | ┌last_name + | └EQ┤ + | └Brookins + OR┤ + | ┌last_name + └EQ┤ + └Smith + """ + query = Member.find( + ~(Member.first_name == "Andrew") & + (Member.last_name == "Brookins") | (Member.last_name == "Smith")) + assert query.all() == [member2, member3] actual = Member.find( (Member.first_name == "Andrew") & ~(Member.last_name == "Brookins")).all() - assert sorted(actual) == [member3] + assert actual == [member3] def test_numeric_queries(members): @@ -311,7 +385,7 @@ def test_numeric_queries(members): assert actual == [member2] actual = Member.find(Member.age > 34).all() - assert sorted(actual) == [member1, member3] + assert actual == [member1, member3] actual = Member.find(Member.age < 35).all() assert actual == [member2] @@ -323,17 +397,17 @@ def test_numeric_queries(members): assert actual == [member3] actual = Member.find(~(Member.age == 100)).all() - assert sorted(actual) == [member1, member2] + assert actual == [member1, member2] def test_sorting(members): member1, member2, member3 = members actual = Member.find(Member.age > 34).sort_by('age').all() - assert sorted(actual) == [member3, member1] + assert actual == [member1, member3] actual = Member.find(Member.age > 34).sort_by('-age').all() - assert sorted(actual) == [member1, member3] + assert actual == [member3, member1] with pytest.raises(QueryNotSupportedError): # This field does not exist. @@ -354,7 +428,10 @@ def test_schema(): assert Member.redisearch_schema() == "ON JSON PREFIX 1 " \ "redis-developer:tests.test_json_model.Member: " \ "SCHEMA $.pk AS pk TAG " \ + "$.first_name AS first_name TAG " \ + "$.last_name AS last_name TAG " \ "$.email AS email TAG " \ + "$.age AS age NUMERIC " \ "$.address.pk AS address_pk TAG " \ "$.address.postal_code AS address_postal_code TAG " \ "$.orders[].pk AS orders_pk TAG " \