From b7c9165bbd546c62d6f158e7731e9ea9a84cb464 Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Mon, 20 Sep 2021 16:06:04 -0700 Subject: [PATCH] Improve handling of TAG field queries --- redis_developer/orm/model.py | 54 ++++++++++++++++++++++++++++++------ tests/test_hash_model.py | 39 +++++++++++++++++++++++--- 2 files changed, 80 insertions(+), 13 deletions(-) diff --git a/redis_developer/orm/model.py b/redis_developer/orm/model.py index 4646a88..9c3151c 100644 --- a/redis_developer/orm/model.py +++ b/redis_developer/orm/model.py @@ -70,6 +70,12 @@ class NegatedExpression: def __invert__(self): return self.expression + def __and__(self, other): + return Expression(left=self, op=Operators.AND, right=other) + + def __or__(self, other): + return Expression(left=self, op=Operators.OR, right=other) + @dataclass class Expression: @@ -121,16 +127,31 @@ class FindQuery: self.pagination = self.resolve_redisearch_pagination() def resolve_field_type(self, field: ModelField) -> RediSearchFieldTypes: - if getattr(field.field_info, 'primary_key', None): + if getattr(field.field_info, 'primary_key', None) is True: return RediSearchFieldTypes.TAG + elif getattr(field.field_info, 'full_text_search', None) is True: + return RediSearchFieldTypes.TEXT + field_type = field.outer_type_ # TODO: GEO - # TODO: TAG (other than PK) if any(issubclass(field_type, t) for t in NUMERIC_TYPES): return RediSearchFieldTypes.NUMERIC else: - return RediSearchFieldTypes.TEXT + # TAG fields are the default field type. + return RediSearchFieldTypes.TAG + + @staticmethod + def expand_tag_value(value): + err = RedisModelError(f"Using the IN operator requires passing an iterable of " + "possible values. You passed: {value}") + if isinstance(str, value): + raise err + try: + expanded_value = "|".join(value) + except TypeError: + raise err + return expanded_value def resolve_value(self, field_name: str, field_type: RediSearchFieldTypes, op: Operators, value: Any) -> str: @@ -151,7 +172,8 @@ class FindQuery: if op is Operators.EQ: result += f"@{field_name}:[{value} {value}]" elif op is Operators.NE: - # TODO: Is this enough or do we also need a clause for all values ([-inf +inf])? + # TODO: Is this enough or do we also need a clause for all values + # ([-inf +inf]) from which we then subtract the undesirable value? result += f"~(@{field_name}:[{value} {value}])" elif op is Operators.GT: result += f"@{field_name}:[({value} +inf]" @@ -161,6 +183,17 @@ class FindQuery: result += f"@{field_name}:[{value} +inf]" elif op is Operators.LE: result += f"@{field_name}:[-inf {value}]" + elif field_type is RediSearchFieldTypes.TAG: + if op is Operators.EQ: + result += f"@{field_name}:{{{value}}}" + elif op is Operators.NE: + result += f"~(@{field_name}:{{{value}}})" + elif op is Operators.IN: + expanded_value = self.expand_tag_value(value) + result += f"(@{field_name}:{{{expanded_value}}})" + elif op is Operators.NOT_IN: + expanded_value = self.expand_tag_value(value) + result += f"~(@{field_name}:{{{expanded_value}}})" return result @@ -176,19 +209,22 @@ class FindQuery: """Resolve an expression to a string RediSearch query.""" field_type = None field_name = None - is_negated = False + encompassing_expression_is_negated = False result = "" if isinstance(expression, NegatedExpression): - is_negated = True + encompassing_expression_is_negated = True expression = expression.expression - if isinstance(expression.left, Expression): + + if isinstance(expression.left, Expression) or \ + isinstance(expression.left, NegatedExpression): result += f"({self.resolve_redisearch_query(expression.left)})" elif isinstance(expression.left, ModelField): field_type = self.resolve_field_type(expression.left) field_name = expression.left.name else: - raise QueryNotSupportedError(f"A query expression should start with either a field" + import ipdb; ipdb.set_trace() + raise QueryNotSupportedError(f"A query expression should start with either a field " f"or an expression enclosed in parenthesis. See docs: " f"TODO") @@ -217,7 +253,7 @@ class FindQuery: else: result += self.resolve_value(field_name, field_type, expression.op, right) - if is_negated: + if encompassing_expression_is_negated: result = f"-({result})" return result diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index e4c557f..92544aa 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -29,11 +29,11 @@ class Order(BaseHashModel): class Member(BaseHashModel): - first_name: str - last_name: str + first_name: str = Field(index=True) + last_name: str = Field(index=True) email: str = Field(index=True) join_date: datetime.date - age: int + age: int = Field(index=True) class Meta: model_key_prefix = "member" @@ -176,7 +176,7 @@ def test_exact_match_queries(members): member1, member2, member3 = members actual = Member.find(Member.last_name == "Brookins") - assert actual == sorted([member1, member2]) + assert sorted(actual) == [member1, member2] actual = Member.find( (Member.last_name == "Brookins") & ~(Member.first_name == "Andrew")) @@ -198,6 +198,37 @@ def test_exact_match_queries(members): assert actual == member2 +def test_recursive_query_resolution(members): + member1, member2, member3 = members + + actual = Member.find((Member.last_name == "Brookins") | ( + Member.age == 100 + ) & (Member.last_name == "Smith")) + assert sorted(actual) == [member1, member2, member3] + + +def test_tag_queries_boolean_logic(members): + member1, member2, member3 = members + + actual = Member.find( + (Member.first_name == "Andrew") & + (Member.last_name == "Brookins") | (Member.last_name == "Smith")) + assert sorted(actual) == [member1, member3] + + +def test_tag_queries_negation(members): + member1, member2, member3 = members + + actual = Member.find( + ~(Member.first_name == "Andrew") & + (Member.last_name == "Brookins") | (Member.last_name == "Smith")) + assert sorted(actual) == [member2, member3] + + actual = Member.find( + (Member.first_name == "Andrew") & ~(Member.last_name == "Brookins")) + assert sorted(actual) == [member3] + + def test_numeric_queries(members): member1, member2, member3 = members