|
|
|
@ -4,7 +4,7 @@ import decimal
|
|
|
|
|
import json
|
|
|
|
|
import logging
|
|
|
|
|
import operator
|
|
|
|
|
from copy import deepcopy
|
|
|
|
|
from copy import deepcopy, copy
|
|
|
|
|
from enum import Enum
|
|
|
|
|
from functools import reduce
|
|
|
|
|
from typing import (
|
|
|
|
@ -31,7 +31,7 @@ from pydantic import BaseModel, validator
|
|
|
|
|
from pydantic.fields import FieldInfo as PydanticFieldInfo
|
|
|
|
|
from pydantic.fields import ModelField, Undefined, UndefinedType
|
|
|
|
|
from pydantic.main import ModelMetaclass
|
|
|
|
|
from pydantic.typing import NoArgAnyCallable, resolve_annotations
|
|
|
|
|
from pydantic.typing import NoArgAnyCallable
|
|
|
|
|
from pydantic.utils import Representation
|
|
|
|
|
from ulid import ULID
|
|
|
|
|
|
|
|
|
@ -172,8 +172,8 @@ class NegatedExpression:
|
|
|
|
|
@dataclasses.dataclass
|
|
|
|
|
class Expression:
|
|
|
|
|
op: Operators
|
|
|
|
|
left: ExpressionOrModelField
|
|
|
|
|
right: ExpressionOrModelField
|
|
|
|
|
left: Optional[ExpressionOrModelField]
|
|
|
|
|
right: Optional[ExpressionOrModelField]
|
|
|
|
|
parents: List[Tuple[str, 'RedisModel']]
|
|
|
|
|
|
|
|
|
|
def __invert__(self):
|
|
|
|
@ -208,26 +208,28 @@ class ExpressionProxy:
|
|
|
|
|
def __ne__(self, other: Any) -> Expression: # type: ignore[override]
|
|
|
|
|
return Expression(left=self.field, op=Operators.NE, right=other, parents=self.parents)
|
|
|
|
|
|
|
|
|
|
def __lt__(self, other: Any) -> Expression: # type: ignore[override]
|
|
|
|
|
def __lt__(self, other: Any) -> Expression:
|
|
|
|
|
return Expression(left=self.field, op=Operators.LT, right=other, parents=self.parents)
|
|
|
|
|
|
|
|
|
|
def __le__(self, other: Any) -> Expression: # type: ignore[override]
|
|
|
|
|
def __le__(self, other: Any) -> Expression:
|
|
|
|
|
return Expression(left=self.field, op=Operators.LE, right=other, parents=self.parents)
|
|
|
|
|
|
|
|
|
|
def __gt__(self, other: Any) -> Expression: # type: ignore[override]
|
|
|
|
|
def __gt__(self, other: Any) -> Expression:
|
|
|
|
|
return Expression(left=self.field, op=Operators.GT, right=other, parents=self.parents)
|
|
|
|
|
|
|
|
|
|
def __ge__(self, other: Any) -> Expression: # type: ignore[override]
|
|
|
|
|
def __ge__(self, other: Any) -> Expression:
|
|
|
|
|
return Expression(left=self.field, op=Operators.GE, right=other, parents=self.parents)
|
|
|
|
|
|
|
|
|
|
def __mod__(self, other: Any) -> Expression: # type: ignore[override]
|
|
|
|
|
def __mod__(self, other: Any) -> Expression:
|
|
|
|
|
return Expression(left=self.field, op=Operators.LIKE, right=other, parents=self.parents)
|
|
|
|
|
|
|
|
|
|
def __lshift__(self, other: Any) -> Expression:
|
|
|
|
|
return Expression(left=self.field, op=Operators.IN, right=other, parents=self.parents)
|
|
|
|
|
|
|
|
|
|
def __getattr__(self, item):
|
|
|
|
|
if get_origin(self.field.outer_type_) == list:
|
|
|
|
|
embedded_cls = get_args(self.field.outer_type_)
|
|
|
|
|
if not embedded_cls:
|
|
|
|
|
# TODO: Is this even possible?
|
|
|
|
|
raise QuerySyntaxError("In order to query on a list field, you must define "
|
|
|
|
|
"the contents of the list with a type annotation, like: "
|
|
|
|
|
"orders: List[Order]. Docs: TODO")
|
|
|
|
@ -237,7 +239,7 @@ class ExpressionProxy:
|
|
|
|
|
attr = getattr(self.field.outer_type_, item)
|
|
|
|
|
if isinstance(attr, self.__class__):
|
|
|
|
|
new_parent = (self.field.name, self.field.outer_type_)
|
|
|
|
|
if not new_parent in attr.parents:
|
|
|
|
|
if new_parent not in attr.parents:
|
|
|
|
|
attr.parents.append(new_parent)
|
|
|
|
|
new_parents = list(set(self.parents) - set(attr.parents))
|
|
|
|
|
if new_parents:
|
|
|
|
@ -285,6 +287,21 @@ class FindQuery:
|
|
|
|
|
self._pagination: list[str] = []
|
|
|
|
|
self._model_cache: list[RedisModel] = []
|
|
|
|
|
|
|
|
|
|
def dict(self) -> dict[str, Any]:
|
|
|
|
|
return dict(
|
|
|
|
|
model=self.model,
|
|
|
|
|
offset=self.offset,
|
|
|
|
|
page_size=self.page_size,
|
|
|
|
|
limit=self.limit,
|
|
|
|
|
expressions=copy(self.expressions),
|
|
|
|
|
sort_fields=copy(self.sort_fields)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def copy(self, **kwargs):
|
|
|
|
|
original = self.dict()
|
|
|
|
|
original.update(**kwargs)
|
|
|
|
|
return FindQuery(**original)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def pagination(self):
|
|
|
|
|
if self._pagination:
|
|
|
|
@ -299,15 +316,21 @@ class FindQuery:
|
|
|
|
|
if self.expressions:
|
|
|
|
|
self._expression = reduce(operator.and_, self.expressions)
|
|
|
|
|
else:
|
|
|
|
|
# TODO: Is there a better way to support the "give me all records" query?
|
|
|
|
|
# Also -- if we do it this way, we need different type annotations.
|
|
|
|
|
self._expression = Expression(left=None, right=None, op=Operators.ALL,
|
|
|
|
|
parents=[])
|
|
|
|
|
self._expression = Expression(left=None, right=None, op=Operators.ALL, parents=[])
|
|
|
|
|
return self._expression
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def query(self):
|
|
|
|
|
return self.resolve_redisearch_query(self.expression)
|
|
|
|
|
"""
|
|
|
|
|
Resolve and return the RediSearch query for this FindQuery.
|
|
|
|
|
|
|
|
|
|
NOTE: We cache the resolved query string after generating it. This should be OK
|
|
|
|
|
because all mutations of FindQuery through public APIs return a new FindQuery instance.
|
|
|
|
|
"""
|
|
|
|
|
if self._query:
|
|
|
|
|
return self._query
|
|
|
|
|
self._query = self.resolve_redisearch_query(self.expression)
|
|
|
|
|
return self._query
|
|
|
|
|
|
|
|
|
|
def validate_sort_fields(self, sort_fields):
|
|
|
|
|
for sort_field in sort_fields:
|
|
|
|
@ -322,10 +345,10 @@ class FindQuery:
|
|
|
|
|
return sort_fields
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def resolve_field_type(field: ModelField, operator: Operators) -> RediSearchFieldTypes:
|
|
|
|
|
def resolve_field_type(field: ModelField, op: Operators) -> RediSearchFieldTypes:
|
|
|
|
|
if getattr(field.field_info, 'primary_key', None) is True:
|
|
|
|
|
return RediSearchFieldTypes.TAG
|
|
|
|
|
elif operator is Operators.LIKE:
|
|
|
|
|
elif op is Operators.LIKE:
|
|
|
|
|
fts = getattr(field.field_info, 'full_text_search', None)
|
|
|
|
|
if fts is not True: # Could be PydanticUndefined
|
|
|
|
|
raise QuerySyntaxError(f"You tried to do a full-text search on the field '{field.name}', "
|
|
|
|
@ -346,7 +369,7 @@ class FindQuery:
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def expand_tag_value(value):
|
|
|
|
|
if isinstance(str, value):
|
|
|
|
|
if isinstance(value, str):
|
|
|
|
|
return value
|
|
|
|
|
try:
|
|
|
|
|
expanded_value = "|".join([escaper.escape(v) for v in value])
|
|
|
|
@ -380,8 +403,6 @@ class FindQuery:
|
|
|
|
|
if op is Operators.EQ:
|
|
|
|
|
result += f"@{field_name}:[{value} {value}]"
|
|
|
|
|
elif op is Operators.NE:
|
|
|
|
|
# TODO: Is this enough or do we also need a clause for all values
|
|
|
|
|
# ([-inf +inf]) from which we then subtract the undesirable value?
|
|
|
|
|
result += f"-(@{field_name}:[{value} {value}])"
|
|
|
|
|
elif op is Operators.GT:
|
|
|
|
|
result += f"@{field_name}:[({value} +inf]"
|
|
|
|
@ -410,7 +431,7 @@ class FindQuery:
|
|
|
|
|
# The value contains the TAG field separator. We can work
|
|
|
|
|
# around this by breaking apart the values and unioning them
|
|
|
|
|
# with multiple field:{} queries.
|
|
|
|
|
values = filter(None, value.split(separator_char))
|
|
|
|
|
values: filter = filter(None, value.split(separator_char))
|
|
|
|
|
for value in values:
|
|
|
|
|
value = escaper.escape(value)
|
|
|
|
|
result += f"@{field_name}:{{{value}}}"
|
|
|
|
@ -448,7 +469,29 @@ class FindQuery:
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
|
|
|
|
|
"""Resolve an expression to a string RediSearch query."""
|
|
|
|
|
"""
|
|
|
|
|
Resolve an arbitrarily deep expression into a single RediSearch query string.
|
|
|
|
|
|
|
|
|
|
This method is complex. Note the following:
|
|
|
|
|
|
|
|
|
|
1. This method makes a recursive call to itself when it finds that
|
|
|
|
|
either the left or right operand contains another expression.
|
|
|
|
|
|
|
|
|
|
2. An expression might be in a "negated" form, which means that the user
|
|
|
|
|
gave us an expression like ~(Member.age == 30), or in other words,
|
|
|
|
|
"Members whose age is NOT 30." Thus, a negated expression is one in
|
|
|
|
|
which the meaning of an expression is inverted. If we find a negated
|
|
|
|
|
expression, we need to add the appropriate "NOT" syntax but can
|
|
|
|
|
otherwise use the resolved RediSearch query for the expression as-is.
|
|
|
|
|
|
|
|
|
|
3. The final resolution of an expression should be a left operand that's
|
|
|
|
|
a ModelField, an operator, and a right operand that's NOT a ModelField.
|
|
|
|
|
With an IN or NOT_IN operator, the right operand can be a sequence
|
|
|
|
|
type, but otherwise, sequence types are converted to strings.
|
|
|
|
|
|
|
|
|
|
TODO: When the operator is not IN or NOT_IN, detect a sequence type (other
|
|
|
|
|
than strings, which are allowed) and raise an exception.
|
|
|
|
|
"""
|
|
|
|
|
field_type = None
|
|
|
|
|
field_name = None
|
|
|
|
|
field_info = None
|
|
|
|
@ -462,7 +505,7 @@ class FindQuery:
|
|
|
|
|
if expression.op is Operators.ALL:
|
|
|
|
|
if encompassing_expression_is_negated:
|
|
|
|
|
# TODO: Is there a use case for this, perhaps for dynamic
|
|
|
|
|
# scoring purposes?
|
|
|
|
|
# scoring purposes with full-text search?
|
|
|
|
|
raise QueryNotSupportedError("You cannot negate a query for all results.")
|
|
|
|
|
return "*"
|
|
|
|
|
|
|
|
|
@ -500,7 +543,13 @@ class FindQuery:
|
|
|
|
|
|
|
|
|
|
result += f"({cls.resolve_redisearch_query(right)})"
|
|
|
|
|
else:
|
|
|
|
|
if isinstance(right, ModelField):
|
|
|
|
|
if not field_name:
|
|
|
|
|
raise QuerySyntaxError("Could not resolve field name. See docs: TODO")
|
|
|
|
|
elif not field_type:
|
|
|
|
|
raise QuerySyntaxError("Could not resolve field type. See docs: TODO")
|
|
|
|
|
elif not field_info:
|
|
|
|
|
raise QuerySyntaxError("Could not resolve field info. See docs: TODO")
|
|
|
|
|
elif isinstance(right, ModelField):
|
|
|
|
|
raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO")
|
|
|
|
|
else:
|
|
|
|
|
result += cls.resolve_value(field_name, field_type, field_info,
|
|
|
|
@ -540,11 +589,7 @@ class FindQuery:
|
|
|
|
|
while True:
|
|
|
|
|
# Make a query for each pass of the loop, with a new offset equal to the
|
|
|
|
|
# current offset plus `page_size`, until we stop getting results back.
|
|
|
|
|
query = FindQuery(expressions=query.expressions,
|
|
|
|
|
model=query.model,
|
|
|
|
|
offset=query.offset + query.page_size,
|
|
|
|
|
page_size=query.page_size,
|
|
|
|
|
limit=query.limit)
|
|
|
|
|
query = query.copy(offset=query.offset + query.page_size)
|
|
|
|
|
_results = query.execute(exhaust_results=False)
|
|
|
|
|
if not _results:
|
|
|
|
|
break
|
|
|
|
@ -552,8 +597,7 @@ class FindQuery:
|
|
|
|
|
return self._model_cache
|
|
|
|
|
|
|
|
|
|
def first(self):
|
|
|
|
|
query = FindQuery(expressions=self.expressions, model=self.model,
|
|
|
|
|
offset=0, limit=1, sort_fields=self.sort_fields)
|
|
|
|
|
query = self.copy(offset=0, limit=1, sort_fields=self.sort_fields)
|
|
|
|
|
results = query.execute()
|
|
|
|
|
if not results:
|
|
|
|
|
raise NotFoundError()
|
|
|
|
@ -561,26 +605,14 @@ class FindQuery:
|
|
|
|
|
|
|
|
|
|
def all(self, batch_size=10):
|
|
|
|
|
if batch_size != self.page_size:
|
|
|
|
|
# TODO: There's probably a copy-with-change mechanism in Pydantic,
|
|
|
|
|
# or can we use one from dataclasses?
|
|
|
|
|
query = FindQuery(expressions=self.expressions,
|
|
|
|
|
model=self.model,
|
|
|
|
|
offset=self.offset,
|
|
|
|
|
page_size=batch_size,
|
|
|
|
|
limit=batch_size,
|
|
|
|
|
sort_fields=self.sort_fields)
|
|
|
|
|
query = self.copy(page_size=batch_size, limit=batch_size)
|
|
|
|
|
return query.execute()
|
|
|
|
|
return self.execute()
|
|
|
|
|
|
|
|
|
|
def sort_by(self, *fields: str):
|
|
|
|
|
if not fields:
|
|
|
|
|
return self
|
|
|
|
|
return FindQuery(expressions=self.expressions,
|
|
|
|
|
model=self.model,
|
|
|
|
|
offset=self.offset,
|
|
|
|
|
page_size=self.page_size,
|
|
|
|
|
limit=self.limit,
|
|
|
|
|
sort_fields=list(fields))
|
|
|
|
|
return self.copy(sort_fields=list(fields))
|
|
|
|
|
|
|
|
|
|
def update(self, **kwargs):
|
|
|
|
|
"""Update all matching records in this query."""
|
|
|
|
@ -621,11 +653,7 @@ class FindQuery:
|
|
|
|
|
if self._model_cache and len(self._model_cache) >= item:
|
|
|
|
|
return self._model_cache[item]
|
|
|
|
|
|
|
|
|
|
query = FindQuery(expressions=self.expressions,
|
|
|
|
|
model=self.model,
|
|
|
|
|
offset=item,
|
|
|
|
|
sort_fields=self.sort_fields,
|
|
|
|
|
limit=1)
|
|
|
|
|
query = self.copy(offset=item, limit=1)
|
|
|
|
|
|
|
|
|
|
return query.execute()[0]
|
|
|
|
|
|
|
|
|
@ -845,7 +873,6 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
|
|
|
|
pk: Optional[str] = Field(default=None, primary_key=True)
|
|
|
|
|
|
|
|
|
|
Meta = DefaultMeta
|
|
|
|
|
# TODO: Missing _meta here is causing IDE warnings.
|
|
|
|
|
|
|
|
|
|
class Config:
|
|
|
|
|
orm_mode = True
|
|
|
|
@ -899,7 +926,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
|
|
|
|
return cls._meta.database
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def find(cls, *expressions: Union[Any, Expression]) -> FindQuery: # TODO: How to type annotate this?
|
|
|
|
|
def find(cls, *expressions: Union[Any, Expression]) -> FindQuery:
|
|
|
|
|
return FindQuery(expressions=expressions, model=cls)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|