Improve handling of TAG field queries

This commit is contained in:
Andrew Brookins 2021-09-20 16:06:04 -07:00
parent a788cbedbb
commit b7c9165bbd
2 changed files with 80 additions and 13 deletions

View file

@ -70,6 +70,12 @@ class NegatedExpression:
def __invert__(self):
return self.expression
def __and__(self, other):
return Expression(left=self, op=Operators.AND, right=other)
def __or__(self, other):
return Expression(left=self, op=Operators.OR, right=other)
@dataclass
class Expression:
@ -121,16 +127,31 @@ class FindQuery:
self.pagination = self.resolve_redisearch_pagination()
def resolve_field_type(self, field: ModelField) -> RediSearchFieldTypes:
if getattr(field.field_info, 'primary_key', None):
if getattr(field.field_info, 'primary_key', None) is True:
return RediSearchFieldTypes.TAG
elif getattr(field.field_info, 'full_text_search', None) is True:
return RediSearchFieldTypes.TEXT
field_type = field.outer_type_
# TODO: GEO
# TODO: TAG (other than PK)
if any(issubclass(field_type, t) for t in NUMERIC_TYPES):
return RediSearchFieldTypes.NUMERIC
else:
return RediSearchFieldTypes.TEXT
# TAG fields are the default field type.
return RediSearchFieldTypes.TAG
@staticmethod
def expand_tag_value(value):
err = RedisModelError(f"Using the IN operator requires passing an iterable of "
"possible values. You passed: {value}")
if isinstance(str, value):
raise err
try:
expanded_value = "|".join(value)
except TypeError:
raise err
return expanded_value
def resolve_value(self, field_name: str, field_type: RediSearchFieldTypes,
op: Operators, value: Any) -> str:
@ -151,7 +172,8 @@ class FindQuery:
if op is Operators.EQ:
result += f"@{field_name}:[{value} {value}]"
elif op is Operators.NE:
# TODO: Is this enough or do we also need a clause for all values ([-inf +inf])?
# TODO: Is this enough or do we also need a clause for all values
# ([-inf +inf]) from which we then subtract the undesirable value?
result += f"~(@{field_name}:[{value} {value}])"
elif op is Operators.GT:
result += f"@{field_name}:[({value} +inf]"
@ -161,6 +183,17 @@ class FindQuery:
result += f"@{field_name}:[{value} +inf]"
elif op is Operators.LE:
result += f"@{field_name}:[-inf {value}]"
elif field_type is RediSearchFieldTypes.TAG:
if op is Operators.EQ:
result += f"@{field_name}:{{{value}}}"
elif op is Operators.NE:
result += f"~(@{field_name}:{{{value}}})"
elif op is Operators.IN:
expanded_value = self.expand_tag_value(value)
result += f"(@{field_name}:{{{expanded_value}}})"
elif op is Operators.NOT_IN:
expanded_value = self.expand_tag_value(value)
result += f"~(@{field_name}:{{{expanded_value}}})"
return result
@ -176,19 +209,22 @@ class FindQuery:
"""Resolve an expression to a string RediSearch query."""
field_type = None
field_name = None
is_negated = False
encompassing_expression_is_negated = False
result = ""
if isinstance(expression, NegatedExpression):
is_negated = True
encompassing_expression_is_negated = True
expression = expression.expression
if isinstance(expression.left, Expression):
if isinstance(expression.left, Expression) or \
isinstance(expression.left, NegatedExpression):
result += f"({self.resolve_redisearch_query(expression.left)})"
elif isinstance(expression.left, ModelField):
field_type = self.resolve_field_type(expression.left)
field_name = expression.left.name
else:
raise QueryNotSupportedError(f"A query expression should start with either a field"
import ipdb; ipdb.set_trace()
raise QueryNotSupportedError(f"A query expression should start with either a field "
f"or an expression enclosed in parenthesis. See docs: "
f"TODO")
@ -217,7 +253,7 @@ class FindQuery:
else:
result += self.resolve_value(field_name, field_type, expression.op, right)
if is_negated:
if encompassing_expression_is_negated:
result = f"-({result})"
return result

View file

@ -29,11 +29,11 @@ class Order(BaseHashModel):
class Member(BaseHashModel):
first_name: str
last_name: str
first_name: str = Field(index=True)
last_name: str = Field(index=True)
email: str = Field(index=True)
join_date: datetime.date
age: int
age: int = Field(index=True)
class Meta:
model_key_prefix = "member"
@ -176,7 +176,7 @@ def test_exact_match_queries(members):
member1, member2, member3 = members
actual = Member.find(Member.last_name == "Brookins")
assert actual == sorted([member1, member2])
assert sorted(actual) == [member1, member2]
actual = Member.find(
(Member.last_name == "Brookins") & ~(Member.first_name == "Andrew"))
@ -198,6 +198,37 @@ def test_exact_match_queries(members):
assert actual == member2
def test_recursive_query_resolution(members):
member1, member2, member3 = members
actual = Member.find((Member.last_name == "Brookins") | (
Member.age == 100
) & (Member.last_name == "Smith"))
assert sorted(actual) == [member1, member2, member3]
def test_tag_queries_boolean_logic(members):
member1, member2, member3 = members
actual = Member.find(
(Member.first_name == "Andrew") &
(Member.last_name == "Brookins") | (Member.last_name == "Smith"))
assert sorted(actual) == [member1, member3]
def test_tag_queries_negation(members):
member1, member2, member3 = members
actual = Member.find(
~(Member.first_name == "Andrew") &
(Member.last_name == "Brookins") | (Member.last_name == "Smith"))
assert sorted(actual) == [member2, member3]
actual = Member.find(
(Member.first_name == "Andrew") & ~(Member.last_name == "Brookins"))
assert sorted(actual) == [member3]
def test_numeric_queries(members):
member1, member2, member3 = members