Add support for IN queries
This commit is contained in:
		
							parent
							
								
									bb08fb9eb5
								
							
						
					
					
						commit
						389a6ea878
					
				
					 3 changed files with 111 additions and 54 deletions
				
			
		| 
						 | 
					@ -4,7 +4,7 @@ import decimal
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
import operator
 | 
					import operator
 | 
				
			||||||
from copy import deepcopy
 | 
					from copy import deepcopy, copy
 | 
				
			||||||
from enum import Enum
 | 
					from enum import Enum
 | 
				
			||||||
from functools import reduce
 | 
					from functools import reduce
 | 
				
			||||||
from typing import (
 | 
					from typing import (
 | 
				
			||||||
| 
						 | 
					@ -31,7 +31,7 @@ from pydantic import BaseModel, validator
 | 
				
			||||||
from pydantic.fields import FieldInfo as PydanticFieldInfo
 | 
					from pydantic.fields import FieldInfo as PydanticFieldInfo
 | 
				
			||||||
from pydantic.fields import ModelField, Undefined, UndefinedType
 | 
					from pydantic.fields import ModelField, Undefined, UndefinedType
 | 
				
			||||||
from pydantic.main import ModelMetaclass
 | 
					from pydantic.main import ModelMetaclass
 | 
				
			||||||
from pydantic.typing import NoArgAnyCallable, resolve_annotations
 | 
					from pydantic.typing import NoArgAnyCallable
 | 
				
			||||||
from pydantic.utils import Representation
 | 
					from pydantic.utils import Representation
 | 
				
			||||||
from ulid import ULID
 | 
					from ulid import ULID
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -172,8 +172,8 @@ class NegatedExpression:
 | 
				
			||||||
@dataclasses.dataclass
 | 
					@dataclasses.dataclass
 | 
				
			||||||
class Expression:
 | 
					class Expression:
 | 
				
			||||||
    op: Operators
 | 
					    op: Operators
 | 
				
			||||||
    left: ExpressionOrModelField
 | 
					    left: Optional[ExpressionOrModelField]
 | 
				
			||||||
    right: ExpressionOrModelField
 | 
					    right: Optional[ExpressionOrModelField]
 | 
				
			||||||
    parents: List[Tuple[str, 'RedisModel']]
 | 
					    parents: List[Tuple[str, 'RedisModel']]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __invert__(self):
 | 
					    def __invert__(self):
 | 
				
			||||||
| 
						 | 
					@ -208,26 +208,28 @@ class ExpressionProxy:
 | 
				
			||||||
    def __ne__(self, other: Any) -> Expression:  # type: ignore[override]
 | 
					    def __ne__(self, other: Any) -> Expression:  # type: ignore[override]
 | 
				
			||||||
        return Expression(left=self.field, op=Operators.NE, right=other, parents=self.parents)
 | 
					        return Expression(left=self.field, op=Operators.NE, right=other, parents=self.parents)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __lt__(self, other: Any) -> Expression:  # type: ignore[override]
 | 
					    def __lt__(self, other: Any) -> Expression:
 | 
				
			||||||
        return Expression(left=self.field, op=Operators.LT, right=other, parents=self.parents)
 | 
					        return Expression(left=self.field, op=Operators.LT, right=other, parents=self.parents)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __le__(self, other: Any) -> Expression:  # type: ignore[override]
 | 
					    def __le__(self, other: Any) -> Expression:
 | 
				
			||||||
        return Expression(left=self.field, op=Operators.LE, right=other, parents=self.parents)
 | 
					        return Expression(left=self.field, op=Operators.LE, right=other, parents=self.parents)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __gt__(self, other: Any) -> Expression:  # type: ignore[override]
 | 
					    def __gt__(self, other: Any) -> Expression:
 | 
				
			||||||
        return Expression(left=self.field, op=Operators.GT, right=other, parents=self.parents)
 | 
					        return Expression(left=self.field, op=Operators.GT, right=other, parents=self.parents)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __ge__(self, other: Any) -> Expression:  # type: ignore[override]
 | 
					    def __ge__(self, other: Any) -> Expression:
 | 
				
			||||||
        return Expression(left=self.field, op=Operators.GE, right=other, parents=self.parents)
 | 
					        return Expression(left=self.field, op=Operators.GE, right=other, parents=self.parents)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __mod__(self, other: Any) -> Expression:  # type: ignore[override]
 | 
					    def __mod__(self, other: Any) -> Expression:
 | 
				
			||||||
        return Expression(left=self.field, op=Operators.LIKE, right=other, parents=self.parents)
 | 
					        return Expression(left=self.field, op=Operators.LIKE, right=other, parents=self.parents)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __lshift__(self, other: Any) -> Expression:
 | 
				
			||||||
 | 
					        return Expression(left=self.field, op=Operators.IN, right=other, parents=self.parents)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __getattr__(self, item):
 | 
					    def __getattr__(self, item):
 | 
				
			||||||
        if get_origin(self.field.outer_type_) == list:
 | 
					        if get_origin(self.field.outer_type_) == list:
 | 
				
			||||||
            embedded_cls = get_args(self.field.outer_type_)
 | 
					            embedded_cls = get_args(self.field.outer_type_)
 | 
				
			||||||
            if not embedded_cls:
 | 
					            if not embedded_cls:
 | 
				
			||||||
                # TODO: Is this even possible?
 | 
					 | 
				
			||||||
                raise QuerySyntaxError("In order to query on a list field, you must define "
 | 
					                raise QuerySyntaxError("In order to query on a list field, you must define "
 | 
				
			||||||
                                       "the contents of the list with a type annotation, like: "
 | 
					                                       "the contents of the list with a type annotation, like: "
 | 
				
			||||||
                                       "orders: List[Order]. Docs: TODO")
 | 
					                                       "orders: List[Order]. Docs: TODO")
 | 
				
			||||||
| 
						 | 
					@ -237,7 +239,7 @@ class ExpressionProxy:
 | 
				
			||||||
            attr = getattr(self.field.outer_type_, item)
 | 
					            attr = getattr(self.field.outer_type_, item)
 | 
				
			||||||
        if isinstance(attr, self.__class__):
 | 
					        if isinstance(attr, self.__class__):
 | 
				
			||||||
            new_parent = (self.field.name, self.field.outer_type_)
 | 
					            new_parent = (self.field.name, self.field.outer_type_)
 | 
				
			||||||
            if not new_parent in attr.parents:
 | 
					            if new_parent not in attr.parents:
 | 
				
			||||||
                attr.parents.append(new_parent)
 | 
					                attr.parents.append(new_parent)
 | 
				
			||||||
            new_parents = list(set(self.parents) - set(attr.parents))
 | 
					            new_parents = list(set(self.parents) - set(attr.parents))
 | 
				
			||||||
            if new_parents:
 | 
					            if new_parents:
 | 
				
			||||||
| 
						 | 
					@ -285,6 +287,21 @@ class FindQuery:
 | 
				
			||||||
        self._pagination: list[str] = []
 | 
					        self._pagination: list[str] = []
 | 
				
			||||||
        self._model_cache: list[RedisModel] = []
 | 
					        self._model_cache: list[RedisModel] = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def dict(self) -> dict[str, Any]:
 | 
				
			||||||
 | 
					        return dict(
 | 
				
			||||||
 | 
					            model=self.model,
 | 
				
			||||||
 | 
					            offset=self.offset,
 | 
				
			||||||
 | 
					            page_size=self.page_size,
 | 
				
			||||||
 | 
					            limit=self.limit,
 | 
				
			||||||
 | 
					            expressions=copy(self.expressions),
 | 
				
			||||||
 | 
					            sort_fields=copy(self.sort_fields)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def copy(self, **kwargs):
 | 
				
			||||||
 | 
					        original = self.dict()
 | 
				
			||||||
 | 
					        original.update(**kwargs)
 | 
				
			||||||
 | 
					        return FindQuery(**original)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def pagination(self):
 | 
					    def pagination(self):
 | 
				
			||||||
        if self._pagination:
 | 
					        if self._pagination:
 | 
				
			||||||
| 
						 | 
					@ -299,15 +316,21 @@ class FindQuery:
 | 
				
			||||||
        if self.expressions:
 | 
					        if self.expressions:
 | 
				
			||||||
            self._expression = reduce(operator.and_, self.expressions)
 | 
					            self._expression = reduce(operator.and_, self.expressions)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            # TODO: Is there a better way to support the "give me all records" query?
 | 
					            self._expression = Expression(left=None, right=None, op=Operators.ALL, parents=[])
 | 
				
			||||||
            # Also -- if we do it this way, we need different type annotations.
 | 
					 | 
				
			||||||
            self._expression = Expression(left=None, right=None, op=Operators.ALL,
 | 
					 | 
				
			||||||
                                          parents=[])
 | 
					 | 
				
			||||||
        return self._expression
 | 
					        return self._expression
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def query(self):
 | 
					    def query(self):
 | 
				
			||||||
        return self.resolve_redisearch_query(self.expression)
 | 
					        """
 | 
				
			||||||
 | 
					        Resolve and return the RediSearch query for this FindQuery.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        NOTE: We cache the resolved query string after generating it. This should be OK
 | 
				
			||||||
 | 
					        because all mutations of FindQuery through public APIs return a new FindQuery instance.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if self._query:
 | 
				
			||||||
 | 
					            return self._query
 | 
				
			||||||
 | 
					        self._query = self.resolve_redisearch_query(self.expression)
 | 
				
			||||||
 | 
					        return self._query
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def validate_sort_fields(self, sort_fields):
 | 
					    def validate_sort_fields(self, sort_fields):
 | 
				
			||||||
        for sort_field in sort_fields:
 | 
					        for sort_field in sort_fields:
 | 
				
			||||||
| 
						 | 
					@ -322,10 +345,10 @@ class FindQuery:
 | 
				
			||||||
        return sort_fields
 | 
					        return sort_fields
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    def resolve_field_type(field: ModelField, operator: Operators) -> RediSearchFieldTypes:
 | 
					    def resolve_field_type(field: ModelField, op: Operators) -> RediSearchFieldTypes:
 | 
				
			||||||
        if getattr(field.field_info, 'primary_key', None) is True:
 | 
					        if getattr(field.field_info, 'primary_key', None) is True:
 | 
				
			||||||
            return RediSearchFieldTypes.TAG
 | 
					            return RediSearchFieldTypes.TAG
 | 
				
			||||||
        elif operator is Operators.LIKE:
 | 
					        elif op is Operators.LIKE:
 | 
				
			||||||
            fts = getattr(field.field_info, 'full_text_search', None)
 | 
					            fts = getattr(field.field_info, 'full_text_search', None)
 | 
				
			||||||
            if fts is not True:  # Could be PydanticUndefined
 | 
					            if fts is not True:  # Could be PydanticUndefined
 | 
				
			||||||
                raise QuerySyntaxError(f"You tried to do a full-text search on the field '{field.name}', "
 | 
					                raise QuerySyntaxError(f"You tried to do a full-text search on the field '{field.name}', "
 | 
				
			||||||
| 
						 | 
					@ -346,7 +369,7 @@ class FindQuery:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    def expand_tag_value(value):
 | 
					    def expand_tag_value(value):
 | 
				
			||||||
        if isinstance(str, value):
 | 
					        if isinstance(value, str):
 | 
				
			||||||
            return value
 | 
					            return value
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            expanded_value = "|".join([escaper.escape(v) for v in value])
 | 
					            expanded_value = "|".join([escaper.escape(v) for v in value])
 | 
				
			||||||
| 
						 | 
					@ -380,8 +403,6 @@ class FindQuery:
 | 
				
			||||||
            if op is Operators.EQ:
 | 
					            if op is Operators.EQ:
 | 
				
			||||||
                result += f"@{field_name}:[{value} {value}]"
 | 
					                result += f"@{field_name}:[{value} {value}]"
 | 
				
			||||||
            elif op is Operators.NE:
 | 
					            elif op is Operators.NE:
 | 
				
			||||||
                # TODO: Is this enough or do we also need a clause for all values
 | 
					 | 
				
			||||||
                #  ([-inf +inf]) from which we then subtract the undesirable value?
 | 
					 | 
				
			||||||
                result += f"-(@{field_name}:[{value} {value}])"
 | 
					                result += f"-(@{field_name}:[{value} {value}])"
 | 
				
			||||||
            elif op is Operators.GT:
 | 
					            elif op is Operators.GT:
 | 
				
			||||||
                result += f"@{field_name}:[({value} +inf]"
 | 
					                result += f"@{field_name}:[({value} +inf]"
 | 
				
			||||||
| 
						 | 
					@ -410,7 +431,7 @@ class FindQuery:
 | 
				
			||||||
                    # The value contains the TAG field separator. We can work
 | 
					                    # The value contains the TAG field separator. We can work
 | 
				
			||||||
                    # around this by breaking apart the values and unioning them
 | 
					                    # around this by breaking apart the values and unioning them
 | 
				
			||||||
                    # with multiple field:{} queries.
 | 
					                    # with multiple field:{} queries.
 | 
				
			||||||
                    values = filter(None, value.split(separator_char))
 | 
					                    values: filter = filter(None, value.split(separator_char))
 | 
				
			||||||
                    for value in values:
 | 
					                    for value in values:
 | 
				
			||||||
                        value = escaper.escape(value)
 | 
					                        value = escaper.escape(value)
 | 
				
			||||||
                        result += f"@{field_name}:{{{value}}}"
 | 
					                        result += f"@{field_name}:{{{value}}}"
 | 
				
			||||||
| 
						 | 
					@ -448,7 +469,29 @@ class FindQuery:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
 | 
					    def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
 | 
				
			||||||
        """Resolve an expression to a string RediSearch query."""
 | 
					        """
 | 
				
			||||||
 | 
					        Resolve an arbitrarily deep expression into a single RediSearch query string.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        This method is complex. Note the following:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        1. This method makes a recursive call to itself when it finds that
 | 
				
			||||||
 | 
					           either the left or right operand contains another expression.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        2. An expression might be in a "negated" form, which means that the user
 | 
				
			||||||
 | 
					           gave us an expression like ~(Member.age == 30), or in other words,
 | 
				
			||||||
 | 
					           "Members whose age is NOT 30." Thus, a negated expression is one in
 | 
				
			||||||
 | 
					           which the meaning of an expression is inverted. If we find a negated
 | 
				
			||||||
 | 
					           expression, we need to add the appropriate "NOT" syntax but can
 | 
				
			||||||
 | 
					           otherwise use the resolved RediSearch query for the expression as-is.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        3. The final resolution of an expression should be a left operand that's
 | 
				
			||||||
 | 
					           a ModelField, an operator, and a right operand that's NOT a ModelField.
 | 
				
			||||||
 | 
					           With an IN or NOT_IN operator, the right operand can be a sequence
 | 
				
			||||||
 | 
					           type, but otherwise, sequence types are converted to strings.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        TODO: When the operator is not IN or NOT_IN, detect a sequence type (other
 | 
				
			||||||
 | 
					         than strings, which are allowed) and raise an exception.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
        field_type = None
 | 
					        field_type = None
 | 
				
			||||||
        field_name = None
 | 
					        field_name = None
 | 
				
			||||||
        field_info = None
 | 
					        field_info = None
 | 
				
			||||||
| 
						 | 
					@ -462,7 +505,7 @@ class FindQuery:
 | 
				
			||||||
        if expression.op is Operators.ALL:
 | 
					        if expression.op is Operators.ALL:
 | 
				
			||||||
            if encompassing_expression_is_negated:
 | 
					            if encompassing_expression_is_negated:
 | 
				
			||||||
                # TODO: Is there a use case for this, perhaps for dynamic
 | 
					                # TODO: Is there a use case for this, perhaps for dynamic
 | 
				
			||||||
                #  scoring purposes?
 | 
					                #  scoring purposes with full-text search?
 | 
				
			||||||
                raise QueryNotSupportedError("You cannot negate a query for all results.")
 | 
					                raise QueryNotSupportedError("You cannot negate a query for all results.")
 | 
				
			||||||
            return "*"
 | 
					            return "*"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -500,7 +543,13 @@ class FindQuery:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            result += f"({cls.resolve_redisearch_query(right)})"
 | 
					            result += f"({cls.resolve_redisearch_query(right)})"
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            if isinstance(right, ModelField):
 | 
					            if not field_name:
 | 
				
			||||||
 | 
					                raise QuerySyntaxError("Could not resolve field name. See docs: TODO")
 | 
				
			||||||
 | 
					            elif not field_type:
 | 
				
			||||||
 | 
					                raise QuerySyntaxError("Could not resolve field type. See docs: TODO")
 | 
				
			||||||
 | 
					            elif not field_info:
 | 
				
			||||||
 | 
					                raise QuerySyntaxError("Could not resolve field info. See docs: TODO")
 | 
				
			||||||
 | 
					            elif isinstance(right, ModelField):
 | 
				
			||||||
                raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO")
 | 
					                raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO")
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                result += cls.resolve_value(field_name, field_type, field_info,
 | 
					                result += cls.resolve_value(field_name, field_type, field_info,
 | 
				
			||||||
| 
						 | 
					@ -540,11 +589,7 @@ class FindQuery:
 | 
				
			||||||
        while True:
 | 
					        while True:
 | 
				
			||||||
            # Make a query for each pass of the loop, with a new offset equal to the
 | 
					            # Make a query for each pass of the loop, with a new offset equal to the
 | 
				
			||||||
            # current offset plus `page_size`, until we stop getting results back.
 | 
					            # current offset plus `page_size`, until we stop getting results back.
 | 
				
			||||||
            query = FindQuery(expressions=query.expressions,
 | 
					            query = query.copy(offset=query.offset + query.page_size)
 | 
				
			||||||
                              model=query.model,
 | 
					 | 
				
			||||||
                              offset=query.offset + query.page_size,
 | 
					 | 
				
			||||||
                              page_size=query.page_size,
 | 
					 | 
				
			||||||
                              limit=query.limit)
 | 
					 | 
				
			||||||
            _results = query.execute(exhaust_results=False)
 | 
					            _results = query.execute(exhaust_results=False)
 | 
				
			||||||
            if not _results:
 | 
					            if not _results:
 | 
				
			||||||
                break
 | 
					                break
 | 
				
			||||||
| 
						 | 
					@ -552,8 +597,7 @@ class FindQuery:
 | 
				
			||||||
        return self._model_cache
 | 
					        return self._model_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def first(self):
 | 
					    def first(self):
 | 
				
			||||||
        query = FindQuery(expressions=self.expressions, model=self.model,
 | 
					        query = self.copy(offset=0, limit=1, sort_fields=self.sort_fields)
 | 
				
			||||||
                          offset=0, limit=1, sort_fields=self.sort_fields)
 | 
					 | 
				
			||||||
        results = query.execute()
 | 
					        results = query.execute()
 | 
				
			||||||
        if not results:
 | 
					        if not results:
 | 
				
			||||||
            raise NotFoundError()
 | 
					            raise NotFoundError()
 | 
				
			||||||
| 
						 | 
					@ -561,26 +605,14 @@ class FindQuery:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def all(self, batch_size=10):
 | 
					    def all(self, batch_size=10):
 | 
				
			||||||
        if batch_size != self.page_size:
 | 
					        if batch_size != self.page_size:
 | 
				
			||||||
            # TODO: There's probably a copy-with-change mechanism in Pydantic,
 | 
					            query = self.copy(page_size=batch_size, limit=batch_size)
 | 
				
			||||||
            #  or can we use one from dataclasses?
 | 
					 | 
				
			||||||
            query = FindQuery(expressions=self.expressions,
 | 
					 | 
				
			||||||
                              model=self.model,
 | 
					 | 
				
			||||||
                              offset=self.offset,
 | 
					 | 
				
			||||||
                              page_size=batch_size,
 | 
					 | 
				
			||||||
                              limit=batch_size,
 | 
					 | 
				
			||||||
                              sort_fields=self.sort_fields)
 | 
					 | 
				
			||||||
            return query.execute()
 | 
					            return query.execute()
 | 
				
			||||||
        return self.execute()
 | 
					        return self.execute()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def sort_by(self, *fields: str):
 | 
					    def sort_by(self, *fields: str):
 | 
				
			||||||
        if not fields:
 | 
					        if not fields:
 | 
				
			||||||
            return self
 | 
					            return self
 | 
				
			||||||
        return FindQuery(expressions=self.expressions,
 | 
					        return self.copy(sort_fields=list(fields))
 | 
				
			||||||
                         model=self.model,
 | 
					 | 
				
			||||||
                         offset=self.offset,
 | 
					 | 
				
			||||||
                         page_size=self.page_size,
 | 
					 | 
				
			||||||
                         limit=self.limit,
 | 
					 | 
				
			||||||
                         sort_fields=list(fields))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def update(self, **kwargs):
 | 
					    def update(self, **kwargs):
 | 
				
			||||||
        """Update all matching records in this query."""
 | 
					        """Update all matching records in this query."""
 | 
				
			||||||
| 
						 | 
					@ -621,11 +653,7 @@ class FindQuery:
 | 
				
			||||||
        if self._model_cache and len(self._model_cache) >= item:
 | 
					        if self._model_cache and len(self._model_cache) >= item:
 | 
				
			||||||
            return self._model_cache[item]
 | 
					            return self._model_cache[item]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        query = FindQuery(expressions=self.expressions,
 | 
					        query = self.copy(offset=item, limit=1)
 | 
				
			||||||
                          model=self.model,
 | 
					 | 
				
			||||||
                          offset=item,
 | 
					 | 
				
			||||||
                          sort_fields=self.sort_fields,
 | 
					 | 
				
			||||||
                          limit=1)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return query.execute()[0]
 | 
					        return query.execute()[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -845,7 +873,6 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
 | 
				
			||||||
    pk: Optional[str] = Field(default=None, primary_key=True)
 | 
					    pk: Optional[str] = Field(default=None, primary_key=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Meta = DefaultMeta
 | 
					    Meta = DefaultMeta
 | 
				
			||||||
    # TODO: Missing _meta here is causing IDE warnings.
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    class Config:
 | 
					    class Config:
 | 
				
			||||||
        orm_mode = True
 | 
					        orm_mode = True
 | 
				
			||||||
| 
						 | 
					@ -899,7 +926,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
 | 
				
			||||||
        return cls._meta.database
 | 
					        return cls._meta.database
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def find(cls, *expressions: Union[Any, Expression]) -> FindQuery:  # TODO: How to type annotate this?
 | 
					    def find(cls, *expressions: Union[Any, Expression]) -> FindQuery:
 | 
				
			||||||
        return FindQuery(expressions=expressions, model=cls)
 | 
					        return FindQuery(expressions=expressions, model=cls)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -360,9 +360,15 @@ def test_numeric_queries(members):
 | 
				
			||||||
    actual = Member.find(Member.age >= 100).all()
 | 
					    actual = Member.find(Member.age >= 100).all()
 | 
				
			||||||
    assert actual == [member3]
 | 
					    assert actual == [member3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    actual = Member.find(Member.age != 34).all()
 | 
				
			||||||
 | 
					    assert actual == [member1, member3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    actual = Member.find(~(Member.age == 100)).all()
 | 
					    actual = Member.find(~(Member.age == 100)).all()
 | 
				
			||||||
    assert actual == [member1, member2]
 | 
					    assert actual == [member1, member2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    actual = Member.find(Member.age > 30, Member.age < 40).all()
 | 
				
			||||||
 | 
					    assert actual == [member1, member2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_sorting(members):
 | 
					def test_sorting(members):
 | 
				
			||||||
    member1, member2, member3 = members
 | 
					    member1, member2, member3 = members
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -12,7 +12,7 @@ from redis_developer.orm import (
 | 
				
			||||||
    JsonModel,
 | 
					    JsonModel,
 | 
				
			||||||
    Field,
 | 
					    Field,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from redis_developer.orm.model import RedisModelError, QueryNotSupportedError, NotFoundError, embedded
 | 
					from redis_developer.orm.model import QueryNotSupportedError, NotFoundError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
r = redis.Redis()
 | 
					r = redis.Redis()
 | 
				
			||||||
today = datetime.date.today()
 | 
					today = datetime.date.today()
 | 
				
			||||||
| 
						 | 
					@ -240,6 +240,24 @@ def test_access_result_by_index_not_cached(members):
 | 
				
			||||||
    assert query[2] == member3
 | 
					    assert query[2] == member3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_in_query(members):
 | 
				
			||||||
 | 
					    member1, member2, member3 = members
 | 
				
			||||||
 | 
					    actual = Member.find(Member.pk << [member1.pk, member2.pk, member3.pk]).all()
 | 
				
			||||||
 | 
					    assert actual == [member1, member2, member3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.skip("Not implemented yet")
 | 
				
			||||||
 | 
					def test_update_query(members):
 | 
				
			||||||
 | 
					    member1, member2, member3 = members
 | 
				
			||||||
 | 
					    Member.find(Member.pk << [member1.pk, member2.pk, member3.pk]).update(
 | 
				
			||||||
 | 
					        first_name="Bobby"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    actual = Member.find(
 | 
				
			||||||
 | 
					        Member.pk << [member1.pk, member2.pk, member3.pk]).sort_by('age').all()
 | 
				
			||||||
 | 
					    assert actual == [member1, member2, member3]
 | 
				
			||||||
 | 
					    assert all([m.name == "Bobby" for m in actual])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_exact_match_queries(members):
 | 
					def test_exact_match_queries(members):
 | 
				
			||||||
    member1, member2, member3 = members
 | 
					    member1, member2, member3 = members
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -458,6 +476,12 @@ def test_numeric_queries(members):
 | 
				
			||||||
    actual = Member.find(~(Member.age == 100)).all()
 | 
					    actual = Member.find(~(Member.age == 100)).all()
 | 
				
			||||||
    assert actual == [member1, member2]
 | 
					    assert actual == [member1, member2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    actual = Member.find(Member.age > 30, Member.age < 40).all()
 | 
				
			||||||
 | 
					    assert actual == [member1, member2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    actual = Member.find(Member.age != 34).all()
 | 
				
			||||||
 | 
					    assert actual == [member1, member3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_sorting(members):
 | 
					def test_sorting(members):
 | 
				
			||||||
    member1, member2, member3 = members
 | 
					    member1, member2, member3 = members
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue