Improve handling of TAG field queries

This commit is contained in:
Andrew Brookins 2021-09-20 16:06:04 -07:00
parent a788cbedbb
commit b7c9165bbd
2 changed files with 80 additions and 13 deletions

View file

@ -70,6 +70,12 @@ class NegatedExpression:
def __invert__(self): def __invert__(self):
return self.expression return self.expression
def __and__(self, other):
return Expression(left=self, op=Operators.AND, right=other)
def __or__(self, other):
return Expression(left=self, op=Operators.OR, right=other)
@dataclass @dataclass
class Expression: class Expression:
@ -121,16 +127,31 @@ class FindQuery:
self.pagination = self.resolve_redisearch_pagination() self.pagination = self.resolve_redisearch_pagination()
def resolve_field_type(self, field: ModelField) -> RediSearchFieldTypes: def resolve_field_type(self, field: ModelField) -> RediSearchFieldTypes:
if getattr(field.field_info, 'primary_key', None): if getattr(field.field_info, 'primary_key', None) is True:
return RediSearchFieldTypes.TAG return RediSearchFieldTypes.TAG
elif getattr(field.field_info, 'full_text_search', None) is True:
return RediSearchFieldTypes.TEXT
field_type = field.outer_type_ field_type = field.outer_type_
# TODO: GEO # TODO: GEO
# TODO: TAG (other than PK)
if any(issubclass(field_type, t) for t in NUMERIC_TYPES): if any(issubclass(field_type, t) for t in NUMERIC_TYPES):
return RediSearchFieldTypes.NUMERIC return RediSearchFieldTypes.NUMERIC
else: else:
return RediSearchFieldTypes.TEXT # TAG fields are the default field type.
return RediSearchFieldTypes.TAG
@staticmethod
def expand_tag_value(value):
err = RedisModelError(f"Using the IN operator requires passing an iterable of "
"possible values. You passed: {value}")
if isinstance(str, value):
raise err
try:
expanded_value = "|".join(value)
except TypeError:
raise err
return expanded_value
def resolve_value(self, field_name: str, field_type: RediSearchFieldTypes, def resolve_value(self, field_name: str, field_type: RediSearchFieldTypes,
op: Operators, value: Any) -> str: op: Operators, value: Any) -> str:
@ -151,7 +172,8 @@ class FindQuery:
if op is Operators.EQ: if op is Operators.EQ:
result += f"@{field_name}:[{value} {value}]" result += f"@{field_name}:[{value} {value}]"
elif op is Operators.NE: elif op is Operators.NE:
# TODO: Is this enough or do we also need a clause for all values ([-inf +inf])? # TODO: Is this enough or do we also need a clause for all values
# ([-inf +inf]) from which we then subtract the undesirable value?
result += f"~(@{field_name}:[{value} {value}])" result += f"~(@{field_name}:[{value} {value}])"
elif op is Operators.GT: elif op is Operators.GT:
result += f"@{field_name}:[({value} +inf]" result += f"@{field_name}:[({value} +inf]"
@ -161,6 +183,17 @@ class FindQuery:
result += f"@{field_name}:[{value} +inf]" result += f"@{field_name}:[{value} +inf]"
elif op is Operators.LE: elif op is Operators.LE:
result += f"@{field_name}:[-inf {value}]" result += f"@{field_name}:[-inf {value}]"
elif field_type is RediSearchFieldTypes.TAG:
if op is Operators.EQ:
result += f"@{field_name}:{{{value}}}"
elif op is Operators.NE:
result += f"~(@{field_name}:{{{value}}})"
elif op is Operators.IN:
expanded_value = self.expand_tag_value(value)
result += f"(@{field_name}:{{{expanded_value}}})"
elif op is Operators.NOT_IN:
expanded_value = self.expand_tag_value(value)
result += f"~(@{field_name}:{{{expanded_value}}})"
return result return result
@ -176,18 +209,21 @@ class FindQuery:
"""Resolve an expression to a string RediSearch query.""" """Resolve an expression to a string RediSearch query."""
field_type = None field_type = None
field_name = None field_name = None
is_negated = False encompassing_expression_is_negated = False
result = "" result = ""
if isinstance(expression, NegatedExpression): if isinstance(expression, NegatedExpression):
is_negated = True encompassing_expression_is_negated = True
expression = expression.expression expression = expression.expression
if isinstance(expression.left, Expression):
if isinstance(expression.left, Expression) or \
isinstance(expression.left, NegatedExpression):
result += f"({self.resolve_redisearch_query(expression.left)})" result += f"({self.resolve_redisearch_query(expression.left)})"
elif isinstance(expression.left, ModelField): elif isinstance(expression.left, ModelField):
field_type = self.resolve_field_type(expression.left) field_type = self.resolve_field_type(expression.left)
field_name = expression.left.name field_name = expression.left.name
else: else:
import ipdb; ipdb.set_trace()
raise QueryNotSupportedError(f"A query expression should start with either a field " raise QueryNotSupportedError(f"A query expression should start with either a field "
f"or an expression enclosed in parenthesis. See docs: " f"or an expression enclosed in parenthesis. See docs: "
f"TODO") f"TODO")
@ -217,7 +253,7 @@ class FindQuery:
else: else:
result += self.resolve_value(field_name, field_type, expression.op, right) result += self.resolve_value(field_name, field_type, expression.op, right)
if is_negated: if encompassing_expression_is_negated:
result = f"-({result})" result = f"-({result})"
return result return result

View file

@ -29,11 +29,11 @@ class Order(BaseHashModel):
class Member(BaseHashModel): class Member(BaseHashModel):
first_name: str first_name: str = Field(index=True)
last_name: str last_name: str = Field(index=True)
email: str = Field(index=True) email: str = Field(index=True)
join_date: datetime.date join_date: datetime.date
age: int age: int = Field(index=True)
class Meta: class Meta:
model_key_prefix = "member" model_key_prefix = "member"
@ -176,7 +176,7 @@ def test_exact_match_queries(members):
member1, member2, member3 = members member1, member2, member3 = members
actual = Member.find(Member.last_name == "Brookins") actual = Member.find(Member.last_name == "Brookins")
assert actual == sorted([member1, member2]) assert sorted(actual) == [member1, member2]
actual = Member.find( actual = Member.find(
(Member.last_name == "Brookins") & ~(Member.first_name == "Andrew")) (Member.last_name == "Brookins") & ~(Member.first_name == "Andrew"))
@ -198,6 +198,37 @@ def test_exact_match_queries(members):
assert actual == member2 assert actual == member2
def test_recursive_query_resolution(members):
member1, member2, member3 = members
actual = Member.find((Member.last_name == "Brookins") | (
Member.age == 100
) & (Member.last_name == "Smith"))
assert sorted(actual) == [member1, member2, member3]
def test_tag_queries_boolean_logic(members):
member1, member2, member3 = members
actual = Member.find(
(Member.first_name == "Andrew") &
(Member.last_name == "Brookins") | (Member.last_name == "Smith"))
assert sorted(actual) == [member1, member3]
def test_tag_queries_negation(members):
member1, member2, member3 = members
actual = Member.find(
~(Member.first_name == "Andrew") &
(Member.last_name == "Brookins") | (Member.last_name == "Smith"))
assert sorted(actual) == [member2, member3]
actual = Member.find(
(Member.first_name == "Andrew") & ~(Member.last_name == "Brookins"))
assert sorted(actual) == [member3]
def test_numeric_queries(members): def test_numeric_queries(members):
member1, member2, member3 = members member1, member2, member3 = members