Add support for IN queries

This commit is contained in:
Andrew Brookins 2021-10-13 17:16:20 -07:00
parent bb08fb9eb5
commit 389a6ea878
3 changed files with 111 additions and 54 deletions

View file

@ -4,7 +4,7 @@ import decimal
import json
import logging
import operator
from copy import deepcopy
from copy import deepcopy, copy
from enum import Enum
from functools import reduce
from typing import (
@ -31,7 +31,7 @@ from pydantic import BaseModel, validator
from pydantic.fields import FieldInfo as PydanticFieldInfo
from pydantic.fields import ModelField, Undefined, UndefinedType
from pydantic.main import ModelMetaclass
from pydantic.typing import NoArgAnyCallable, resolve_annotations
from pydantic.typing import NoArgAnyCallable
from pydantic.utils import Representation
from ulid import ULID
@ -172,8 +172,8 @@ class NegatedExpression:
@dataclasses.dataclass
class Expression:
op: Operators
left: ExpressionOrModelField
right: ExpressionOrModelField
left: Optional[ExpressionOrModelField]
right: Optional[ExpressionOrModelField]
parents: List[Tuple[str, 'RedisModel']]
def __invert__(self):
@ -208,26 +208,28 @@ class ExpressionProxy:
def __ne__(self, other: Any) -> Expression: # type: ignore[override]
return Expression(left=self.field, op=Operators.NE, right=other, parents=self.parents)
def __lt__(self, other: Any) -> Expression: # type: ignore[override]
def __lt__(self, other: Any) -> Expression:
return Expression(left=self.field, op=Operators.LT, right=other, parents=self.parents)
def __le__(self, other: Any) -> Expression: # type: ignore[override]
def __le__(self, other: Any) -> Expression:
return Expression(left=self.field, op=Operators.LE, right=other, parents=self.parents)
def __gt__(self, other: Any) -> Expression: # type: ignore[override]
def __gt__(self, other: Any) -> Expression:
return Expression(left=self.field, op=Operators.GT, right=other, parents=self.parents)
def __ge__(self, other: Any) -> Expression: # type: ignore[override]
def __ge__(self, other: Any) -> Expression:
return Expression(left=self.field, op=Operators.GE, right=other, parents=self.parents)
def __mod__(self, other: Any) -> Expression: # type: ignore[override]
def __mod__(self, other: Any) -> Expression:
return Expression(left=self.field, op=Operators.LIKE, right=other, parents=self.parents)
def __lshift__(self, other: Any) -> Expression:
return Expression(left=self.field, op=Operators.IN, right=other, parents=self.parents)
def __getattr__(self, item):
if get_origin(self.field.outer_type_) == list:
embedded_cls = get_args(self.field.outer_type_)
if not embedded_cls:
# TODO: Is this even possible?
raise QuerySyntaxError("In order to query on a list field, you must define "
"the contents of the list with a type annotation, like: "
"orders: List[Order]. Docs: TODO")
@ -237,7 +239,7 @@ class ExpressionProxy:
attr = getattr(self.field.outer_type_, item)
if isinstance(attr, self.__class__):
new_parent = (self.field.name, self.field.outer_type_)
if not new_parent in attr.parents:
if new_parent not in attr.parents:
attr.parents.append(new_parent)
new_parents = list(set(self.parents) - set(attr.parents))
if new_parents:
@ -285,6 +287,21 @@ class FindQuery:
self._pagination: list[str] = []
self._model_cache: list[RedisModel] = []
def dict(self) -> dict[str, Any]:
return dict(
model=self.model,
offset=self.offset,
page_size=self.page_size,
limit=self.limit,
expressions=copy(self.expressions),
sort_fields=copy(self.sort_fields)
)
def copy(self, **kwargs):
original = self.dict()
original.update(**kwargs)
return FindQuery(**original)
@property
def pagination(self):
if self._pagination:
@ -299,15 +316,21 @@ class FindQuery:
if self.expressions:
self._expression = reduce(operator.and_, self.expressions)
else:
# TODO: Is there a better way to support the "give me all records" query?
# Also -- if we do it this way, we need different type annotations.
self._expression = Expression(left=None, right=None, op=Operators.ALL,
parents=[])
self._expression = Expression(left=None, right=None, op=Operators.ALL, parents=[])
return self._expression
@property
def query(self):
return self.resolve_redisearch_query(self.expression)
"""
Resolve and return the RediSearch query for this FindQuery.
NOTE: We cache the resolved query string after generating it. This should be OK
because all mutations of FindQuery through public APIs return a new FindQuery instance.
"""
if self._query:
return self._query
self._query = self.resolve_redisearch_query(self.expression)
return self._query
def validate_sort_fields(self, sort_fields):
for sort_field in sort_fields:
@ -322,10 +345,10 @@ class FindQuery:
return sort_fields
@staticmethod
def resolve_field_type(field: ModelField, operator: Operators) -> RediSearchFieldTypes:
def resolve_field_type(field: ModelField, op: Operators) -> RediSearchFieldTypes:
if getattr(field.field_info, 'primary_key', None) is True:
return RediSearchFieldTypes.TAG
elif operator is Operators.LIKE:
elif op is Operators.LIKE:
fts = getattr(field.field_info, 'full_text_search', None)
if fts is not True: # Could be PydanticUndefined
raise QuerySyntaxError(f"You tried to do a full-text search on the field '{field.name}', "
@ -346,7 +369,7 @@ class FindQuery:
@staticmethod
def expand_tag_value(value):
if isinstance(str, value):
if isinstance(value, str):
return value
try:
expanded_value = "|".join([escaper.escape(v) for v in value])
@ -380,8 +403,6 @@ class FindQuery:
if op is Operators.EQ:
result += f"@{field_name}:[{value} {value}]"
elif op is Operators.NE:
# TODO: Is this enough or do we also need a clause for all values
# ([-inf +inf]) from which we then subtract the undesirable value?
result += f"-(@{field_name}:[{value} {value}])"
elif op is Operators.GT:
result += f"@{field_name}:[({value} +inf]"
@ -410,7 +431,7 @@ class FindQuery:
# The value contains the TAG field separator. We can work
# around this by breaking apart the values and unioning them
# with multiple field:{} queries.
values = filter(None, value.split(separator_char))
values: filter = filter(None, value.split(separator_char))
for value in values:
value = escaper.escape(value)
result += f"@{field_name}:{{{value}}}"
@ -448,7 +469,29 @@ class FindQuery:
@classmethod
def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
"""Resolve an expression to a string RediSearch query."""
"""
Resolve an arbitrarily deep expression into a single RediSearch query string.
This method is complex. Note the following:
1. This method makes a recursive call to itself when it finds that
either the left or right operand contains another expression.
2. An expression might be in a "negated" form, which means that the user
gave us an expression like ~(Member.age == 30), or in other words,
"Members whose age is NOT 30." Thus, a negated expression is one in
which the meaning of an expression is inverted. If we find a negated
expression, we need to add the appropriate "NOT" syntax but can
otherwise use the resolved RediSearch query for the expression as-is.
3. The final resolution of an expression should be a left operand that's
a ModelField, an operator, and a right operand that's NOT a ModelField.
With an IN or NOT_IN operator, the right operand can be a sequence
type, but otherwise, sequence types are converted to strings.
TODO: When the operator is not IN or NOT_IN, detect a sequence type (other
than strings, which are allowed) and raise an exception.
"""
field_type = None
field_name = None
field_info = None
@ -462,7 +505,7 @@ class FindQuery:
if expression.op is Operators.ALL:
if encompassing_expression_is_negated:
# TODO: Is there a use case for this, perhaps for dynamic
# scoring purposes?
# scoring purposes with full-text search?
raise QueryNotSupportedError("You cannot negate a query for all results.")
return "*"
@ -500,7 +543,13 @@ class FindQuery:
result += f"({cls.resolve_redisearch_query(right)})"
else:
if isinstance(right, ModelField):
if not field_name:
raise QuerySyntaxError("Could not resolve field name. See docs: TODO")
elif not field_type:
raise QuerySyntaxError("Could not resolve field type. See docs: TODO")
elif not field_info:
raise QuerySyntaxError("Could not resolve field info. See docs: TODO")
elif isinstance(right, ModelField):
raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO")
else:
result += cls.resolve_value(field_name, field_type, field_info,
@ -540,11 +589,7 @@ class FindQuery:
while True:
# Make a query for each pass of the loop, with a new offset equal to the
# current offset plus `page_size`, until we stop getting results back.
query = FindQuery(expressions=query.expressions,
model=query.model,
offset=query.offset + query.page_size,
page_size=query.page_size,
limit=query.limit)
query = query.copy(offset=query.offset + query.page_size)
_results = query.execute(exhaust_results=False)
if not _results:
break
@ -552,8 +597,7 @@ class FindQuery:
return self._model_cache
def first(self):
query = FindQuery(expressions=self.expressions, model=self.model,
offset=0, limit=1, sort_fields=self.sort_fields)
query = self.copy(offset=0, limit=1, sort_fields=self.sort_fields)
results = query.execute()
if not results:
raise NotFoundError()
@ -561,26 +605,14 @@ class FindQuery:
def all(self, batch_size=10):
if batch_size != self.page_size:
# TODO: There's probably a copy-with-change mechanism in Pydantic,
# or can we use one from dataclasses?
query = FindQuery(expressions=self.expressions,
model=self.model,
offset=self.offset,
page_size=batch_size,
limit=batch_size,
sort_fields=self.sort_fields)
query = self.copy(page_size=batch_size, limit=batch_size)
return query.execute()
return self.execute()
def sort_by(self, *fields: str):
if not fields:
return self
return FindQuery(expressions=self.expressions,
model=self.model,
offset=self.offset,
page_size=self.page_size,
limit=self.limit,
sort_fields=list(fields))
return self.copy(sort_fields=list(fields))
def update(self, **kwargs):
"""Update all matching records in this query."""
@ -621,11 +653,7 @@ class FindQuery:
if self._model_cache and len(self._model_cache) >= item:
return self._model_cache[item]
query = FindQuery(expressions=self.expressions,
model=self.model,
offset=item,
sort_fields=self.sort_fields,
limit=1)
query = self.copy(offset=item, limit=1)
return query.execute()[0]
@ -845,7 +873,6 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
pk: Optional[str] = Field(default=None, primary_key=True)
Meta = DefaultMeta
# TODO: Missing _meta here is causing IDE warnings.
class Config:
orm_mode = True
@ -899,7 +926,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
return cls._meta.database
@classmethod
def find(cls, *expressions: Union[Any, Expression]) -> FindQuery: # TODO: How to type annotate this?
def find(cls, *expressions: Union[Any, Expression]) -> FindQuery:
return FindQuery(expressions=expressions, model=cls)
@classmethod

View file

@ -360,9 +360,15 @@ def test_numeric_queries(members):
actual = Member.find(Member.age >= 100).all()
assert actual == [member3]
actual = Member.find(Member.age != 34).all()
assert actual == [member1, member3]
actual = Member.find(~(Member.age == 100)).all()
assert actual == [member1, member2]
actual = Member.find(Member.age > 30, Member.age < 40).all()
assert actual == [member1, member2]
def test_sorting(members):
member1, member2, member3 = members

View file

@ -12,7 +12,7 @@ from redis_developer.orm import (
JsonModel,
Field,
)
from redis_developer.orm.model import RedisModelError, QueryNotSupportedError, NotFoundError, embedded
from redis_developer.orm.model import QueryNotSupportedError, NotFoundError
r = redis.Redis()
today = datetime.date.today()
@ -240,6 +240,24 @@ def test_access_result_by_index_not_cached(members):
assert query[2] == member3
def test_in_query(members):
member1, member2, member3 = members
actual = Member.find(Member.pk << [member1.pk, member2.pk, member3.pk]).all()
assert actual == [member1, member2, member3]
@pytest.mark.skip("Not implemented yet")
def test_update_query(members):
member1, member2, member3 = members
Member.find(Member.pk << [member1.pk, member2.pk, member3.pk]).update(
first_name="Bobby"
)
actual = Member.find(
Member.pk << [member1.pk, member2.pk, member3.pk]).sort_by('age').all()
assert actual == [member1, member2, member3]
assert all([m.name == "Bobby" for m in actual])
def test_exact_match_queries(members):
member1, member2, member3 = members
@ -458,6 +476,12 @@ def test_numeric_queries(members):
actual = Member.find(~(Member.age == 100)).all()
assert actual == [member1, member2]
actual = Member.find(Member.age > 30, Member.age < 40).all()
assert actual == [member1, member2]
actual = Member.find(Member.age != 34).all()
assert actual == [member1, member3]
def test_sorting(members):
member1, member2, member3 = members