Improve handling of TAG field queries
This commit is contained in:
parent
a788cbedbb
commit
b7c9165bbd
2 changed files with 80 additions and 13 deletions
|
@ -70,6 +70,12 @@ class NegatedExpression:
|
||||||
def __invert__(self):
|
def __invert__(self):
|
||||||
return self.expression
|
return self.expression
|
||||||
|
|
||||||
|
def __and__(self, other):
|
||||||
|
return Expression(left=self, op=Operators.AND, right=other)
|
||||||
|
|
||||||
|
def __or__(self, other):
|
||||||
|
return Expression(left=self, op=Operators.OR, right=other)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Expression:
|
class Expression:
|
||||||
|
@ -121,16 +127,31 @@ class FindQuery:
|
||||||
self.pagination = self.resolve_redisearch_pagination()
|
self.pagination = self.resolve_redisearch_pagination()
|
||||||
|
|
||||||
def resolve_field_type(self, field: ModelField) -> RediSearchFieldTypes:
|
def resolve_field_type(self, field: ModelField) -> RediSearchFieldTypes:
|
||||||
if getattr(field.field_info, 'primary_key', None):
|
if getattr(field.field_info, 'primary_key', None) is True:
|
||||||
return RediSearchFieldTypes.TAG
|
return RediSearchFieldTypes.TAG
|
||||||
|
elif getattr(field.field_info, 'full_text_search', None) is True:
|
||||||
|
return RediSearchFieldTypes.TEXT
|
||||||
|
|
||||||
field_type = field.outer_type_
|
field_type = field.outer_type_
|
||||||
|
|
||||||
# TODO: GEO
|
# TODO: GEO
|
||||||
# TODO: TAG (other than PK)
|
|
||||||
if any(issubclass(field_type, t) for t in NUMERIC_TYPES):
|
if any(issubclass(field_type, t) for t in NUMERIC_TYPES):
|
||||||
return RediSearchFieldTypes.NUMERIC
|
return RediSearchFieldTypes.NUMERIC
|
||||||
else:
|
else:
|
||||||
return RediSearchFieldTypes.TEXT
|
# TAG fields are the default field type.
|
||||||
|
return RediSearchFieldTypes.TAG
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def expand_tag_value(value):
|
||||||
|
err = RedisModelError(f"Using the IN operator requires passing an iterable of "
|
||||||
|
"possible values. You passed: {value}")
|
||||||
|
if isinstance(str, value):
|
||||||
|
raise err
|
||||||
|
try:
|
||||||
|
expanded_value = "|".join(value)
|
||||||
|
except TypeError:
|
||||||
|
raise err
|
||||||
|
return expanded_value
|
||||||
|
|
||||||
def resolve_value(self, field_name: str, field_type: RediSearchFieldTypes,
|
def resolve_value(self, field_name: str, field_type: RediSearchFieldTypes,
|
||||||
op: Operators, value: Any) -> str:
|
op: Operators, value: Any) -> str:
|
||||||
|
@ -151,7 +172,8 @@ class FindQuery:
|
||||||
if op is Operators.EQ:
|
if op is Operators.EQ:
|
||||||
result += f"@{field_name}:[{value} {value}]"
|
result += f"@{field_name}:[{value} {value}]"
|
||||||
elif op is Operators.NE:
|
elif op is Operators.NE:
|
||||||
# TODO: Is this enough or do we also need a clause for all values ([-inf +inf])?
|
# TODO: Is this enough or do we also need a clause for all values
|
||||||
|
# ([-inf +inf]) from which we then subtract the undesirable value?
|
||||||
result += f"~(@{field_name}:[{value} {value}])"
|
result += f"~(@{field_name}:[{value} {value}])"
|
||||||
elif op is Operators.GT:
|
elif op is Operators.GT:
|
||||||
result += f"@{field_name}:[({value} +inf]"
|
result += f"@{field_name}:[({value} +inf]"
|
||||||
|
@ -161,6 +183,17 @@ class FindQuery:
|
||||||
result += f"@{field_name}:[{value} +inf]"
|
result += f"@{field_name}:[{value} +inf]"
|
||||||
elif op is Operators.LE:
|
elif op is Operators.LE:
|
||||||
result += f"@{field_name}:[-inf {value}]"
|
result += f"@{field_name}:[-inf {value}]"
|
||||||
|
elif field_type is RediSearchFieldTypes.TAG:
|
||||||
|
if op is Operators.EQ:
|
||||||
|
result += f"@{field_name}:{{{value}}}"
|
||||||
|
elif op is Operators.NE:
|
||||||
|
result += f"~(@{field_name}:{{{value}}})"
|
||||||
|
elif op is Operators.IN:
|
||||||
|
expanded_value = self.expand_tag_value(value)
|
||||||
|
result += f"(@{field_name}:{{{expanded_value}}})"
|
||||||
|
elif op is Operators.NOT_IN:
|
||||||
|
expanded_value = self.expand_tag_value(value)
|
||||||
|
result += f"~(@{field_name}:{{{expanded_value}}})"
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -176,19 +209,22 @@ class FindQuery:
|
||||||
"""Resolve an expression to a string RediSearch query."""
|
"""Resolve an expression to a string RediSearch query."""
|
||||||
field_type = None
|
field_type = None
|
||||||
field_name = None
|
field_name = None
|
||||||
is_negated = False
|
encompassing_expression_is_negated = False
|
||||||
result = ""
|
result = ""
|
||||||
|
|
||||||
if isinstance(expression, NegatedExpression):
|
if isinstance(expression, NegatedExpression):
|
||||||
is_negated = True
|
encompassing_expression_is_negated = True
|
||||||
expression = expression.expression
|
expression = expression.expression
|
||||||
if isinstance(expression.left, Expression):
|
|
||||||
|
if isinstance(expression.left, Expression) or \
|
||||||
|
isinstance(expression.left, NegatedExpression):
|
||||||
result += f"({self.resolve_redisearch_query(expression.left)})"
|
result += f"({self.resolve_redisearch_query(expression.left)})"
|
||||||
elif isinstance(expression.left, ModelField):
|
elif isinstance(expression.left, ModelField):
|
||||||
field_type = self.resolve_field_type(expression.left)
|
field_type = self.resolve_field_type(expression.left)
|
||||||
field_name = expression.left.name
|
field_name = expression.left.name
|
||||||
else:
|
else:
|
||||||
raise QueryNotSupportedError(f"A query expression should start with either a field"
|
import ipdb; ipdb.set_trace()
|
||||||
|
raise QueryNotSupportedError(f"A query expression should start with either a field "
|
||||||
f"or an expression enclosed in parenthesis. See docs: "
|
f"or an expression enclosed in parenthesis. See docs: "
|
||||||
f"TODO")
|
f"TODO")
|
||||||
|
|
||||||
|
@ -217,7 +253,7 @@ class FindQuery:
|
||||||
else:
|
else:
|
||||||
result += self.resolve_value(field_name, field_type, expression.op, right)
|
result += self.resolve_value(field_name, field_type, expression.op, right)
|
||||||
|
|
||||||
if is_negated:
|
if encompassing_expression_is_negated:
|
||||||
result = f"-({result})"
|
result = f"-({result})"
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -29,11 +29,11 @@ class Order(BaseHashModel):
|
||||||
|
|
||||||
|
|
||||||
class Member(BaseHashModel):
|
class Member(BaseHashModel):
|
||||||
first_name: str
|
first_name: str = Field(index=True)
|
||||||
last_name: str
|
last_name: str = Field(index=True)
|
||||||
email: str = Field(index=True)
|
email: str = Field(index=True)
|
||||||
join_date: datetime.date
|
join_date: datetime.date
|
||||||
age: int
|
age: int = Field(index=True)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
model_key_prefix = "member"
|
model_key_prefix = "member"
|
||||||
|
@ -176,7 +176,7 @@ def test_exact_match_queries(members):
|
||||||
member1, member2, member3 = members
|
member1, member2, member3 = members
|
||||||
|
|
||||||
actual = Member.find(Member.last_name == "Brookins")
|
actual = Member.find(Member.last_name == "Brookins")
|
||||||
assert actual == sorted([member1, member2])
|
assert sorted(actual) == [member1, member2]
|
||||||
|
|
||||||
actual = Member.find(
|
actual = Member.find(
|
||||||
(Member.last_name == "Brookins") & ~(Member.first_name == "Andrew"))
|
(Member.last_name == "Brookins") & ~(Member.first_name == "Andrew"))
|
||||||
|
@ -198,6 +198,37 @@ def test_exact_match_queries(members):
|
||||||
assert actual == member2
|
assert actual == member2
|
||||||
|
|
||||||
|
|
||||||
|
def test_recursive_query_resolution(members):
|
||||||
|
member1, member2, member3 = members
|
||||||
|
|
||||||
|
actual = Member.find((Member.last_name == "Brookins") | (
|
||||||
|
Member.age == 100
|
||||||
|
) & (Member.last_name == "Smith"))
|
||||||
|
assert sorted(actual) == [member1, member2, member3]
|
||||||
|
|
||||||
|
|
||||||
|
def test_tag_queries_boolean_logic(members):
|
||||||
|
member1, member2, member3 = members
|
||||||
|
|
||||||
|
actual = Member.find(
|
||||||
|
(Member.first_name == "Andrew") &
|
||||||
|
(Member.last_name == "Brookins") | (Member.last_name == "Smith"))
|
||||||
|
assert sorted(actual) == [member1, member3]
|
||||||
|
|
||||||
|
|
||||||
|
def test_tag_queries_negation(members):
|
||||||
|
member1, member2, member3 = members
|
||||||
|
|
||||||
|
actual = Member.find(
|
||||||
|
~(Member.first_name == "Andrew") &
|
||||||
|
(Member.last_name == "Brookins") | (Member.last_name == "Smith"))
|
||||||
|
assert sorted(actual) == [member2, member3]
|
||||||
|
|
||||||
|
actual = Member.find(
|
||||||
|
(Member.first_name == "Andrew") & ~(Member.last_name == "Brookins"))
|
||||||
|
assert sorted(actual) == [member3]
|
||||||
|
|
||||||
|
|
||||||
def test_numeric_queries(members):
|
def test_numeric_queries(members):
|
||||||
member1, member2, member3 = members
|
member1, member2, member3 = members
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue