From 389a6ea878ad9f739a3f3113f76c2e6bcc349444 Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Wed, 13 Oct 2021 17:16:20 -0700 Subject: [PATCH] Add support for IN queries --- redis_developer/orm/model.py | 133 +++++++++++++++++++++-------------- tests/test_hash_model.py | 6 ++ tests/test_json_model.py | 26 ++++++- 3 files changed, 111 insertions(+), 54 deletions(-) diff --git a/redis_developer/orm/model.py b/redis_developer/orm/model.py index 58db040..3b17abc 100644 --- a/redis_developer/orm/model.py +++ b/redis_developer/orm/model.py @@ -4,7 +4,7 @@ import decimal import json import logging import operator -from copy import deepcopy +from copy import deepcopy, copy from enum import Enum from functools import reduce from typing import ( @@ -31,7 +31,7 @@ from pydantic import BaseModel, validator from pydantic.fields import FieldInfo as PydanticFieldInfo from pydantic.fields import ModelField, Undefined, UndefinedType from pydantic.main import ModelMetaclass -from pydantic.typing import NoArgAnyCallable, resolve_annotations +from pydantic.typing import NoArgAnyCallable from pydantic.utils import Representation from ulid import ULID @@ -172,8 +172,8 @@ class NegatedExpression: @dataclasses.dataclass class Expression: op: Operators - left: ExpressionOrModelField - right: ExpressionOrModelField + left: Optional[ExpressionOrModelField] + right: Optional[ExpressionOrModelField] parents: List[Tuple[str, 'RedisModel']] def __invert__(self): @@ -208,26 +208,28 @@ class ExpressionProxy: def __ne__(self, other: Any) -> Expression: # type: ignore[override] return Expression(left=self.field, op=Operators.NE, right=other, parents=self.parents) - def __lt__(self, other: Any) -> Expression: # type: ignore[override] + def __lt__(self, other: Any) -> Expression: return Expression(left=self.field, op=Operators.LT, right=other, parents=self.parents) - def __le__(self, other: Any) -> Expression: # type: ignore[override] + def __le__(self, other: Any) -> Expression: return Expression(left=self.field, op=Operators.LE, right=other, parents=self.parents) - def __gt__(self, other: Any) -> Expression: # type: ignore[override] + def __gt__(self, other: Any) -> Expression: return Expression(left=self.field, op=Operators.GT, right=other, parents=self.parents) - def __ge__(self, other: Any) -> Expression: # type: ignore[override] + def __ge__(self, other: Any) -> Expression: return Expression(left=self.field, op=Operators.GE, right=other, parents=self.parents) - def __mod__(self, other: Any) -> Expression: # type: ignore[override] + def __mod__(self, other: Any) -> Expression: return Expression(left=self.field, op=Operators.LIKE, right=other, parents=self.parents) + def __lshift__(self, other: Any) -> Expression: + return Expression(left=self.field, op=Operators.IN, right=other, parents=self.parents) + def __getattr__(self, item): if get_origin(self.field.outer_type_) == list: embedded_cls = get_args(self.field.outer_type_) if not embedded_cls: - # TODO: Is this even possible? raise QuerySyntaxError("In order to query on a list field, you must define " "the contents of the list with a type annotation, like: " "orders: List[Order]. Docs: TODO") @@ -237,7 +239,7 @@ class ExpressionProxy: attr = getattr(self.field.outer_type_, item) if isinstance(attr, self.__class__): new_parent = (self.field.name, self.field.outer_type_) - if not new_parent in attr.parents: + if new_parent not in attr.parents: attr.parents.append(new_parent) new_parents = list(set(self.parents) - set(attr.parents)) if new_parents: @@ -285,6 +287,21 @@ class FindQuery: self._pagination: list[str] = [] self._model_cache: list[RedisModel] = [] + def dict(self) -> dict[str, Any]: + return dict( + model=self.model, + offset=self.offset, + page_size=self.page_size, + limit=self.limit, + expressions=copy(self.expressions), + sort_fields=copy(self.sort_fields) + ) + + def copy(self, **kwargs): + original = self.dict() + original.update(**kwargs) + return FindQuery(**original) + @property def pagination(self): if self._pagination: @@ -299,15 +316,21 @@ class FindQuery: if self.expressions: self._expression = reduce(operator.and_, self.expressions) else: - # TODO: Is there a better way to support the "give me all records" query? - # Also -- if we do it this way, we need different type annotations. - self._expression = Expression(left=None, right=None, op=Operators.ALL, - parents=[]) + self._expression = Expression(left=None, right=None, op=Operators.ALL, parents=[]) return self._expression @property def query(self): - return self.resolve_redisearch_query(self.expression) + """ + Resolve and return the RediSearch query for this FindQuery. + + NOTE: We cache the resolved query string after generating it. This should be OK + because all mutations of FindQuery through public APIs return a new FindQuery instance. + """ + if self._query: + return self._query + self._query = self.resolve_redisearch_query(self.expression) + return self._query def validate_sort_fields(self, sort_fields): for sort_field in sort_fields: @@ -322,10 +345,10 @@ class FindQuery: return sort_fields @staticmethod - def resolve_field_type(field: ModelField, operator: Operators) -> RediSearchFieldTypes: + def resolve_field_type(field: ModelField, op: Operators) -> RediSearchFieldTypes: if getattr(field.field_info, 'primary_key', None) is True: return RediSearchFieldTypes.TAG - elif operator is Operators.LIKE: + elif op is Operators.LIKE: fts = getattr(field.field_info, 'full_text_search', None) if fts is not True: # Could be PydanticUndefined raise QuerySyntaxError(f"You tried to do a full-text search on the field '{field.name}', " @@ -346,7 +369,7 @@ class FindQuery: @staticmethod def expand_tag_value(value): - if isinstance(str, value): + if isinstance(value, str): return value try: expanded_value = "|".join([escaper.escape(v) for v in value]) @@ -380,8 +403,6 @@ class FindQuery: if op is Operators.EQ: result += f"@{field_name}:[{value} {value}]" elif op is Operators.NE: - # TODO: Is this enough or do we also need a clause for all values - # ([-inf +inf]) from which we then subtract the undesirable value? result += f"-(@{field_name}:[{value} {value}])" elif op is Operators.GT: result += f"@{field_name}:[({value} +inf]" @@ -410,7 +431,7 @@ class FindQuery: # The value contains the TAG field separator. We can work # around this by breaking apart the values and unioning them # with multiple field:{} queries. - values = filter(None, value.split(separator_char)) + values: filter = filter(None, value.split(separator_char)) for value in values: value = escaper.escape(value) result += f"@{field_name}:{{{value}}}" @@ -448,7 +469,29 @@ class FindQuery: @classmethod def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str: - """Resolve an expression to a string RediSearch query.""" + """ + Resolve an arbitrarily deep expression into a single RediSearch query string. + + This method is complex. Note the following: + + 1. This method makes a recursive call to itself when it finds that + either the left or right operand contains another expression. + + 2. An expression might be in a "negated" form, which means that the user + gave us an expression like ~(Member.age == 30), or in other words, + "Members whose age is NOT 30." Thus, a negated expression is one in + which the meaning of an expression is inverted. If we find a negated + expression, we need to add the appropriate "NOT" syntax but can + otherwise use the resolved RediSearch query for the expression as-is. + + 3. The final resolution of an expression should be a left operand that's + a ModelField, an operator, and a right operand that's NOT a ModelField. + With an IN or NOT_IN operator, the right operand can be a sequence + type, but otherwise, sequence types are converted to strings. + + TODO: When the operator is not IN or NOT_IN, detect a sequence type (other + than strings, which are allowed) and raise an exception. + """ field_type = None field_name = None field_info = None @@ -462,7 +505,7 @@ class FindQuery: if expression.op is Operators.ALL: if encompassing_expression_is_negated: # TODO: Is there a use case for this, perhaps for dynamic - # scoring purposes? + # scoring purposes with full-text search? raise QueryNotSupportedError("You cannot negate a query for all results.") return "*" @@ -500,7 +543,13 @@ class FindQuery: result += f"({cls.resolve_redisearch_query(right)})" else: - if isinstance(right, ModelField): + if not field_name: + raise QuerySyntaxError("Could not resolve field name. See docs: TODO") + elif not field_type: + raise QuerySyntaxError("Could not resolve field type. See docs: TODO") + elif not field_info: + raise QuerySyntaxError("Could not resolve field info. See docs: TODO") + elif isinstance(right, ModelField): raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO") else: result += cls.resolve_value(field_name, field_type, field_info, @@ -540,11 +589,7 @@ class FindQuery: while True: # Make a query for each pass of the loop, with a new offset equal to the # current offset plus `page_size`, until we stop getting results back. - query = FindQuery(expressions=query.expressions, - model=query.model, - offset=query.offset + query.page_size, - page_size=query.page_size, - limit=query.limit) + query = query.copy(offset=query.offset + query.page_size) _results = query.execute(exhaust_results=False) if not _results: break @@ -552,8 +597,7 @@ class FindQuery: return self._model_cache def first(self): - query = FindQuery(expressions=self.expressions, model=self.model, - offset=0, limit=1, sort_fields=self.sort_fields) + query = self.copy(offset=0, limit=1, sort_fields=self.sort_fields) results = query.execute() if not results: raise NotFoundError() @@ -561,26 +605,14 @@ class FindQuery: def all(self, batch_size=10): if batch_size != self.page_size: - # TODO: There's probably a copy-with-change mechanism in Pydantic, - # or can we use one from dataclasses? - query = FindQuery(expressions=self.expressions, - model=self.model, - offset=self.offset, - page_size=batch_size, - limit=batch_size, - sort_fields=self.sort_fields) + query = self.copy(page_size=batch_size, limit=batch_size) return query.execute() return self.execute() def sort_by(self, *fields: str): if not fields: return self - return FindQuery(expressions=self.expressions, - model=self.model, - offset=self.offset, - page_size=self.page_size, - limit=self.limit, - sort_fields=list(fields)) + return self.copy(sort_fields=list(fields)) def update(self, **kwargs): """Update all matching records in this query.""" @@ -621,11 +653,7 @@ class FindQuery: if self._model_cache and len(self._model_cache) >= item: return self._model_cache[item] - query = FindQuery(expressions=self.expressions, - model=self.model, - offset=item, - sort_fields=self.sort_fields, - limit=1) + query = self.copy(offset=item, limit=1) return query.execute()[0] @@ -845,7 +873,6 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): pk: Optional[str] = Field(default=None, primary_key=True) Meta = DefaultMeta - # TODO: Missing _meta here is causing IDE warnings. class Config: orm_mode = True @@ -899,7 +926,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): return cls._meta.database @classmethod - def find(cls, *expressions: Union[Any, Expression]) -> FindQuery: # TODO: How to type annotate this? + def find(cls, *expressions: Union[Any, Expression]) -> FindQuery: return FindQuery(expressions=expressions, model=cls) @classmethod diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index 1d55458..49a13eb 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -360,9 +360,15 @@ def test_numeric_queries(members): actual = Member.find(Member.age >= 100).all() assert actual == [member3] + actual = Member.find(Member.age != 34).all() + assert actual == [member1, member3] + actual = Member.find(~(Member.age == 100)).all() assert actual == [member1, member2] + actual = Member.find(Member.age > 30, Member.age < 40).all() + assert actual == [member1, member2] + def test_sorting(members): member1, member2, member3 = members diff --git a/tests/test_json_model.py b/tests/test_json_model.py index fe89db2..7778330 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -12,7 +12,7 @@ from redis_developer.orm import ( JsonModel, Field, ) -from redis_developer.orm.model import RedisModelError, QueryNotSupportedError, NotFoundError, embedded +from redis_developer.orm.model import QueryNotSupportedError, NotFoundError r = redis.Redis() today = datetime.date.today() @@ -240,6 +240,24 @@ def test_access_result_by_index_not_cached(members): assert query[2] == member3 +def test_in_query(members): + member1, member2, member3 = members + actual = Member.find(Member.pk << [member1.pk, member2.pk, member3.pk]).all() + assert actual == [member1, member2, member3] + + +@pytest.mark.skip("Not implemented yet") +def test_update_query(members): + member1, member2, member3 = members + Member.find(Member.pk << [member1.pk, member2.pk, member3.pk]).update( + first_name="Bobby" + ) + actual = Member.find( + Member.pk << [member1.pk, member2.pk, member3.pk]).sort_by('age').all() + assert actual == [member1, member2, member3] + assert all([m.name == "Bobby" for m in actual]) + + def test_exact_match_queries(members): member1, member2, member3 = members @@ -458,6 +476,12 @@ def test_numeric_queries(members): actual = Member.find(~(Member.age == 100)).all() assert actual == [member1, member2] + actual = Member.find(Member.age > 30, Member.age < 40).all() + assert actual == [member1, member2] + + actual = Member.find(Member.age != 34).all() + assert actual == [member1, member3] + def test_sorting(members): member1, member2, member3 = members