Improve handling of TAG field queries
This commit is contained in:
parent
a788cbedbb
commit
b7c9165bbd
2 changed files with 80 additions and 13 deletions
|
@ -70,6 +70,12 @@ class NegatedExpression:
|
|||
def __invert__(self):
|
||||
return self.expression
|
||||
|
||||
def __and__(self, other):
|
||||
return Expression(left=self, op=Operators.AND, right=other)
|
||||
|
||||
def __or__(self, other):
|
||||
return Expression(left=self, op=Operators.OR, right=other)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Expression:
|
||||
|
@ -121,16 +127,31 @@ class FindQuery:
|
|||
self.pagination = self.resolve_redisearch_pagination()
|
||||
|
||||
def resolve_field_type(self, field: ModelField) -> RediSearchFieldTypes:
|
||||
if getattr(field.field_info, 'primary_key', None):
|
||||
if getattr(field.field_info, 'primary_key', None) is True:
|
||||
return RediSearchFieldTypes.TAG
|
||||
elif getattr(field.field_info, 'full_text_search', None) is True:
|
||||
return RediSearchFieldTypes.TEXT
|
||||
|
||||
field_type = field.outer_type_
|
||||
|
||||
# TODO: GEO
|
||||
# TODO: TAG (other than PK)
|
||||
if any(issubclass(field_type, t) for t in NUMERIC_TYPES):
|
||||
return RediSearchFieldTypes.NUMERIC
|
||||
else:
|
||||
return RediSearchFieldTypes.TEXT
|
||||
# TAG fields are the default field type.
|
||||
return RediSearchFieldTypes.TAG
|
||||
|
||||
@staticmethod
|
||||
def expand_tag_value(value):
|
||||
err = RedisModelError(f"Using the IN operator requires passing an iterable of "
|
||||
"possible values. You passed: {value}")
|
||||
if isinstance(str, value):
|
||||
raise err
|
||||
try:
|
||||
expanded_value = "|".join(value)
|
||||
except TypeError:
|
||||
raise err
|
||||
return expanded_value
|
||||
|
||||
def resolve_value(self, field_name: str, field_type: RediSearchFieldTypes,
|
||||
op: Operators, value: Any) -> str:
|
||||
|
@ -151,7 +172,8 @@ class FindQuery:
|
|||
if op is Operators.EQ:
|
||||
result += f"@{field_name}:[{value} {value}]"
|
||||
elif op is Operators.NE:
|
||||
# TODO: Is this enough or do we also need a clause for all values ([-inf +inf])?
|
||||
# TODO: Is this enough or do we also need a clause for all values
|
||||
# ([-inf +inf]) from which we then subtract the undesirable value?
|
||||
result += f"~(@{field_name}:[{value} {value}])"
|
||||
elif op is Operators.GT:
|
||||
result += f"@{field_name}:[({value} +inf]"
|
||||
|
@ -161,6 +183,17 @@ class FindQuery:
|
|||
result += f"@{field_name}:[{value} +inf]"
|
||||
elif op is Operators.LE:
|
||||
result += f"@{field_name}:[-inf {value}]"
|
||||
elif field_type is RediSearchFieldTypes.TAG:
|
||||
if op is Operators.EQ:
|
||||
result += f"@{field_name}:{{{value}}}"
|
||||
elif op is Operators.NE:
|
||||
result += f"~(@{field_name}:{{{value}}})"
|
||||
elif op is Operators.IN:
|
||||
expanded_value = self.expand_tag_value(value)
|
||||
result += f"(@{field_name}:{{{expanded_value}}})"
|
||||
elif op is Operators.NOT_IN:
|
||||
expanded_value = self.expand_tag_value(value)
|
||||
result += f"~(@{field_name}:{{{expanded_value}}})"
|
||||
|
||||
return result
|
||||
|
||||
|
@ -176,19 +209,22 @@ class FindQuery:
|
|||
"""Resolve an expression to a string RediSearch query."""
|
||||
field_type = None
|
||||
field_name = None
|
||||
is_negated = False
|
||||
encompassing_expression_is_negated = False
|
||||
result = ""
|
||||
|
||||
if isinstance(expression, NegatedExpression):
|
||||
is_negated = True
|
||||
encompassing_expression_is_negated = True
|
||||
expression = expression.expression
|
||||
if isinstance(expression.left, Expression):
|
||||
|
||||
if isinstance(expression.left, Expression) or \
|
||||
isinstance(expression.left, NegatedExpression):
|
||||
result += f"({self.resolve_redisearch_query(expression.left)})"
|
||||
elif isinstance(expression.left, ModelField):
|
||||
field_type = self.resolve_field_type(expression.left)
|
||||
field_name = expression.left.name
|
||||
else:
|
||||
raise QueryNotSupportedError(f"A query expression should start with either a field"
|
||||
import ipdb; ipdb.set_trace()
|
||||
raise QueryNotSupportedError(f"A query expression should start with either a field "
|
||||
f"or an expression enclosed in parenthesis. See docs: "
|
||||
f"TODO")
|
||||
|
||||
|
@ -217,7 +253,7 @@ class FindQuery:
|
|||
else:
|
||||
result += self.resolve_value(field_name, field_type, expression.op, right)
|
||||
|
||||
if is_negated:
|
||||
if encompassing_expression_is_negated:
|
||||
result = f"-({result})"
|
||||
|
||||
return result
|
||||
|
|
|
@ -29,11 +29,11 @@ class Order(BaseHashModel):
|
|||
|
||||
|
||||
class Member(BaseHashModel):
|
||||
first_name: str
|
||||
last_name: str
|
||||
first_name: str = Field(index=True)
|
||||
last_name: str = Field(index=True)
|
||||
email: str = Field(index=True)
|
||||
join_date: datetime.date
|
||||
age: int
|
||||
age: int = Field(index=True)
|
||||
|
||||
class Meta:
|
||||
model_key_prefix = "member"
|
||||
|
@ -176,7 +176,7 @@ def test_exact_match_queries(members):
|
|||
member1, member2, member3 = members
|
||||
|
||||
actual = Member.find(Member.last_name == "Brookins")
|
||||
assert actual == sorted([member1, member2])
|
||||
assert sorted(actual) == [member1, member2]
|
||||
|
||||
actual = Member.find(
|
||||
(Member.last_name == "Brookins") & ~(Member.first_name == "Andrew"))
|
||||
|
@ -198,6 +198,37 @@ def test_exact_match_queries(members):
|
|||
assert actual == member2
|
||||
|
||||
|
||||
def test_recursive_query_resolution(members):
|
||||
member1, member2, member3 = members
|
||||
|
||||
actual = Member.find((Member.last_name == "Brookins") | (
|
||||
Member.age == 100
|
||||
) & (Member.last_name == "Smith"))
|
||||
assert sorted(actual) == [member1, member2, member3]
|
||||
|
||||
|
||||
def test_tag_queries_boolean_logic(members):
|
||||
member1, member2, member3 = members
|
||||
|
||||
actual = Member.find(
|
||||
(Member.first_name == "Andrew") &
|
||||
(Member.last_name == "Brookins") | (Member.last_name == "Smith"))
|
||||
assert sorted(actual) == [member1, member3]
|
||||
|
||||
|
||||
def test_tag_queries_negation(members):
|
||||
member1, member2, member3 = members
|
||||
|
||||
actual = Member.find(
|
||||
~(Member.first_name == "Andrew") &
|
||||
(Member.last_name == "Brookins") | (Member.last_name == "Smith"))
|
||||
assert sorted(actual) == [member2, member3]
|
||||
|
||||
actual = Member.find(
|
||||
(Member.first_name == "Andrew") & ~(Member.last_name == "Brookins"))
|
||||
assert sorted(actual) == [member3]
|
||||
|
||||
|
||||
def test_numeric_queries(members):
|
||||
member1, member2, member3 = members
|
||||
|
||||
|
|
Loading…
Reference in a new issue