Improve handling of TAG field queries
This commit is contained in:
		
							parent
							
								
									a788cbedbb
								
							
						
					
					
						commit
						b7c9165bbd
					
				
					 2 changed files with 80 additions and 13 deletions
				
			
		| 
						 | 
					@ -70,6 +70,12 @@ class NegatedExpression:
 | 
				
			||||||
    def __invert__(self):
 | 
					    def __invert__(self):
 | 
				
			||||||
        return self.expression
 | 
					        return self.expression
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __and__(self, other):
 | 
				
			||||||
 | 
					        return Expression(left=self, op=Operators.AND, right=other)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __or__(self, other):
 | 
				
			||||||
 | 
					        return Expression(left=self, op=Operators.OR, right=other)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@dataclass
 | 
					@dataclass
 | 
				
			||||||
class Expression:
 | 
					class Expression:
 | 
				
			||||||
| 
						 | 
					@ -121,16 +127,31 @@ class FindQuery:
 | 
				
			||||||
        self.pagination = self.resolve_redisearch_pagination()
 | 
					        self.pagination = self.resolve_redisearch_pagination()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def resolve_field_type(self, field: ModelField) -> RediSearchFieldTypes:
 | 
					    def resolve_field_type(self, field: ModelField) -> RediSearchFieldTypes:
 | 
				
			||||||
        if getattr(field.field_info, 'primary_key', None):
 | 
					        if getattr(field.field_info, 'primary_key', None) is True:
 | 
				
			||||||
            return RediSearchFieldTypes.TAG
 | 
					            return RediSearchFieldTypes.TAG
 | 
				
			||||||
 | 
					        elif getattr(field.field_info, 'full_text_search', None) is True:
 | 
				
			||||||
 | 
					            return RediSearchFieldTypes.TEXT
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        field_type = field.outer_type_
 | 
					        field_type = field.outer_type_
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # TODO: GEO
 | 
					        # TODO: GEO
 | 
				
			||||||
        # TODO: TAG (other than PK)
 | 
					 | 
				
			||||||
        if any(issubclass(field_type, t) for t in NUMERIC_TYPES):
 | 
					        if any(issubclass(field_type, t) for t in NUMERIC_TYPES):
 | 
				
			||||||
            return RediSearchFieldTypes.NUMERIC
 | 
					            return RediSearchFieldTypes.NUMERIC
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            return RediSearchFieldTypes.TEXT
 | 
					            # TAG fields are the default field type.
 | 
				
			||||||
 | 
					            return RediSearchFieldTypes.TAG
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def expand_tag_value(value):
 | 
				
			||||||
 | 
					        err = RedisModelError(f"Using the IN operator requires passing an iterable of "
 | 
				
			||||||
 | 
					                              "possible values. You passed: {value}")
 | 
				
			||||||
 | 
					        if isinstance(str, value):
 | 
				
			||||||
 | 
					            raise err
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            expanded_value = "|".join(value)
 | 
				
			||||||
 | 
					        except TypeError:
 | 
				
			||||||
 | 
					            raise err
 | 
				
			||||||
 | 
					        return expanded_value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def resolve_value(self, field_name: str, field_type: RediSearchFieldTypes,
 | 
					    def resolve_value(self, field_name: str, field_type: RediSearchFieldTypes,
 | 
				
			||||||
                      op: Operators, value: Any) -> str:
 | 
					                      op: Operators, value: Any) -> str:
 | 
				
			||||||
| 
						 | 
					@ -151,7 +172,8 @@ class FindQuery:
 | 
				
			||||||
            if op is Operators.EQ:
 | 
					            if op is Operators.EQ:
 | 
				
			||||||
                result += f"@{field_name}:[{value} {value}]"
 | 
					                result += f"@{field_name}:[{value} {value}]"
 | 
				
			||||||
            elif op is Operators.NE:
 | 
					            elif op is Operators.NE:
 | 
				
			||||||
                # TODO: Is this enough or do we also need a clause for all values ([-inf +inf])?
 | 
					                # TODO: Is this enough or do we also need a clause for all values
 | 
				
			||||||
 | 
					                #  ([-inf +inf]) from which we then subtract the undesirable value?
 | 
				
			||||||
                result += f"~(@{field_name}:[{value} {value}])"
 | 
					                result += f"~(@{field_name}:[{value} {value}])"
 | 
				
			||||||
            elif op is Operators.GT:
 | 
					            elif op is Operators.GT:
 | 
				
			||||||
                result += f"@{field_name}:[({value} +inf]"
 | 
					                result += f"@{field_name}:[({value} +inf]"
 | 
				
			||||||
