Add support for IN queries

This commit is contained in:
Andrew Brookins 2021-10-13 17:16:20 -07:00
parent bb08fb9eb5
commit 389a6ea878
3 changed files with 111 additions and 54 deletions

View file

@ -4,7 +4,7 @@ import decimal
import json import json
import logging import logging
import operator import operator
from copy import deepcopy from copy import deepcopy, copy
from enum import Enum from enum import Enum
from functools import reduce from functools import reduce
from typing import ( from typing import (
@ -31,7 +31,7 @@ from pydantic import BaseModel, validator
from pydantic.fields import FieldInfo as PydanticFieldInfo from pydantic.fields import FieldInfo as PydanticFieldInfo
from pydantic.fields import ModelField, Undefined, UndefinedType from pydantic.fields import ModelField, Undefined, UndefinedType
from pydantic.main import ModelMetaclass from pydantic.main import ModelMetaclass
from pydantic.typing import NoArgAnyCallable, resolve_annotations from pydantic.typing import NoArgAnyCallable
from pydantic.utils import Representation from pydantic.utils import Representation
from ulid import ULID from ulid import ULID
@ -172,8 +172,8 @@ class NegatedExpression:
@dataclasses.dataclass @dataclasses.dataclass
class Expression: class Expression:
op: Operators op: Operators
left: ExpressionOrModelField left: Optional[ExpressionOrModelField]
right: ExpressionOrModelField right: Optional[ExpressionOrModelField]
parents: List[Tuple[str, 'RedisModel']] parents: List[Tuple[str, 'RedisModel']]
def __invert__(self): def __invert__(self):
@ -208,26 +208,28 @@ class ExpressionProxy:
def __ne__(self, other: Any) -> Expression: # type: ignore[override] def __ne__(self, other: Any) -> Expression: # type: ignore[override]
return Expression(left=self.field, op=Operators.NE, right=other, parents=self.parents) return Expression(left=self.field, op=Operators.NE, right=other, parents=self.parents)
def __lt__(self, other: Any) -> Expression: # type: ignore[override] def __lt__(self, other: Any) -> Expression:
return Expression(left=self.field, op=Operators.LT, right=other, parents=self.parents) return Expression(left=self.field, op=Operators.LT, right=other, parents=self.parents)
def __le__(self, other: Any) -> Expression: # type: ignore[override] def __le__(self, other: Any) -> Expression:
return Expression(left=self.field, op=Operators.LE, right=other, parents=self.parents) return Expression(left=self.field, op=Operators.LE, right=other, parents=self.parents)
def __gt__(self, other: Any) -> Expression: # type: ignore[override] def __gt__(self, other: Any) -> Expression:
return Expression(left=self.field, op=Operators.GT, right=other, parents=self.parents) return Expression(left=self.field, op=Operators.GT, right=other, parents=self.parents)
def __ge__(self, other: Any) -> Expression: # type: ignore[override] def __ge__(self, other: Any) -> Expression:
return Expression(left=self.field, op=Operators.GE, right=other, parents=self.parents) return Expression(left=self.field, op=Operators.GE, right=other, parents=self.parents)
def __mod__(self, other: Any) -> Expression: # type: ignore[override] def __mod__(self, other: Any) -> Expression:
return Expression(left=self.field, op=Operators.LIKE, right=other, parents=self.parents) return Expression(left=self.field, op=Operators.LIKE, right=other, parents=self.parents)
def __lshift__(self, other: Any) -> Expression:
return Expression(left=self.field, op=Operators.IN, right=other, parents=self.parents)
def __getattr__(self, item): def __getattr__(self, item):
if get_origin(self.field.outer_type_) == list: if get_origin(self.field.outer_type_) == list:
embedded_cls = get_args(self.field.outer_type_) embedded_cls = get_args(self.field.outer_type_)
if not embedded_cls: if not embedded_cls:
# TODO: Is this even possible?
raise QuerySyntaxError("In order to query on a list field, you must define " raise QuerySyntaxError("In order to query on a list field, you must define "
"the contents of the list with a type annotation, like: " "the contents of the list with a type annotation, like: "
"orders: List[Order]. Docs: TODO") "orders: List[Order]. Docs: TODO")
@ -237,7 +239,7 @@ class ExpressionProxy:
attr = getattr(self.field.outer_type_, item) attr = getattr(self.field.outer_type_, item)
if isinstance(attr, self.__class__): if isinstance(attr, self.__class__):
new_parent = (self.field.name, self.field.outer_type_) new_parent = (self.field.name, self.field.outer_type_)
if not new_parent in attr.parents: if new_parent not in attr.parents:
attr.parents.append(new_parent) attr.parents.append(new_parent)
new_parents = list(set(self.parents) - set(attr.parents)) new_parents = list(set(self.parents) - set(attr.parents))
if new_parents: if new_parents:
@ -285,6 +287,21 @@ class FindQuery:
self._pagination: list[str] = [] self._pagination: list[str] = []
self._model_cache: list[RedisModel] = [] self._model_cache: list[RedisModel] = []
def dict(self) -> dict[str, Any]:
return dict(
model=self.model,
offset=self.offset,
page_size=self.page_size,
limit=self.limit,
expressions=copy(self.expressions),
sort_fields=copy(self.sort_fields)
)
def copy(self, **kwargs):
original = self.dict()
original.update(**kwargs)
return FindQuery(**original)
@property @property
def pagination(self): def pagination(self):
if self._pagination: if self._pagination:
@ -299,15 +316,21 @@ class FindQuery:
if self.expressions: if self.expressions:
self._expression = reduce(operator.and_, self.expressions) self._expression = reduce(operator.and_, self.expressions)
else: else:
# TODO: Is there a better way to support the "give me all records" query? self._expression = Expression(left=None, right=None, op=Operators.ALL, parents=[])
# Also -- if we do it this way, we need different type annotations.
self._expression = Expression(left=None, right=None, op=Operators.ALL,
parents=[])
return self._expression return self._expression
@property @property
def query(self): def query(self):
return self.resolve_redisearch_query(self.expression) """
Resolve and return the RediSearch query for this FindQuery.
NOTE: We cache the resolved query string after generating it. This should be OK
because all mutations of FindQuery through public APIs return a new FindQuery instance.
"""
if self._query:
return self._query
self._query = self.resolve_redisearch_query(self.expression)
return self._query
def validate_sort_fields(self, sort_fields): def validate_sort_fields(self, sort_fields):
for sort_field in sort_fields: for sort_field in sort_fields:
@ -322,10 +345,10 @@ class FindQuery:
return sort_fields return sort_fields
@staticmethod @staticmethod
def resolve_field_type(field: ModelField, operator: Operators) -> RediSearchFieldTypes: def resolve_field_type(field: ModelField, op: Operators) -> RediSearchFieldTypes:
if getattr(field.field_info, 'primary_key', None) is True: if getattr(field.field_info, 'primary_key', None) is True:
return RediSearchFieldTypes.TAG return RediSearchFieldTypes.TAG
elif operator is Operators.LIKE: elif op is Operators.LIKE:
fts = getattr(field.field_info, 'full_text_search', None) fts = getattr(field.field_info, 'full_text_search', None)
if fts is not True: # Could be PydanticUndefined if fts is not True: # Could be PydanticUndefined
raise QuerySyntaxError(f"You tried to do a full-text search on the field '{field.name}', " raise QuerySyntaxError(f"You tried to do a full-text search on the field '{field.name}', "
@ -346,7 +369,7 @@ class FindQuery:
@staticmethod @staticmethod
def expand_tag_value(value): def expand_tag_value(value):
if isinstance(str, value): if isinstance(value, str):
return value return value
try: try:
expanded_value = "|".join([escaper.escape(v) for v in value]) expanded_value = "|".join([escaper.escape(v) for v in value])
@ -380,8 +403,6 @@ class FindQuery:
if op is Operators.EQ: if op is Operators.EQ:
result += f"@{field_name}:[{value} {value}]" result += f"@{field_name}:[{value} {value}]"
elif op is Operators.NE: elif op is Operators.NE:
# TODO: Is this enough or do we also need a clause for all values
# ([-inf +inf]) from which we then subtract the undesirable value?
result += f"-(@{field_name}:[{value} {value}])" result += f"-(@{field_name}:[{value} {value}])"
elif op is Operators.GT: elif op is Operators.GT:
result += f"@{field_name}:[({value} +inf]" result += f"@{field_name}:[({value} +inf]"
@ -410,7 +431,7 @@ class FindQuery:
# The value contains the TAG field separator. We can work # The value contains the TAG field separator. We can work
# around this by breaking apart the values and unioning them # around this by breaking apart the values and unioning them
# with multiple field:{} queries. # with multiple field:{} queries.
values = filter(None, value.split(separator_char)) values: filter = filter(None, value.split(separator_char))
for value in values: for value in values:
value = escaper.escape(value) value = escaper.escape(value)
result += f"@{field_name}:{{{value}}}" result += f"@{field_name}:{{{value}}}"
@ -448,7 +469,29 @@ class FindQuery:
@classmethod @classmethod
def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str: def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
"""Resolve an expression to a string RediSearch query.""" """
Resolve an arbitrarily deep expression into a single RediSearch query string.
This method is complex. Note the following:
1. This method makes a recursive call to itself when it finds that
either the left or right operand contains another expression.
2. An expression might be in a "negated" form, which means that the user
gave us an expression like ~(Member.age == 30), or in other words,
"Members whose age is NOT 30." Thus, a negated expression is one in
which the meaning of an expression is inverted. If we find a negated
expression, we need to add the appropriate "NOT" syntax but can
otherwise use the resolved RediSearch query for the expression as-is.
3. The final resolution of an expression should be a left operand that's
a ModelField, an operator, and a right operand that's NOT a ModelField.
With an IN or NOT_IN operator, the right operand can be a sequence
type, but otherwise, sequence types are converted to strings.
TODO: When the operator is not IN or NOT_IN, detect a sequence type (other
than strings, which are allowed) and raise an exception.
"""
field_type = None field_type = None
field_name = None field_name = None
field_info = None field_info = None
@ -462,7 +505,7 @@ class FindQuery:
if expression.op is Operators.ALL: if expression.op is Operators.ALL:
if encompassing_expression_is_negated: if encompassing_expression_is_negated:
# TODO: Is there a use case for this, perhaps for dynamic # TODO: Is there a use case for this, perhaps for dynamic
# scoring purposes? # scoring purposes with full-text search?
raise QueryNotSupportedError("You cannot negate a query for all results.") raise QueryNotSupportedError("You cannot negate a query for all results.")
return "*" return "*"
@ -500,7 +543,13 @@ class FindQuery:
result += f"({cls.resolve_redisearch_query(right)})" result += f"({cls.resolve_redisearch_query(right)})"
else: else:
if isinstance(right, ModelField): if not field_name:
raise QuerySyntaxError("Could not resolve field name. See docs: TODO")
elif not field_type:
raise QuerySyntaxError("Could not resolve field type. See docs: TODO")
elif not field_info:
raise QuerySyntaxError("Could not resolve field info. See docs: TODO")
elif isinstance(right, ModelField):
raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO") raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO")
else: else:
result += cls.resolve_value(field_name, field_type, field_info, result += cls.resolve_value(field_name, field_type, field_info,
@ -540,11 +589,7 @@ class FindQuery:
while True: while True:
# Make a query for each pass of the loop, with a new offset equal to the # Make a query for each pass of the loop, with a new offset equal to the
# current offset plus `page_size`, until we stop getting results back. # current offset plus `page_size`, until we stop getting results back.
query = FindQuery(expressions=query.expressions, query = query.copy(offset=query.offset + query.page_size)
model=query.model,
offset=query.offset + query.page_size,
page_size=query.page_size,
limit=query.limit)
_results = query.execute(exhaust_results=False) _results = query.execute(exhaust_results=False)
if not _results: if not _results:
break break
@ -552,8 +597,7 @@ class FindQuery:
return self._model_cache return self._model_cache
def first(self): def first(self):
query = FindQuery(expressions=self.expressions, model=self.model, query = self.copy(offset=0, limit=1, sort_fields=self.sort_fields)
offset=0, limit=1, sort_fields=self.sort_fields)
results = query.execute() results = query.execute()
if not results: if not results:
raise NotFoundError() raise NotFoundError()
@ -561,26 +605,14 @@ class FindQuery:
def all(self, batch_size=10): def all(self, batch_size=10):
if batch_size != self.page_size: if batch_size != self.page_size:
# TODO: There's probably a copy-with-change mechanism in Pydantic, query = self.copy(page_size=batch_size, limit=batch_size)
# or can we use one from dataclasses?
query = FindQuery(expressions=self.expressions,
model=self.model,
offset=self.offset,
page_size=batch_size,
limit=batch_size,
sort_fields=self.sort_fields)
return query.execute() return query.execute()
return self.execute() return self.execute()
def sort_by(self, *fields: str): def sort_by(self, *fields: str):
if not fields: if not fields:
return self return self
return FindQuery(expressions=self.expressions, return self.copy(sort_fields=list(fields))
model=self.model,
offset=self.offset,
page_size=self.page_size,
limit=self.limit,
sort_fields=list(fields))
def update(self, **kwargs): def update(self, **kwargs):
"""Update all matching records in this query.""" """Update all matching records in this query."""
@ -621,11 +653,7 @@ class FindQuery:
if self._model_cache and len(self._model_cache) >= item: if self._model_cache and len(self._model_cache) >= item:
return self._model_cache[item] return self._model_cache[item]
query = FindQuery(expressions=self.expressions, query = self.copy(offset=item, limit=1)
model=self.model,
offset=item,
sort_fields=self.sort_fields,
limit=1)
return query.execute()[0] return query.execute()[0]
@ -845,7 +873,6 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
pk: Optional[str] = Field(default=None, primary_key=True) pk: Optional[str] = Field(default=None, primary_key=True)
Meta = DefaultMeta Meta = DefaultMeta
# TODO: Missing _meta here is causing IDE warnings.
class Config: class Config:
orm_mode = True orm_mode = True
@ -899,7 +926,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
return cls._meta.database return cls._meta.database
@classmethod @classmethod
def find(cls, *expressions: Union[Any, Expression]) -> FindQuery: # TODO: How to type annotate this? def find(cls, *expressions: Union[Any, Expression]) -> FindQuery:
return FindQuery(expressions=expressions, model=cls) return FindQuery(expressions=expressions, model=cls)
@classmethod @classmethod

View file

@ -360,9 +360,15 @@ def test_numeric_queries(members):
actual = Member.find(Member.age >= 100).all() actual = Member.find(Member.age >= 100).all()
assert actual == [member3] assert actual == [member3]
actual = Member.find(Member.age != 34).all()
assert actual == [member1, member3]
actual = Member.find(~(Member.age == 100)).all() actual = Member.find(~(Member.age == 100)).all()
assert actual == [member1, member2] assert actual == [member1, member2]
actual = Member.find(Member.age > 30, Member.age < 40).all()
assert actual == [member1, member2]
def test_sorting(members): def test_sorting(members):
member1, member2, member3 = members member1, member2, member3 = members

View file

@ -12,7 +12,7 @@ from redis_developer.orm import (
JsonModel, JsonModel,
Field, Field,
) )
from redis_developer.orm.model import RedisModelError, QueryNotSupportedError, NotFoundError, embedded from redis_developer.orm.model import QueryNotSupportedError, NotFoundError
r = redis.Redis() r = redis.Redis()
today = datetime.date.today() today = datetime.date.today()
@ -240,6 +240,24 @@ def test_access_result_by_index_not_cached(members):
assert query[2] == member3 assert query[2] == member3
def test_in_query(members):
member1, member2, member3 = members
actual = Member.find(Member.pk << [member1.pk, member2.pk, member3.pk]).all()
assert actual == [member1, member2, member3]
@pytest.mark.skip("Not implemented yet")
def test_update_query(members):
member1, member2, member3 = members
Member.find(Member.pk << [member1.pk, member2.pk, member3.pk]).update(
first_name="Bobby"
)
actual = Member.find(
Member.pk << [member1.pk, member2.pk, member3.pk]).sort_by('age').all()
assert actual == [member1, member2, member3]
assert all([m.name == "Bobby" for m in actual])
def test_exact_match_queries(members): def test_exact_match_queries(members):
member1, member2, member3 = members member1, member2, member3 = members
@ -458,6 +476,12 @@ def test_numeric_queries(members):
actual = Member.find(~(Member.age == 100)).all() actual = Member.find(~(Member.age == 100)).all()
assert actual == [member1, member2] assert actual == [member1, member2]
actual = Member.find(Member.age > 30, Member.age < 40).all()
assert actual == [member1, member2]
actual = Member.find(Member.age != 34).all()
assert actual == [member1, member3]
def test_sorting(members): def test_sorting(members):
member1, member2, member3 = members member1, member2, member3 = members