Add support for IN queries
This commit is contained in:
parent
bb08fb9eb5
commit
389a6ea878
3 changed files with 111 additions and 54 deletions
|
@ -4,7 +4,7 @@ import decimal
|
|||
import json
|
||||
import logging
|
||||
import operator
|
||||
from copy import deepcopy
|
||||
from copy import deepcopy, copy
|
||||
from enum import Enum
|
||||
from functools import reduce
|
||||
from typing import (
|
||||
|
@ -31,7 +31,7 @@ from pydantic import BaseModel, validator
|
|||
from pydantic.fields import FieldInfo as PydanticFieldInfo
|
||||
from pydantic.fields import ModelField, Undefined, UndefinedType
|
||||
from pydantic.main import ModelMetaclass
|
||||
from pydantic.typing import NoArgAnyCallable, resolve_annotations
|
||||
from pydantic.typing import NoArgAnyCallable
|
||||
from pydantic.utils import Representation
|
||||
from ulid import ULID
|
||||
|
||||
|
@ -172,8 +172,8 @@ class NegatedExpression:
|
|||
@dataclasses.dataclass
|
||||
class Expression:
|
||||
op: Operators
|
||||
left: ExpressionOrModelField
|
||||
right: ExpressionOrModelField
|
||||
left: Optional[ExpressionOrModelField]
|
||||
right: Optional[ExpressionOrModelField]
|
||||
parents: List[Tuple[str, 'RedisModel']]
|
||||
|
||||
def __invert__(self):
|
||||
|
@ -208,26 +208,28 @@ class ExpressionProxy:
|
|||
def __ne__(self, other: Any) -> Expression: # type: ignore[override]
|
||||
return Expression(left=self.field, op=Operators.NE, right=other, parents=self.parents)
|
||||
|
||||
def __lt__(self, other: Any) -> Expression: # type: ignore[override]
|
||||
def __lt__(self, other: Any) -> Expression:
|
||||
return Expression(left=self.field, op=Operators.LT, right=other, parents=self.parents)
|
||||
|
||||
def __le__(self, other: Any) -> Expression: # type: ignore[override]
|
||||
def __le__(self, other: Any) -> Expression:
|
||||
return Expression(left=self.field, op=Operators.LE, right=other, parents=self.parents)
|
||||
|
||||
def __gt__(self, other: Any) -> Expression: # type: ignore[override]
|
||||
def __gt__(self, other: Any) -> Expression:
|
||||
return Expression(left=self.field, op=Operators.GT, right=other, parents=self.parents)
|
||||
|
||||
def __ge__(self, other: Any) -> Expression: # type: ignore[override]
|
||||
def __ge__(self, other: Any) -> Expression:
|
||||
return Expression(left=self.field, op=Operators.GE, right=other, parents=self.parents)
|
||||
|
||||
def __mod__(self, other: Any) -> Expression: # type: ignore[override]
|
||||
def __mod__(self, other: Any) -> Expression:
|
||||
return Expression(left=self.field, op=Operators.LIKE, right=other, parents=self.parents)
|
||||
|
||||
def __lshift__(self, other: Any) -> Expression:
|
||||
return Expression(left=self.field, op=Operators.IN, right=other, parents=self.parents)
|
||||
|
||||
def __getattr__(self, item):
|
||||
if get_origin(self.field.outer_type_) == list:
|
||||
embedded_cls = get_args(self.field.outer_type_)
|
||||
if not embedded_cls:
|
||||
# TODO: Is this even possible?
|
||||
raise QuerySyntaxError("In order to query on a list field, you must define "
|
||||
"the contents of the list with a type annotation, like: "
|
||||
"orders: List[Order]. Docs: TODO")
|
||||
|
@ -237,7 +239,7 @@ class ExpressionProxy:
|
|||
attr = getattr(self.field.outer_type_, item)
|
||||
if isinstance(attr, self.__class__):
|
||||
new_parent = (self.field.name, self.field.outer_type_)
|
||||
if not new_parent in attr.parents:
|
||||
if new_parent not in attr.parents:
|
||||
attr.parents.append(new_parent)
|
||||
new_parents = list(set(self.parents) - set(attr.parents))
|
||||
if new_parents:
|
||||
|
@ -285,6 +287,21 @@ class FindQuery:
|
|||
self._pagination: list[str] = []
|
||||
self._model_cache: list[RedisModel] = []
|
||||
|
||||
def dict(self) -> dict[str, Any]:
|
||||
return dict(
|
||||
model=self.model,
|
||||
offset=self.offset,
|
||||
page_size=self.page_size,
|
||||
limit=self.limit,
|
||||
expressions=copy(self.expressions),
|
||||
sort_fields=copy(self.sort_fields)
|
||||
)
|
||||
|
||||
def copy(self, **kwargs):
|
||||
original = self.dict()
|
||||
original.update(**kwargs)
|
||||
return FindQuery(**original)
|
||||
|
||||
@property
|
||||
def pagination(self):
|
||||
if self._pagination:
|
||||
|
@ -299,15 +316,21 @@ class FindQuery:
|
|||
if self.expressions:
|
||||
self._expression = reduce(operator.and_, self.expressions)
|
||||
else:
|
||||
# TODO: Is there a better way to support the "give me all records" query?
|
||||
# Also -- if we do it this way, we need different type annotations.
|
||||
self._expression = Expression(left=None, right=None, op=Operators.ALL,
|
||||
parents=[])
|
||||
self._expression = Expression(left=None, right=None, op=Operators.ALL, parents=[])
|
||||
return self._expression
|
||||
|
||||
@property
|
||||
def query(self):
|
||||
return self.resolve_redisearch_query(self.expression)
|
||||
"""
|
||||
Resolve and return the RediSearch query for this FindQuery.
|
||||
|
||||
NOTE: We cache the resolved query string after generating it. This should be OK
|
||||
because all mutations of FindQuery through public APIs return a new FindQuery instance.
|
||||
"""
|
||||
if self._query:
|
||||
return self._query
|
||||
self._query = self.resolve_redisearch_query(self.expression)
|
||||
return self._query
|
||||
|
||||
def validate_sort_fields(self, sort_fields):
|
||||
for sort_field in sort_fields:
|
||||
|
@ -322,10 +345,10 @@ class FindQuery:
|
|||
return sort_fields
|
||||
|
||||
@staticmethod
|
||||
def resolve_field_type(field: ModelField, operator: Operators) -> RediSearchFieldTypes:
|
||||
def resolve_field_type(field: ModelField, op: Operators) -> RediSearchFieldTypes:
|
||||
if getattr(field.field_info, 'primary_key', None) is True:
|
||||
return RediSearchFieldTypes.TAG
|
||||
elif operator is Operators.LIKE:
|
||||
elif op is Operators.LIKE:
|
||||
fts = getattr(field.field_info, 'full_text_search', None)
|
||||
if fts is not True: # Could be PydanticUndefined
|
||||
raise QuerySyntaxError(f"You tried to do a full-text search on the field '{field.name}', "
|
||||
|
@ -346,7 +369,7 @@ class FindQuery:
|
|||
|
||||
@staticmethod
|
||||
def expand_tag_value(value):
|
||||
if isinstance(str, value):
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
try:
|
||||
expanded_value = "|".join([escaper.escape(v) for v in value])
|
||||
|
@ -380,8 +403,6 @@ class FindQuery:
|
|||
if op is Operators.EQ:
|
||||
result += f"@{field_name}:[{value} {value}]"
|
||||
elif op is Operators.NE:
|
||||
# TODO: Is this enough or do we also need a clause for all values
|
||||
# ([-inf +inf]) from which we then subtract the undesirable value?
|
||||
result += f"-(@{field_name}:[{value} {value}])"
|
||||
elif op is Operators.GT:
|
||||
result += f"@{field_name}:[({value} +inf]"
|
||||
|
@ -410,7 +431,7 @@ class FindQuery:
|
|||
# The value contains the TAG field separator. We can work
|
||||
# around this by breaking apart the values and unioning them
|
||||
# with multiple field:{} queries.
|
||||
values = filter(None, value.split(separator_char))
|
||||
values: filter = filter(None, value.split(separator_char))
|
||||
for value in values:
|
||||
value = escaper.escape(value)
|
||||
result += f"@{field_name}:{{{value}}}"
|
||||
|
@ -448,7 +469,29 @@ class FindQuery:
|
|||
|
||||
@classmethod
|
||||
def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
|
||||
"""Resolve an expression to a string RediSearch query."""
|
||||
"""
|
||||
Resolve an arbitrarily deep expression into a single RediSearch query string.
|
||||
|
||||
This method is complex. Note the following:
|
||||
|
||||
1. This method makes a recursive call to itself when it finds that
|
||||
either the left or right operand contains another expression.
|
||||
|
||||
2. An expression might be in a "negated" form, which means that the user
|
||||
gave us an expression like ~(Member.age == 30), or in other words,
|
||||
"Members whose age is NOT 30." Thus, a negated expression is one in
|
||||
which the meaning of an expression is inverted. If we find a negated
|
||||
expression, we need to add the appropriate "NOT" syntax but can
|
||||
otherwise use the resolved RediSearch query for the expression as-is.
|
||||
|
||||
3. The final resolution of an expression should be a left operand that's
|
||||
a ModelField, an operator, and a right operand that's NOT a ModelField.
|
||||
With an IN or NOT_IN operator, the right operand can be a sequence
|
||||
type, but otherwise, sequence types are converted to strings.
|
||||
|
||||
TODO: When the operator is not IN or NOT_IN, detect a sequence type (other
|
||||
than strings, which are allowed) and raise an exception.
|
||||
"""
|
||||
field_type = None
|
||||
field_name = None
|
||||
field_info = None
|
||||
|
@ -462,7 +505,7 @@ class FindQuery:
|
|||
if expression.op is Operators.ALL:
|
||||
if encompassing_expression_is_negated:
|
||||
# TODO: Is there a use case for this, perhaps for dynamic
|
||||
# scoring purposes?
|
||||
# scoring purposes with full-text search?
|
||||
raise QueryNotSupportedError("You cannot negate a query for all results.")
|
||||
return "*"
|
||||
|
||||
|
@ -500,7 +543,13 @@ class FindQuery:
|
|||
|
||||
result += f"({cls.resolve_redisearch_query(right)})"
|
||||
else:
|
||||
if isinstance(right, ModelField):
|
||||
if not field_name:
|
||||
raise QuerySyntaxError("Could not resolve field name. See docs: TODO")
|
||||
elif not field_type:
|
||||
raise QuerySyntaxError("Could not resolve field type. See docs: TODO")
|
||||
elif not field_info:
|
||||
raise QuerySyntaxError("Could not resolve field info. See docs: TODO")
|
||||
elif isinstance(right, ModelField):
|
||||
raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO")
|
||||
else:
|
||||
result += cls.resolve_value(field_name, field_type, field_info,
|
||||
|
@ -540,11 +589,7 @@ class FindQuery:
|
|||
while True:
|
||||
# Make a query for each pass of the loop, with a new offset equal to the
|
||||
# current offset plus `page_size`, until we stop getting results back.
|
||||
query = FindQuery(expressions=query.expressions,
|
||||
model=query.model,
|
||||
offset=query.offset + query.page_size,
|
||||
page_size=query.page_size,
|
||||
limit=query.limit)
|
||||
query = query.copy(offset=query.offset + query.page_size)
|
||||
_results = query.execute(exhaust_results=False)
|
||||
if not _results:
|
||||
break
|
||||
|
@ -552,8 +597,7 @@ class FindQuery:
|
|||
return self._model_cache
|
||||
|
||||
def first(self):
|
||||
query = FindQuery(expressions=self.expressions, model=self.model,
|
||||
offset=0, limit=1, sort_fields=self.sort_fields)
|
||||
query = self.copy(offset=0, limit=1, sort_fields=self.sort_fields)
|
||||
results = query.execute()
|
||||
if not results:
|
||||
raise NotFoundError()
|
||||
|
@ -561,26 +605,14 @@ class FindQuery:
|
|||
|
||||
def all(self, batch_size=10):
|
||||
if batch_size != self.page_size:
|
||||
# TODO: There's probably a copy-with-change mechanism in Pydantic,
|
||||
# or can we use one from dataclasses?
|
||||
query = FindQuery(expressions=self.expressions,
|
||||
model=self.model,
|
||||
offset=self.offset,
|
||||
page_size=batch_size,
|
||||
limit=batch_size,
|
||||
sort_fields=self.sort_fields)
|
||||
query = self.copy(page_size=batch_size, limit=batch_size)
|
||||
return query.execute()
|
||||
return self.execute()
|
||||
|
||||
def sort_by(self, *fields: str):
|
||||
if not fields:
|
||||
return self
|
||||
return FindQuery(expressions=self.expressions,
|
||||
model=self.model,
|
||||
offset=self.offset,
|
||||
page_size=self.page_size,
|
||||
limit=self.limit,
|
||||
sort_fields=list(fields))
|
||||
return self.copy(sort_fields=list(fields))
|
||||
|
||||
def update(self, **kwargs):
|
||||
"""Update all matching records in this query."""
|
||||
|
@ -621,11 +653,7 @@ class FindQuery:
|
|||
if self._model_cache and len(self._model_cache) >= item:
|
||||
return self._model_cache[item]
|
||||
|
||||
query = FindQuery(expressions=self.expressions,
|
||||
model=self.model,
|
||||
offset=item,
|
||||
sort_fields=self.sort_fields,
|
||||
limit=1)
|
||||
query = self.copy(offset=item, limit=1)
|
||||
|
||||
return query.execute()[0]
|
||||
|
||||
|
@ -845,7 +873,6 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
|||
pk: Optional[str] = Field(default=None, primary_key=True)
|
||||
|
||||
Meta = DefaultMeta
|
||||
# TODO: Missing _meta here is causing IDE warnings.
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
@ -899,7 +926,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
|||
return cls._meta.database
|
||||
|
||||
@classmethod
|
||||
def find(cls, *expressions: Union[Any, Expression]) -> FindQuery: # TODO: How to type annotate this?
|
||||
def find(cls, *expressions: Union[Any, Expression]) -> FindQuery:
|
||||
return FindQuery(expressions=expressions, model=cls)
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -360,9 +360,15 @@ def test_numeric_queries(members):
|
|||
actual = Member.find(Member.age >= 100).all()
|
||||
assert actual == [member3]
|
||||
|
||||
actual = Member.find(Member.age != 34).all()
|
||||
assert actual == [member1, member3]
|
||||
|
||||
actual = Member.find(~(Member.age == 100)).all()
|
||||
assert actual == [member1, member2]
|
||||
|
||||
actual = Member.find(Member.age > 30, Member.age < 40).all()
|
||||
assert actual == [member1, member2]
|
||||
|
||||
|
||||
def test_sorting(members):
|
||||
member1, member2, member3 = members
|
||||
|
|
|
@ -12,7 +12,7 @@ from redis_developer.orm import (
|
|||
JsonModel,
|
||||
Field,
|
||||
)
|
||||
from redis_developer.orm.model import RedisModelError, QueryNotSupportedError, NotFoundError, embedded
|
||||
from redis_developer.orm.model import QueryNotSupportedError, NotFoundError
|
||||
|
||||
r = redis.Redis()
|
||||
today = datetime.date.today()
|
||||
|
@ -240,6 +240,24 @@ def test_access_result_by_index_not_cached(members):
|
|||
assert query[2] == member3
|
||||
|
||||
|
||||
def test_in_query(members):
|
||||
member1, member2, member3 = members
|
||||
actual = Member.find(Member.pk << [member1.pk, member2.pk, member3.pk]).all()
|
||||
assert actual == [member1, member2, member3]
|
||||
|
||||
|
||||
@pytest.mark.skip("Not implemented yet")
|
||||
def test_update_query(members):
|
||||
member1, member2, member3 = members
|
||||
Member.find(Member.pk << [member1.pk, member2.pk, member3.pk]).update(
|
||||
first_name="Bobby"
|
||||
)
|
||||
actual = Member.find(
|
||||
Member.pk << [member1.pk, member2.pk, member3.pk]).sort_by('age').all()
|
||||
assert actual == [member1, member2, member3]
|
||||
assert all([m.name == "Bobby" for m in actual])
|
||||
|
||||
|
||||
def test_exact_match_queries(members):
|
||||
member1, member2, member3 = members
|
||||
|
||||
|
@ -458,6 +476,12 @@ def test_numeric_queries(members):
|
|||
actual = Member.find(~(Member.age == 100)).all()
|
||||
assert actual == [member1, member2]
|
||||
|
||||
actual = Member.find(Member.age > 30, Member.age < 40).all()
|
||||
assert actual == [member1, member2]
|
||||
|
||||
actual = Member.find(Member.age != 34).all()
|
||||
assert actual == [member1, member3]
|
||||
|
||||
|
||||
def test_sorting(members):
|
||||
member1, member2, member3 = members
|
||||
|
|
Loading…
Reference in a new issue