| 
						 | 
					@ -161,6 +183,17 @@ class FindQuery:
 | 
				
			||||||
                result += f"@{field_name}:[{value} +inf]"
 | 
					                result += f"@{field_name}:[{value} +inf]"
 | 
				
			||||||
            elif op is Operators.LE:
 | 
					            elif op is Operators.LE:
 | 
				
			||||||
                result += f"@{field_name}:[-inf {value}]"
 | 
					                result += f"@{field_name}:[-inf {value}]"
 | 
				
			||||||
 | 
					        elif field_type is RediSearchFieldTypes.TAG:
 | 
				
			||||||
 | 
					            if op is Operators.EQ:
 | 
				
			||||||
 | 
					                result += f"@{field_name}:{{{value}}}"
 | 
				
			||||||
 | 
					            elif op is Operators.NE:
 | 
				
			||||||
 | 
					                result += f"~(@{field_name}:{{{value}}})"
 | 
				
			||||||
 | 
					            elif op is Operators.IN:
 | 
				
			||||||
 | 
					                expanded_value = self.expand_tag_value(value)
 | 
				
			||||||
 | 
					                result += f"(@{field_name}:{{{expanded_value}}})"
 | 
				
			||||||
 | 
					            elif op is Operators.NOT_IN:
 | 
				
			||||||
 | 
					                expanded_value = self.expand_tag_value(value)
 | 
				
			||||||
 | 
					                result += f"~(@{field_name}:{{{expanded_value}}})"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return result
 | 
					        return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -176,19 +209,22 @@ class FindQuery:
 | 
				
			||||||
        """Resolve an expression to a string RediSearch query."""
 | 
					        """Resolve an expression to a string RediSearch query."""
 | 
				
			||||||
        field_type = None
 | 
					        field_type = None
 | 
				
			||||||
        field_name = None
 | 
					        field_name = None
 | 
				
			||||||
        is_negated = False
 | 
					        encompassing_expression_is_negated = False
 | 
				
			||||||
        result = ""
 | 
					        result = ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if isinstance(expression, NegatedExpression):
 | 
					        if isinstance(expression, NegatedExpression):
 | 
				
			||||||
            is_negated = True
 | 
					            encompassing_expression_is_negated = True
 | 
				
			||||||
            expression = expression.expression
 | 
					            expression = expression.expression
 | 
				
			||||||
        if isinstance(expression.left, Expression):
 | 
					
 | 
				
			||||||
 | 
					        if isinstance(expression.left, Expression) or \
 | 
				
			||||||
 | 
					                isinstance(expression.left, NegatedExpression):
 | 
				
			||||||
            result += f"({self.resolve_redisearch_query(expression.left)})"
 | 
					            result += f"({self.resolve_redisearch_query(expression.left)})"
 | 
				
			||||||
        elif isinstance(expression.left, ModelField):
 | 
					        elif isinstance(expression.left, ModelField):
 | 
				
			||||||
            field_type = self.resolve_field_type(expression.left)
 | 
					            field_type = self.resolve_field_type(expression.left)
 | 
				
			||||||
            field_name = expression.left.name
 | 
					            field_name = expression.left.name
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            raise QueryNotSupportedError(f"A query expression should start with either a field"
 | 
					            import ipdb; ipdb.set_trace()
 | 
				
			||||||
 | 
					            raise QueryNotSupportedError(f"A query expression should start with either a field "
 | 
				
			||||||
                                         f"or an expression enclosed in parenthesis. See docs: "
 | 
					                                         f"or an expression enclosed in parenthesis. See docs: "
 | 
				
			||||||
                                         f"TODO")
 | 
					                                         f"TODO")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -217,7 +253,7 @@ class FindQuery:
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                result += self.resolve_value(field_name, field_type, expression.op, right)
 | 
					                result += self.resolve_value(field_name, field_type, expression.op, right)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if is_negated:
 | 
					        if encompassing_expression_is_negated:
 | 
				
			||||||
            result = f"-({result})"
 | 
					            result = f"-({result})"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return result
 | 
					        return result
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -29,11 +29,11 @@ class Order(BaseHashModel):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Member(BaseHashModel):
 | 
					class Member(BaseHashModel):
 | 
				
			||||||
    first_name: str
 | 
					    first_name: str = Field(index=True)
 | 
				
			||||||
    last_name: str
 | 
					    last_name: str = Field(index=True)
 | 
				
			||||||
    email: str = Field(index=True)
 | 
					    email: str = Field(index=True)
 | 
				
			||||||
    join_date: datetime.date
 | 
					    join_date: datetime.date
 | 
				
			||||||
    age: int
 | 
					    age: int = Field(index=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    class Meta:
 | 
					    class Meta:
 | 
				
			||||||
        model_key_prefix = "member"
 | 
					        model_key_prefix = "member"
 | 
				
			||||||
| 
						 | 
					@ -176,7 +176,7 @@ def test_exact_match_queries(members):
 | 
				
			||||||
    member1, member2, member3 = members
 | 
					    member1, member2, member3 = members
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    actual = Member.find(Member.last_name == "Brookins")
 | 
					    actual = Member.find(Member.last_name == "Brookins")
 | 
				
			||||||
    assert actual == sorted([member1, member2])
 | 
					    assert sorted(actual) == [member1, member2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    actual = Member.find(
 | 
					    actual = Member.find(
 | 
				
			||||||
        (Member.last_name == "Brookins") & ~(Member.first_name == "Andrew"))
 | 
					        (Member.last_name == "Brookins") & ~(Member.first_name == "Andrew"))
 | 
				
			||||||
| 
						 | 
					@ -198,6 +198,37 @@ def test_exact_match_queries(members):
 | 
				
			||||||
    assert actual == member2
 | 
					    assert actual == member2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_recursive_query_resolution(members):
 | 
				
			||||||
 | 
					    member1, member2, member3 = members
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    actual = Member.find((Member.last_name == "Brookins") | (
 | 
				
			||||||
 | 
					        Member.age == 100
 | 
				
			||||||
 | 
					    ) & (Member.last_name == "Smith"))
 | 
				
			||||||
 | 
					    assert sorted(actual) == [member1, member2, member3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_tag_queries_boolean_logic(members):
 | 
				
			||||||
 | 
					    member1, member2, member3 = members
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    actual = Member.find(
 | 
				
			||||||
 | 
					        (Member.first_name == "Andrew") &
 | 
				
			||||||
 | 
					        (Member.last_name == "Brookins") | (Member.last_name == "Smith"))
 | 
				
			||||||
 | 
					    assert sorted(actual) == [member1, member3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_tag_queries_negation(members):
 | 
				
			||||||
 | 
					    member1, member2, member3 = members
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    actual = Member.find(
 | 
				
			||||||
 | 
					        ~(Member.first_name == "Andrew") &
 | 
				
			||||||
 | 
					        (Member.last_name == "Brookins") | (Member.last_name == "Smith"))
 | 
				
			||||||
 | 
					    assert sorted(actual) == [member2, member3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    actual = Member.find(
 | 
				
			||||||
 | 
					        (Member.first_name == "Andrew") & ~(Member.last_name == "Brookins"))
 | 
				
			||||||
 | 
					    assert sorted(actual) == [member3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_numeric_queries(members):
 | 
					def test_numeric_queries(members):
 | 
				
			||||||
    member1, member2, member3 = members
 | 
					    member1, member2, member3 = members
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue