Add support for IN queries
This commit is contained in:
parent
bb08fb9eb5
commit
389a6ea878
3 changed files with 111 additions and 54 deletions
|
@ -4,7 +4,7 @@ import decimal
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import operator
|
import operator
|
||||||
from copy import deepcopy
|
from copy import deepcopy, copy
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import (
|
from typing import (
|
||||||
|
@ -31,7 +31,7 @@ from pydantic import BaseModel, validator
|
||||||
from pydantic.fields import FieldInfo as PydanticFieldInfo
|
from pydantic.fields import FieldInfo as PydanticFieldInfo
|
||||||
from pydantic.fields import ModelField, Undefined, UndefinedType
|
from pydantic.fields import ModelField, Undefined, UndefinedType
|
||||||
from pydantic.main import ModelMetaclass
|
from pydantic.main import ModelMetaclass
|
||||||
from pydantic.typing import NoArgAnyCallable, resolve_annotations
|
from pydantic.typing import NoArgAnyCallable
|
||||||
from pydantic.utils import Representation
|
from pydantic.utils import Representation
|
||||||
from ulid import ULID
|
from ulid import ULID
|
||||||
|
|
||||||
|
@ -172,8 +172,8 @@ class NegatedExpression:
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class Expression:
|
class Expression:
|
||||||
op: Operators
|
op: Operators
|
||||||
left: ExpressionOrModelField
|
left: Optional[ExpressionOrModelField]
|
||||||
right: ExpressionOrModelField
|
right: Optional[ExpressionOrModelField]
|
||||||
parents: List[Tuple[str, 'RedisModel']]
|
parents: List[Tuple[str, 'RedisModel']]
|
||||||
|
|
||||||
def __invert__(self):
|
def __invert__(self):
|
||||||
|
@ -208,26 +208,28 @@ class ExpressionProxy:
|
||||||
def __ne__(self, other: Any) -> Expression: # type: ignore[override]
|
def __ne__(self, other: Any) -> Expression: # type: ignore[override]
|
||||||
return Expression(left=self.field, op=Operators.NE, right=other, parents=self.parents)
|
return Expression(left=self.field, op=Operators.NE, right=other, parents=self.parents)
|
||||||
|
|
||||||
def __lt__(self, other: Any) -> Expression: # type: ignore[override]
|
def __lt__(self, other: Any) -> Expression:
|
||||||
return Expression(left=self.field, op=Operators.LT, right=other, parents=self.parents)
|
return Expression(left=self.field, op=Operators.LT, right=other, parents=self.parents)
|
||||||
|
|
||||||
def __le__(self, other: Any) -> Expression: # type: ignore[override]
|
def __le__(self, other: Any) -> Expression:
|
||||||
return Expression(left=self.field, op=Operators.LE, right=other, parents=self.parents)
|
return Expression(left=self.field, op=Operators.LE, right=other, parents=self.parents)
|
||||||
|
|
||||||
def __gt__(self, other: Any) -> Expression: # type: ignore[override]
|
def __gt__(self, other: Any) -> Expression:
|
||||||
return Expression(left=self.field, op=Operators.GT, right=other, parents=self.parents)
|
return Expression(left=self.field, op=Operators.GT, right=other, parents=self.parents)
|
||||||
|
|
||||||
def __ge__(self, other: Any) -> Expression: # type: ignore[override]
|
def __ge__(self, other: Any) -> Expression:
|
||||||
return Expression(left=self.field, op=Operators.GE, right=other, parents=self.parents)
|
return Expression(left=self.field, op=Operators.GE, right=other, parents=self.parents)
|
||||||
|
|
||||||
def __mod__(self, other: Any) -> Expression: # type: ignore[override]
|
def __mod__(self, other: Any) -> Expression:
|
||||||
return Expression(left=self.field, op=Operators.LIKE, right=other, parents=self.parents)
|
return Expression(left=self.field, op=Operators.LIKE, right=other, parents=self.parents)
|
||||||
|
|
||||||
|
def __lshift__(self, other: Any) -> Expression:
|
||||||
|
return Expression(left=self.field, op=Operators.IN, right=other, parents=self.parents)
|
||||||
|
|
||||||
def __getattr__(self, item):
|
def __getattr__(self, item):
|
||||||
if get_origin(self.field.outer_type_) == list:
|
if get_origin(self.field.outer_type_) == list:
|
||||||
embedded_cls = get_args(self.field.outer_type_)
|
embedded_cls = get_args(self.field.outer_type_)
|
||||||
if not embedded_cls:
|
if not embedded_cls:
|
||||||
# TODO: Is this even possible?
|
|
||||||
raise QuerySyntaxError("In order to query on a list field, you must define "
|
raise QuerySyntaxError("In order to query on a list field, you must define "
|
||||||
"the contents of the list with a type annotation, like: "
|
"the contents of the list with a type annotation, like: "
|
||||||
"orders: List[Order]. Docs: TODO")
|
"orders: List[Order]. Docs: TODO")
|
||||||
|
@ -237,7 +239,7 @@ class ExpressionProxy:
|
||||||
attr = getattr(self.field.outer_type_, item)
|
attr = getattr(self.field.outer_type_, item)
|
||||||
if isinstance(attr, self.__class__):
|
if isinstance(attr, self.__class__):
|
||||||
new_parent = (self.field.name, self.field.outer_type_)
|
new_parent = (self.field.name, self.field.outer_type_)
|
||||||
if not new_parent in attr.parents:
|
if new_parent not in attr.parents:
|
||||||
attr.parents.append(new_parent)
|
attr.parents.append(new_parent)
|
||||||
new_parents = list(set(self.parents) - set(attr.parents))
|
new_parents = list(set(self.parents) - set(attr.parents))
|
||||||
if new_parents:
|
if new_parents:
|
||||||
|
@ -285,6 +287,21 @@ class FindQuery:
|
||||||
self._pagination: list[str] = []
|
self._pagination: list[str] = []
|
||||||
self._model_cache: list[RedisModel] = []
|
self._model_cache: list[RedisModel] = []
|
||||||
|
|
||||||
|
def dict(self) -> dict[str, Any]:
|
||||||
|
return dict(
|
||||||
|
model=self.model,
|
||||||
|
offset=self.offset,
|
||||||
|
page_size=self.page_size,
|
||||||
|
limit=self.limit,
|
||||||
|
expressions=copy(self.expressions),
|
||||||
|
sort_fields=copy(self.sort_fields)
|
||||||
|
)
|
||||||
|
|
||||||
|
def copy(self, **kwargs):
|
||||||
|
original = self.dict()
|
||||||
|
original.update(**kwargs)
|
||||||
|
return FindQuery(**original)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pagination(self):
|
def pagination(self):
|
||||||
if self._pagination:
|
if self._pagination:
|
||||||
|
@ -299,15 +316,21 @@ class FindQuery:
|
||||||
if self.expressions:
|
if self.expressions:
|
||||||
self._expression = reduce(operator.and_, self.expressions)
|
self._expression = reduce(operator.and_, self.expressions)
|
||||||
else:
|
else:
|
||||||
# TODO: Is there a better way to support the "give me all records" query?
|
self._expression = Expression(left=None, right=None, op=Operators.ALL, parents=[])
|
||||||
# Also -- if we do it this way, we need different type annotations.
|
|
||||||
self._expression = Expression(left=None, right=None, op=Operators.ALL,
|
|
||||||
parents=[])
|
|
||||||
return self._expression
|
return self._expression
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def query(self):
|
def query(self):
|
||||||
return self.resolve_redisearch_query(self.expression)
|
"""
|
||||||
|
Resolve and return the RediSearch query for this FindQuery.
|
||||||
|
|
||||||
|
NOTE: We cache the resolved query string after generating it. This should be OK
|
||||||
|
because all mutations of FindQuery through public APIs return a new FindQuery instance.
|
||||||
|
"""
|
||||||
|
if self._query:
|
||||||
|
return self._query
|
||||||
|
self._query = self.resolve_redisearch_query(self.expression)
|
||||||
|
return self._query
|
||||||
|
|
||||||
def validate_sort_fields(self, sort_fields):
|
def validate_sort_fields(self, sort_fields):
|
||||||
for sort_field in sort_fields:
|
for sort_field in sort_fields:
|
||||||
|
@ -322,10 +345,10 @@ class FindQuery:
|
||||||
return sort_fields
|
return sort_fields
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def resolve_field_type(field: ModelField, operator: Operators) -> RediSearchFieldTypes:
|
def resolve_field_type(field: ModelField, op: Operators) -> RediSearchFieldTypes:
|
||||||
if getattr(field.field_info, 'primary_key', None) is True:
|
if getattr(field.field_info, 'primary_key', None) is True:
|
||||||
return RediSearchFieldTypes.TAG
|
return RediSearchFieldTypes.TAG
|
||||||
elif operator is Operators.LIKE:
|
elif op is Operators.LIKE:
|
||||||
fts = getattr(field.field_info, 'full_text_search', None)
|
fts = getattr(field.field_info, 'full_text_search', None)
|
||||||
if fts is not True: # Could be PydanticUndefined
|
if fts is not True: # Could be PydanticUndefined
|
||||||
raise QuerySyntaxError(f"You tried to do a full-text search on the field '{field.name}', "
|
raise QuerySyntaxError(f"You tried to do a full-text search on the field '{field.name}', "
|
||||||
|
@ -346,7 +369,7 @@ class FindQuery:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def expand_tag_value(value):
|
def expand_tag_value(value):
|
||||||
if isinstance(str, value):
|
if isinstance(value, str):
|
||||||
return value
|
return value
|
||||||
try:
|
try:
|
||||||
expanded_value = "|".join([escaper.escape(v) for v in value])
|
expanded_value = "|".join([escaper.escape(v) for v in value])
|
||||||
|
@ -380,8 +403,6 @@ class FindQuery:
|
||||||
if op is Operators.EQ:
|
if op is Operators.EQ:
|
||||||
result += f"@{field_name}:[{value} {value}]"
|
result += f"@{field_name}:[{value} {value}]"
|
||||||
elif op is Operators.NE:
|
elif op is Operators.NE:
|
||||||
# TODO: Is this enough or do we also need a clause for all values
|
|
||||||
# ([-inf +inf]) from which we then subtract the undesirable value?
|
|
||||||
result += f"-(@{field_name}:[{value} {value}])"
|
result += f"-(@{field_name}:[{value} {value}])"
|
||||||
elif op is Operators.GT:
|
elif op is Operators.GT:
|
||||||
result += f"@{field_name}:[({value} +inf]"
|
result += f"@{field_name}:[({value} +inf]"
|
||||||
|
@ -410,7 +431,7 @@ class FindQuery:
|
||||||
# The value contains the TAG field separator. We can work
|
# The value contains the TAG field separator. We can work
|
||||||
# around this by breaking apart the values and unioning them
|
# around this by breaking apart the values and unioning them
|
||||||
# with multiple field:{} queries.
|
# with multiple field:{} queries.
|
||||||
values = filter(None, value.split(separator_char))
|
values: filter = filter(None, value.split(separator_char))
|
||||||
for value in values:
|
for value in values:
|
||||||
value = escaper.escape(value)
|
value = escaper.escape(value)
|
||||||
result += f"@{field_name}:{{{value}}}"
|
result += f"@{field_name}:{{{value}}}"
|
||||||
|
@ -448,7 +469,29 @@ class FindQuery:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
|
def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
|
||||||
"""Resolve an expression to a string RediSearch query."""
|
"""
|
||||||
|
Resolve an arbitrarily deep expression into a single RediSearch query string.
|
||||||
|
|
||||||
|
This method is complex. Note the following:
|
||||||
|
|
||||||
|
1. This method makes a recursive call to itself when it finds that
|
||||||
|
either the left or right operand contains another expression.
|
||||||
|
|
||||||
|
2. An expression might be in a "negated" form, which means that the user
|
||||||
|
gave us an expression like ~(Member.age == 30), or in other words,
|
||||||
|
"Members whose age is NOT 30." Thus, a negated expression is one in
|
||||||
|
which the meaning of an expression is inverted. If we find a negated
|
||||||
|
expression, we need to add the appropriate "NOT" syntax but can
|
||||||
|
otherwise use the resolved RediSearch query for the expression as-is.
|
||||||
|
|
||||||
|
3. The final resolution of an expression should be a left operand that's
|
||||||
|
a ModelField, an operator, and a right operand that's NOT a ModelField.
|
||||||
|
With an IN or NOT_IN operator, the right operand can be a sequence
|
||||||
|
type, but otherwise, sequence types are converted to strings.
|
||||||
|
|
||||||
|
TODO: When the operator is not IN or NOT_IN, detect a sequence type (other
|
||||||
|
than strings, which are allowed) and raise an exception.
|
||||||
|
"""
|
||||||
field_type = None
|
field_type = None
|
||||||
field_name = None
|
field_name = None
|
||||||
field_info = None
|
field_info = None
|
||||||
|
@ -462,7 +505,7 @@ class FindQuery:
|
||||||
if expression.op is Operators.ALL:
|
if expression.op is Operators.ALL:
|
||||||
if encompassing_expression_is_negated:
|
if encompassing_expression_is_negated:
|
||||||
# TODO: Is there a use case for this, perhaps for dynamic
|
# TODO: Is there a use case for this, perhaps for dynamic
|
||||||
# scoring purposes?
|
# scoring purposes with full-text search?
|
||||||
raise QueryNotSupportedError("You cannot negate a query for all results.")
|
raise QueryNotSupportedError("You cannot negate a query for all results.")
|
||||||
return "*"
|
return "*"
|
||||||
|
|
||||||
|
@ -500,7 +543,13 @@ class FindQuery:
|
||||||
|
|
||||||
result += f"({cls.resolve_redisearch_query(right)})"
|
result += f"({cls.resolve_redisearch_query(right)})"
|
||||||
else:
|
else:
|
||||||
if isinstance(right, ModelField):
|
if not field_name:
|
||||||
|
raise QuerySyntaxError("Could not resolve field name. See docs: TODO")
|
||||||
|
elif not field_type:
|
||||||
|
raise QuerySyntaxError("Could not resolve field type. See docs: TODO")
|
||||||
|
elif not field_info:
|
||||||
|
raise QuerySyntaxError("Could not resolve field info. See docs: TODO")
|
||||||
|
elif isinstance(right, ModelField):
|
||||||
raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO")
|
raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO")
|
||||||
else:
|
else:
|
||||||
result += cls.resolve_value(field_name, field_type, field_info,
|
result += cls.resolve_value(field_name, field_type, field_info,
|
||||||
|
@ -540,11 +589,7 @@ class FindQuery:
|
||||||
while True:
|
while True:
|
||||||
# Make a query for each pass of the loop, with a new offset equal to the
|
# Make a query for each pass of the loop, with a new offset equal to the
|
||||||
# current offset plus `page_size`, until we stop getting results back.
|
# current offset plus `page_size`, until we stop getting results back.
|
||||||
query = FindQuery(expressions=query.expressions,
|
query = query.copy(offset=query.offset + query.page_size)
|
||||||
model=query.model,
|
|
||||||
offset=query.offset + query.page_size,
|
|
||||||
page_size=query.page_size,
|
|
||||||
limit=query.limit)
|
|
||||||
_results = query.execute(exhaust_results=False)
|
_results = query.execute(exhaust_results=False)
|
||||||
if not _results:
|
if not _results:
|
||||||
break
|
break
|
||||||
|
@ -552,8 +597,7 @@ class FindQuery:
|
||||||
return self._model_cache
|
return self._model_cache
|
||||||
|
|
||||||
def first(self):
|
def first(self):
|
||||||
query = FindQuery(expressions=self.expressions, model=self.model,
|
query = self.copy(offset=0, limit=1, sort_fields=self.sort_fields)
|
||||||
offset=0, limit=1, sort_fields=self.sort_fields)
|
|
||||||
results = query.execute()
|
results = query.execute()
|
||||||
if not results:
|
if not results:
|
||||||
raise NotFoundError()
|
raise NotFoundError()
|
||||||
|
@ -561,26 +605,14 @@ class FindQuery:
|
||||||
|
|
||||||
def all(self, batch_size=10):
|
def all(self, batch_size=10):
|
||||||
if batch_size != self.page_size:
|
if batch_size != self.page_size:
|
||||||
# TODO: There's probably a copy-with-change mechanism in Pydantic,
|
query = self.copy(page_size=batch_size, limit=batch_size)
|
||||||
# or can we use one from dataclasses?
|
|
||||||
query = FindQuery(expressions=self.expressions,
|
|
||||||
model=self.model,
|
|
||||||
offset=self.offset,
|
|
||||||
page_size=batch_size,
|
|
||||||
limit=batch_size,
|
|
||||||
sort_fields=self.sort_fields)
|
|
||||||
return query.execute()
|
return query.execute()
|
||||||
return self.execute()
|
return self.execute()
|
||||||
|
|
||||||
def sort_by(self, *fields: str):
|
def sort_by(self, *fields: str):
|
||||||
if not fields:
|
if not fields:
|
||||||
return self
|
return self
|
||||||
return FindQuery(expressions=self.expressions,
|
return self.copy(sort_fields=list(fields))
|
||||||
model=self.model,
|
|
||||||
offset=self.offset,
|
|
||||||
page_size=self.page_size,
|
|
||||||
limit=self.limit,
|
|
||||||
sort_fields=list(fields))
|
|
||||||
|
|
||||||
def update(self, **kwargs):
|
def update(self, **kwargs):
|
||||||
"""Update all matching records in this query."""
|
"""Update all matching records in this query."""
|
||||||
|
@ -621,11 +653,7 @@ class FindQuery:
|
||||||
if self._model_cache and len(self._model_cache) >= item:
|
if self._model_cache and len(self._model_cache) >= item:
|
||||||
return self._model_cache[item]
|
return self._model_cache[item]
|
||||||
|
|
||||||
query = FindQuery(expressions=self.expressions,
|
query = self.copy(offset=item, limit=1)
|
||||||
model=self.model,
|
|
||||||
offset=item,
|
|
||||||
sort_fields=self.sort_fields,
|
|
||||||
limit=1)
|
|
||||||
|
|
||||||
return query.execute()[0]
|
return query.execute()[0]
|
||||||
|
|
||||||
|
@ -845,7 +873,6 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
||||||
pk: Optional[str] = Field(default=None, primary_key=True)
|
pk: Optional[str] = Field(default=None, primary_key=True)
|
||||||
|
|
||||||
Meta = DefaultMeta
|
Meta = DefaultMeta
|
||||||
# TODO: Missing _meta here is causing IDE warnings.
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
orm_mode = True
|
orm_mode = True
|
||||||
|
@ -899,7 +926,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
||||||
return cls._meta.database
|
return cls._meta.database
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def find(cls, *expressions: Union[Any, Expression]) -> FindQuery: # TODO: How to type annotate this?
|
def find(cls, *expressions: Union[Any, Expression]) -> FindQuery:
|
||||||
return FindQuery(expressions=expressions, model=cls)
|
return FindQuery(expressions=expressions, model=cls)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -360,9 +360,15 @@ def test_numeric_queries(members):
|
||||||
actual = Member.find(Member.age >= 100).all()
|
actual = Member.find(Member.age >= 100).all()
|
||||||
assert actual == [member3]
|
assert actual == [member3]
|
||||||
|
|
||||||
|
actual = Member.find(Member.age != 34).all()
|
||||||
|
assert actual == [member1, member3]
|
||||||
|
|
||||||
actual = Member.find(~(Member.age == 100)).all()
|
actual = Member.find(~(Member.age == 100)).all()
|
||||||
assert actual == [member1, member2]
|
assert actual == [member1, member2]
|
||||||
|
|
||||||
|
actual = Member.find(Member.age > 30, Member.age < 40).all()
|
||||||
|
assert actual == [member1, member2]
|
||||||
|
|
||||||
|
|
||||||
def test_sorting(members):
|
def test_sorting(members):
|
||||||
member1, member2, member3 = members
|
member1, member2, member3 = members
|
||||||
|
|
|
@ -12,7 +12,7 @@ from redis_developer.orm import (
|
||||||
JsonModel,
|
JsonModel,
|
||||||
Field,
|
Field,
|
||||||
)
|
)
|
||||||
from redis_developer.orm.model import RedisModelError, QueryNotSupportedError, NotFoundError, embedded
|
from redis_developer.orm.model import QueryNotSupportedError, NotFoundError
|
||||||
|
|
||||||
r = redis.Redis()
|
r = redis.Redis()
|
||||||
today = datetime.date.today()
|
today = datetime.date.today()
|
||||||
|
@ -240,6 +240,24 @@ def test_access_result_by_index_not_cached(members):
|
||||||
assert query[2] == member3
|
assert query[2] == member3
|
||||||
|
|
||||||
|
|
||||||
|
def test_in_query(members):
|
||||||
|
member1, member2, member3 = members
|
||||||
|
actual = Member.find(Member.pk << [member1.pk, member2.pk, member3.pk]).all()
|
||||||
|
assert actual == [member1, member2, member3]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip("Not implemented yet")
|
||||||
|
def test_update_query(members):
|
||||||
|
member1, member2, member3 = members
|
||||||
|
Member.find(Member.pk << [member1.pk, member2.pk, member3.pk]).update(
|
||||||
|
first_name="Bobby"
|
||||||
|
)
|
||||||
|
actual = Member.find(
|
||||||
|
Member.pk << [member1.pk, member2.pk, member3.pk]).sort_by('age').all()
|
||||||
|
assert actual == [member1, member2, member3]
|
||||||
|
assert all([m.name == "Bobby" for m in actual])
|
||||||
|
|
||||||
|
|
||||||
def test_exact_match_queries(members):
|
def test_exact_match_queries(members):
|
||||||
member1, member2, member3 = members
|
member1, member2, member3 = members
|
||||||
|
|
||||||
|
@ -458,6 +476,12 @@ def test_numeric_queries(members):
|
||||||
actual = Member.find(~(Member.age == 100)).all()
|
actual = Member.find(~(Member.age == 100)).all()
|
||||||
assert actual == [member1, member2]
|
assert actual == [member1, member2]
|
||||||
|
|
||||||
|
actual = Member.find(Member.age > 30, Member.age < 40).all()
|
||||||
|
assert actual == [member1, member2]
|
||||||
|
|
||||||
|
actual = Member.find(Member.age != 34).all()
|
||||||
|
assert actual == [member1, member3]
|
||||||
|
|
||||||
|
|
||||||
def test_sorting(members):
|
def test_sorting(members):
|
||||||
member1, member2, member3 = members
|
member1, member2, member3 = members
|
||||||
|
|
Loading…
Reference in a new issue