WIP on pagination, sorting all()/first() methods
This commit is contained in:
		
							parent
							
								
									3ba45b7c87
								
							
						
					
					
						commit
						01cab5352b
					
				
					 4 changed files with 473 additions and 294 deletions
				
			
		| 
						 | 
				
			
			@ -2,8 +2,8 @@ import abc
 | 
			
		|||
import dataclasses
 | 
			
		||||
import decimal
 | 
			
		||||
import operator
 | 
			
		||||
from copy import copy, deepcopy
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
import re
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from functools import reduce
 | 
			
		||||
from typing import (
 | 
			
		||||
| 
						 | 
				
			
			@ -20,13 +20,15 @@ from typing import (
 | 
			
		|||
    Sequence,
 | 
			
		||||
    no_type_check,
 | 
			
		||||
    Protocol,
 | 
			
		||||
    List, Type
 | 
			
		||||
    List,
 | 
			
		||||
    Type,
 | 
			
		||||
    Pattern
 | 
			
		||||
)
 | 
			
		||||
import uuid
 | 
			
		||||
 | 
			
		||||
import redis
 | 
			
		||||
from pydantic import BaseModel, validator
 | 
			
		||||
from pydantic.fields import FieldInfo as PydanticFieldInfo
 | 
			
		||||
from pydantic.fields import FieldInfo as PydanticFieldInfo, PrivateAttr, Field
 | 
			
		||||
from pydantic.fields import ModelField, Undefined, UndefinedType
 | 
			
		||||
from pydantic.main import ModelMetaclass
 | 
			
		||||
from pydantic.typing import NoArgAnyCallable
 | 
			
		||||
| 
						 | 
				
			
			@ -34,18 +36,47 @@ from pydantic.utils import Representation
 | 
			
		|||
 | 
			
		||||
from .encoders import jsonable_encoder
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
model_registry = {}
 | 
			
		||||
 | 
			
		||||
_T = TypeVar("_T")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TokenEscaper:
 | 
			
		||||
    """
 | 
			
		||||
    Escape punctuation within an input string.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # Characters that RediSearch requires us to escape during queries.
 | 
			
		||||
    # Source: https://oss.redis.com/redisearch/Escaping/#the_rules_of_text_field_tokenization
 | 
			
		||||
    DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\ ]"
 | 
			
		||||
 | 
			
		||||
    def __init__(self, escape_chars_re: Optional[Pattern] = None):
 | 
			
		||||
        if escape_chars_re:
 | 
			
		||||
            self.escaped_chars_re = escape_chars_re
 | 
			
		||||
        else:
 | 
			
		||||
            self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS)
 | 
			
		||||
 | 
			
		||||
    def escape(self, string):
 | 
			
		||||
        def escape_symbol(match):
 | 
			
		||||
            value = match.group(0)
 | 
			
		||||
            return f"\\{value}"
 | 
			
		||||
 | 
			
		||||
        return self.escaped_chars_re.sub(escape_symbol, string)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
escaper = TokenEscaper()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RedisModelError(Exception):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NotFoundError(Exception):
 | 
			
		||||
    pass
 | 
			
		||||
    """
 | 
			
		||||
    A query found no results.
 | 
			
		||||
 | 
			
		||||
    TODO: embed in Model class?
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Operators(Enum):
 | 
			
		||||
| 
						 | 
				
			
			@ -61,9 +92,10 @@ class Operators(Enum):
 | 
			
		|||
    IN = 10
 | 
			
		||||
    NOT_IN = 11
 | 
			
		||||
    LIKE = 12
 | 
			
		||||
    ALL = 13
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class NegatedExpression:
 | 
			
		||||
    expression: 'Expression'
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -77,7 +109,7 @@ class NegatedExpression:
 | 
			
		|||
        return Expression(left=self, op=Operators.OR, right=other)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class Expression:
 | 
			
		||||
    op: Operators
 | 
			
		||||
    left: Any
 | 
			
		||||
| 
						 | 
				
			
			@ -96,186 +128,6 @@ class Expression:
 | 
			
		|||
ExpressionOrNegated = Union[Expression, NegatedExpression]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class QueryNotSupportedError(Exception):
 | 
			
		||||
    """The attempted query is not supported."""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RediSearchFieldTypes(Enum):
 | 
			
		||||
    TEXT = 'TEXT'
 | 
			
		||||
    TAG = 'TAG'
 | 
			
		||||
    NUMERIC = 'NUMERIC'
 | 
			
		||||
    GEO = 'GEO'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO: How to handle Geo fields?
 | 
			
		||||
NUMERIC_TYPES = (float, int, decimal.Decimal)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class FindQuery:
 | 
			
		||||
    expressions: Sequence[Expression]
 | 
			
		||||
    expression: Expression = dataclasses.field(init=False)
 | 
			
		||||
    query: str = dataclasses.field(init=False)
 | 
			
		||||
    pagination: List[str] = dataclasses.field(init=False)
 | 
			
		||||
    model: Type['RedisModel']
 | 
			
		||||
    limit: Optional[int] = None
 | 
			
		||||
    offset: Optional[int] = None
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        self.expression = reduce(operator.and_, self.expressions)
 | 
			
		||||
        self.query = self.resolve_redisearch_query(self.expression)
 | 
			
		||||
        self.pagination = self.resolve_redisearch_pagination()
 | 
			
		||||
 | 
			
		||||
    def resolve_field_type(self, field: ModelField) -> RediSearchFieldTypes:
 | 
			
		||||
        if getattr(field.field_info, 'primary_key', None) is True:
 | 
			
		||||
            return RediSearchFieldTypes.TAG
 | 
			
		||||
        elif getattr(field.field_info, 'full_text_search', None) is True:
 | 
			
		||||
            return RediSearchFieldTypes.TEXT
 | 
			
		||||
 | 
			
		||||
        field_type = field.outer_type_
 | 
			
		||||
 | 
			
		||||
        # TODO: GEO
 | 
			
		||||
        if any(issubclass(field_type, t) for t in NUMERIC_TYPES):
 | 
			
		||||
            return RediSearchFieldTypes.NUMERIC
 | 
			
		||||
        else:
 | 
			
		||||
            # TAG fields are the default field type.
 | 
			
		||||
            return RediSearchFieldTypes.TAG
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def expand_tag_value(value):
 | 
			
		||||
        err = RedisModelError(f"Using the IN operator requires passing an iterable of "
 | 
			
		||||
                              "possible values. You passed: {value}")
 | 
			
		||||
        if isinstance(str, value):
 | 
			
		||||
            raise err
 | 
			
		||||
        try:
 | 
			
		||||
            expanded_value = "|".join(value)
 | 
			
		||||
        except TypeError:
 | 
			
		||||
            raise err
 | 
			
		||||
        return expanded_value
 | 
			
		||||
 | 
			
		||||
    def resolve_value(self, field_name: str, field_type: RediSearchFieldTypes,
 | 
			
		||||
                      op: Operators, value: Any) -> str:
 | 
			
		||||
        result = ""
 | 
			
		||||
        if field_type is RediSearchFieldTypes.TEXT:
 | 
			
		||||
            result = f"@{field_name}:"
 | 
			
		||||
            if op is Operators.EQ:
 | 
			
		||||
                result += f'"{value}"'
 | 
			
		||||
            elif op is Operators.NE:
 | 
			
		||||
                result = f'-({result}"{value}")'
 | 
			
		||||
            elif op is Operators.LIKE:
 | 
			
		||||
                result += value
 | 
			
		||||
            else:
 | 
			
		||||
                # TODO: Handling TAG, TEXT switch-offs, etc.
 | 
			
		||||
                raise QueryNotSupportedError("Only equals (=) and not-equals (!=) comparisons are "
 | 
			
		||||
                                             "currently supported for TEXT fields. See docs: TODO")
 | 
			
		||||
        elif field_type is RediSearchFieldTypes.NUMERIC:
 | 
			
		||||
            if op is Operators.EQ:
 | 
			
		||||
                result += f"@{field_name}:[{value} {value}]"
 | 
			
		||||
            elif op is Operators.NE:
 | 
			
		||||
                # TODO: Is this enough or do we also need a clause for all values
 | 
			
		||||
                #  ([-inf +inf]) from which we then subtract the undesirable value?
 | 
			
		||||
                result += f"-(@{field_name}:[{value} {value}])"
 | 
			
		||||
            elif op is Operators.GT:
 | 
			
		||||
                result += f"@{field_name}:[({value} +inf]"
 | 
			
		||||
            elif op is Operators.LT:
 | 
			
		||||
                result += f"@{field_name}:[-inf ({value}]"
 | 
			
		||||
            elif op is Operators.GE:
 | 
			
		||||
                result += f"@{field_name}:[{value} +inf]"
 | 
			
		||||
            elif op is Operators.LE:
 | 
			
		||||
                result += f"@{field_name}:[-inf {value}]"
 | 
			
		||||
        elif field_type is RediSearchFieldTypes.TAG:
 | 
			
		||||
            if op is Operators.EQ:
 | 
			
		||||
                result += f"@{field_name}:{{{value}}}"
 | 
			
		||||
            elif op is Operators.NE:
 | 
			
		||||
                result += f"-(@{field_name}:{{{value}}})"
 | 
			
		||||
            elif op is Operators.IN:
 | 
			
		||||
                expanded_value = self.expand_tag_value(value)
 | 
			
		||||
                result += f"(@{field_name}:{{{expanded_value}}})"
 | 
			
		||||
            elif op is Operators.NOT_IN:
 | 
			
		||||
                expanded_value = self.expand_tag_value(value)
 | 
			
		||||
                result += f"-(@{field_name}:{{{expanded_value}}})"
 | 
			
		||||
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def resolve_redisearch_pagination(self):
 | 
			
		||||
        """Resolve pagination options for a query."""
 | 
			
		||||
        if not self.limit and not self.offset:
 | 
			
		||||
            return []
 | 
			
		||||
        offset = self.offset or 0
 | 
			
		||||
        limit = self.limit or 10
 | 
			
		||||
        return ["LIMIT", offset, limit]
 | 
			
		||||
 | 
			
		||||
    def resolve_redisearch_query(self, expression: ExpressionOrNegated):
 | 
			
		||||
        """Resolve an expression to a string RediSearch query."""
 | 
			
		||||
        field_type = None
 | 
			
		||||
        field_name = None
 | 
			
		||||
        encompassing_expression_is_negated = False
 | 
			
		||||
        result = ""
 | 
			
		||||
 | 
			
		||||
        if isinstance(expression, NegatedExpression):
 | 
			
		||||
            encompassing_expression_is_negated = True
 | 
			
		||||
            expression = expression.expression
 | 
			
		||||
 | 
			
		||||
        if isinstance(expression.left, Expression) or \
 | 
			
		||||
                isinstance(expression.left, NegatedExpression):
 | 
			
		||||
            result += f"({self.resolve_redisearch_query(expression.left)})"
 | 
			
		||||
        elif isinstance(expression.left, ModelField):
 | 
			
		||||
            field_type = self.resolve_field_type(expression.left)
 | 
			
		||||
            field_name = expression.left.name
 | 
			
		||||
        else:
 | 
			
		||||
            import ipdb; ipdb.set_trace()
 | 
			
		||||
            raise QueryNotSupportedError(f"A query expression should start with either a field "
 | 
			
		||||
                                         f"or an expression enclosed in parenthesis. See docs: "
 | 
			
		||||
                                         f"TODO")
 | 
			
		||||
 | 
			
		||||
        right = expression.right
 | 
			
		||||
        right_is_negated = isinstance(right, NegatedExpression)
 | 
			
		||||
 | 
			
		||||
        if isinstance(right, Expression) or right_is_negated:
 | 
			
		||||
            if expression.op == Operators.AND:
 | 
			
		||||
                result += " "
 | 
			
		||||
            elif expression.op == Operators.OR:
 | 
			
		||||
                result += "| "
 | 
			
		||||
            else:
 | 
			
		||||
                raise QueryNotSupportedError("You can only combine two query expressions with"
 | 
			
		||||
                                             "AND (&) or OR (|). See docs: TODO")
 | 
			
		||||
 | 
			
		||||
            if right_is_negated:
 | 
			
		||||
                result += "-"
 | 
			
		||||
                # We're handling the RediSearch operator in this call ("-"), so resolve the
 | 
			
		||||
                # inner expression instead of the NegatedExpression.
 | 
			
		||||
                right = right.expression
 | 
			
		||||
 | 
			
		||||
            result += f"({self.resolve_redisearch_query(right)})"
 | 
			
		||||
        else:
 | 
			
		||||
            if isinstance(right, ModelField):
 | 
			
		||||
                raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO")
 | 
			
		||||
            else:
 | 
			
		||||
                result += self.resolve_value(field_name, field_type, expression.op, right)
 | 
			
		||||
 | 
			
		||||
        if encompassing_expression_is_negated:
 | 
			
		||||
            result = f"-({result})"
 | 
			
		||||
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def find(self):
 | 
			
		||||
        args = ["ft.search", self.model.Meta.index_name, self.query]
 | 
			
		||||
        # TODO: Do we need self.pagination if we're just appending to query anyway?
 | 
			
		||||
        if self.pagination:
 | 
			
		||||
            args.extend(self.pagination)
 | 
			
		||||
        return self.model.db().execute_command(*args)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PrimaryKeyCreator(Protocol):
 | 
			
		||||
    def create_pk(self, *args, **kwargs) -> str:
 | 
			
		||||
        """Create a new primary key"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Uuid4PrimaryKey:
 | 
			
		||||
    def create_pk(self) -> str:
 | 
			
		||||
        return str(uuid.uuid4())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ExpressionProxy:
 | 
			
		||||
    def __init__(self, field: ModelField):
 | 
			
		||||
        self.field = field
 | 
			
		||||
| 
						 | 
				
			
			@ -299,6 +151,352 @@ class ExpressionProxy:
 | 
			
		|||
        return Expression(left=self.field, op=Operators.GE, right=other)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class QueryNotSupportedError(Exception):
 | 
			
		||||
    """The attempted query is not supported."""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RediSearchFieldTypes(Enum):
 | 
			
		||||
    TEXT = 'TEXT'
 | 
			
		||||
    TAG = 'TAG'
 | 
			
		||||
    NUMERIC = 'NUMERIC'
 | 
			
		||||
    GEO = 'GEO'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO: How to handle Geo fields?
 | 
			
		||||
NUMERIC_TYPES = (float, int, decimal.Decimal)
 | 
			
		||||
DEFAULT_PAGE_SIZE = 10
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FindQuery(BaseModel):
 | 
			
		||||
    expressions: Sequence[ExpressionOrNegated]
 | 
			
		||||
    model: Type['RedisModel']
 | 
			
		||||
    offset: int = 0
 | 
			
		||||
    limit: int = DEFAULT_PAGE_SIZE
 | 
			
		||||
    page_size: int = DEFAULT_PAGE_SIZE
 | 
			
		||||
    sort_fields: Optional[List[str]] = Field(default_factory=list)
 | 
			
		||||
 | 
			
		||||
    _expression: Expression = PrivateAttr(default=None)
 | 
			
		||||
    _query: str = PrivateAttr(default=None)
 | 
			
		||||
    _pagination: List[str] = PrivateAttr(default_factory=list)
 | 
			
		||||
    _model_cache: Optional[List['RedisModel']] = PrivateAttr(default_factory=list)
 | 
			
		||||
 | 
			
		||||
    class Config:
 | 
			
		||||
        arbitrary_types_allowed = True
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def pagination(self):
 | 
			
		||||
        if self._pagination:
 | 
			
		||||
            return self._pagination
 | 
			
		||||
        self._pagination = self.resolve_redisearch_pagination()
 | 
			
		||||
        return self._pagination
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def expression(self):
 | 
			
		||||
        if self._expression:
 | 
			
		||||
            return self._expression
 | 
			
		||||
        if self.expressions:
 | 
			
		||||
            self._expression = reduce(operator.and_, self.expressions)
 | 
			
		||||
        else:
 | 
			
		||||
            self._expression = Expression(left=None, right=None, op=Operators.ALL)
 | 
			
		||||
        return self._expression
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def query(self):
 | 
			
		||||
        return self.resolve_redisearch_query(self.expression)
 | 
			
		||||
 | 
			
		||||
    @validator("sort_fields")
 | 
			
		||||
    def validate_sort_fields(cls, v, values):
 | 
			
		||||
        model = values['model']
 | 
			
		||||
        for sort_field in v:
 | 
			
		||||
            field_name = sort_field.lstrip("-")
 | 
			
		||||
            if field_name not in model.__fields__:
 | 
			
		||||
                raise QueryNotSupportedError(f"You tried sort by {field_name}, but that field "
 | 
			
		||||
                                             f"does not exist on the model {model}")
 | 
			
		||||
            field_proxy = getattr(model, field_name)
 | 
			
		||||
            if not getattr(field_proxy.field.field_info, 'sortable', False):
 | 
			
		||||
                raise QueryNotSupportedError(f"You tried sort by {field_name}, but {cls} does "
 | 
			
		||||
                                             "not define that field as sortable. See docs: XXX")
 | 
			
		||||
        return v
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def resolve_field_type(field: ModelField) -> RediSearchFieldTypes:
 | 
			
		||||
        if getattr(field.field_info, 'primary_key', None) is True:
 | 
			
		||||
            return RediSearchFieldTypes.TAG
 | 
			
		||||
        elif getattr(field.field_info, 'full_text_search', None) is True:
 | 
			
		||||
            return RediSearchFieldTypes.TEXT
 | 
			
		||||
 | 
			
		||||
        field_type = field.outer_type_
 | 
			
		||||
 | 
			
		||||
        # TODO: GEO
 | 
			
		||||
        if any(issubclass(field_type, t) for t in NUMERIC_TYPES):
 | 
			
		||||
            return RediSearchFieldTypes.NUMERIC
 | 
			
		||||
        else:
 | 
			
		||||
            # TAG fields are the default field type.
 | 
			
		||||
            # TODO: A ListField or ArrayField that supports multiple values
 | 
			
		||||
            #  and contains logic.
 | 
			
		||||
            return RediSearchFieldTypes.TAG
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def expand_tag_value(value):
 | 
			
		||||
        err = RedisModelError(f"Using the IN operator requires passing a sequence of "
 | 
			
		||||
                              "possible values. You passed: {value}")
 | 
			
		||||
        if isinstance(str, value):
 | 
			
		||||
            raise err
 | 
			
		||||
        try:
 | 
			
		||||
            expanded_value = "|".join([escaper.escape(v) for v in value])
 | 
			
		||||
        except TypeError:
 | 
			
		||||
            raise err
 | 
			
		||||
        return expanded_value
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def resolve_value(cls, field_name: str, field_type: RediSearchFieldTypes,
 | 
			
		||||
                      op: Operators, value: Any) -> str:
 | 
			
		||||
        result = ""
 | 
			
		||||
        if field_type is RediSearchFieldTypes.TEXT:
 | 
			
		||||
            result = f"@{field_name}:"
 | 
			
		||||
            if op is Operators.EQ:
 | 
			
		||||
                result += f'"{value}"'
 | 
			
		||||
            elif op is Operators.NE:
 | 
			
		||||
                result = f'-({result}"{value}")'
 | 
			
		||||
            elif op is Operators.LIKE:
 | 
			
		||||
                result += value
 | 
			
		||||
            else:
 | 
			
		||||
                raise QueryNotSupportedError("Only equals (=), not-equals (!=), and like() "
 | 
			
		||||
                                             "comparisons are supported for TEXT fields. See "
 | 
			
		||||
                                             "docs: TODO.")
 | 
			
		||||
        elif field_type is RediSearchFieldTypes.NUMERIC:
 | 
			
		||||
            if op is Operators.EQ:
 | 
			
		||||
                result += f"@{field_name}:[{value} {value}]"
 | 
			
		||||
            elif op is Operators.NE:
 | 
			
		||||
                # TODO: Is this enough or do we also need a clause for all values
 | 
			
		||||
                #  ([-inf +inf]) from which we then subtract the undesirable value?
 | 
			
		||||
                result += f"-(@{field_name}:[{value} {value}])"
 | 
			
		||||
            elif op is Operators.GT:
 | 
			
		||||
                result += f"@{field_name}:[({value} +inf]"
 | 
			
		||||
            elif op is Operators.LT:
 | 
			
		||||
                result += f"@{field_name}:[-inf ({value}]"
 | 
			
		||||
            elif op is Operators.GE:
 | 
			
		||||
                result += f"@{field_name}:[{value} +inf]"
 | 
			
		||||
            elif op is Operators.LE:
 | 
			
		||||
                result += f"@{field_name}:[-inf {value}]"
 | 
			
		||||
        elif field_type is RediSearchFieldTypes.TAG:
 | 
			
		||||
            if op is Operators.EQ:
 | 
			
		||||
                value = escaper.escape(value)
 | 
			
		||||
                result += f"@{field_name}:{{{value}}}"
 | 
			
		||||
            elif op is Operators.NE:
 | 
			
		||||
                value = escaper.escape(value)
 | 
			
		||||
                result += f"-(@{field_name}:{{{value}}})"
 | 
			
		||||
            elif op is Operators.IN:
 | 
			
		||||
                expanded_value = cls.expand_tag_value(value)
 | 
			
		||||
                result += f"(@{field_name}:{{{expanded_value}}})"
 | 
			
		||||
            elif op is Operators.NOT_IN:
 | 
			
		||||
                expanded_value = cls.expand_tag_value(value)
 | 
			
		||||
                result += f"-(@{field_name}:{{{expanded_value}}})"
 | 
			
		||||
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def resolve_redisearch_pagination(self):
 | 
			
		||||
        """Resolve pagination options for a query."""
 | 
			
		||||
        return ["LIMIT", self.offset, self.limit]
 | 
			
		||||
 | 
			
		||||
    def resolve_redisearch_sort_fields(self):
 | 
			
		||||
        """Resolve sort options for a query."""
 | 
			
		||||
        if not self.sort_fields:
 | 
			
		||||
            return
 | 
			
		||||
        fields = []
 | 
			
		||||
        for f in self.sort_fields:
 | 
			
		||||
            direction = "desc" if f.startswith('-') else 'asc'
 | 
			
		||||
            fields.extend([f.lstrip('-'), direction])
 | 
			
		||||
        if self.sort_fields:
 | 
			
		||||
            return ["SORTBY", *fields]
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def resolve_redisearch_query(cls, expression: ExpressionOrNegated):
 | 
			
		||||
        """Resolve an expression to a string RediSearch query."""
 | 
			
		||||
        field_type = None
 | 
			
		||||
        field_name = None
 | 
			
		||||
        encompassing_expression_is_negated = False
 | 
			
		||||
        result = ""
 | 
			
		||||
 | 
			
		||||
        if isinstance(expression, NegatedExpression):
 | 
			
		||||
            encompassing_expression_is_negated = True
 | 
			
		||||
            expression = expression.expression
 | 
			
		||||
 | 
			
		||||
        if expression.op is Operators.ALL:
 | 
			
		||||
            if encompassing_expression_is_negated:
 | 
			
		||||
                # TODO: Is there a use case for this, perhaps for dynamic
 | 
			
		||||
                # scoring purposes?
 | 
			
		||||
                raise QueryNotSupportedError("You cannot negate a query for all results.")
 | 
			
		||||
            return "*"
 | 
			
		||||
 | 
			
		||||
        if isinstance(expression.left, Expression) or \
 | 
			
		||||
                isinstance(expression.left, NegatedExpression):
 | 
			
		||||
            result += f"({cls.resolve_redisearch_query(expression.left)})"
 | 
			
		||||
        elif isinstance(expression.left, ModelField):
 | 
			
		||||
            field_type = cls.resolve_field_type(expression.left)
 | 
			
		||||
            field_name = expression.left.name
 | 
			
		||||
        else:
 | 
			
		||||
            import ipdb; ipdb.set_trace()
 | 
			
		||||
            raise QueryNotSupportedError(f"A query expression should start with either a field "
 | 
			
		||||
                                         f"or an expression enclosed in parenthesis. See docs: "
 | 
			
		||||
                                         f"TODO")
 | 
			
		||||
 | 
			
		||||
        right = expression.right
 | 
			
		||||
 | 
			
		||||
        if isinstance(right, Expression) or isinstance(right, NegatedExpression):
 | 
			
		||||
            if expression.op == Operators.AND:
 | 
			
		||||
                result += " "
 | 
			
		||||
            elif expression.op == Operators.OR:
 | 
			
		||||
                result += "| "
 | 
			
		||||
            else:
 | 
			
		||||
                raise QueryNotSupportedError("You can only combine two query expressions with"
 | 
			
		||||
                                             "AND (&) or OR (|). See docs: TODO")
 | 
			
		||||
 | 
			
		||||
            if isinstance(right, NegatedExpression):
 | 
			
		||||
                result += "-"
 | 
			
		||||
                # We're handling the RediSearch operator in this call ("-"), so resolve the
 | 
			
		||||
                # inner expression instead of the NegatedExpression.
 | 
			
		||||
                right = right.expression
 | 
			
		||||
 | 
			
		||||
            result += f"({cls.resolve_redisearch_query(right)})"
 | 
			
		||||
        else:
 | 
			
		||||
            if isinstance(right, ModelField):
 | 
			
		||||
                raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO")
 | 
			
		||||
            else:
 | 
			
		||||
                # TODO: Optionals causing IDE errors here
 | 
			
		||||
                result += cls.resolve_value(field_name, field_type, expression.op, right)
 | 
			
		||||
 | 
			
		||||
        if encompassing_expression_is_negated:
 | 
			
		||||
            result = f"-({result})"
 | 
			
		||||
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def execute(self, exhaust_results=True):
 | 
			
		||||
        args = ["ft.search", self.model.Meta.index_name, self.query, *self.pagination]
 | 
			
		||||
        if self.sort_fields:
 | 
			
		||||
            args += self.resolve_redisearch_sort_fields()
 | 
			
		||||
 | 
			
		||||
        # Reset the cache if we're executing from offset 0.
 | 
			
		||||
        if self.offset == 0:
 | 
			
		||||
            self._model_cache.clear()
 | 
			
		||||
 | 
			
		||||
        # If the offset is greater than 0, we're paginating through a result set,
 | 
			
		||||
        # so append the new results to results already in the cache.
 | 
			
		||||
        raw_result = self.model.db().execute_command(*args)
 | 
			
		||||
        count = raw_result[0]
 | 
			
		||||
        results = self.model.from_redis(raw_result)
 | 
			
		||||
        self._model_cache += results
 | 
			
		||||
 | 
			
		||||
        if not exhaust_results:
 | 
			
		||||
            return self._model_cache
 | 
			
		||||
 | 
			
		||||
        # The query returned all results, so we have no more work to do.
 | 
			
		||||
        if count <= len(results):
 | 
			
		||||
            return self._model_cache
 | 
			
		||||
 | 
			
		||||
        # Transparently (to the user) make subsequent requests to paginate
 | 
			
		||||
        # through the results and finally return them all.
 | 
			
		||||
        query = self
 | 
			
		||||
        while True:
 | 
			
		||||
            # Make a query for each pass of the loop, with a new offset equal to the
 | 
			
		||||
            # current offset plus `page_size`, until we stop getting results back.
 | 
			
		||||
            query = FindQuery(expressions=query.expressions,
 | 
			
		||||
                              model=query.model,
 | 
			
		||||
                              offset=query.offset + query.page_size,
 | 
			
		||||
                              page_size=query.page_size,
 | 
			
		||||
                              limit=query.limit)
 | 
			
		||||
            _results = query.execute(exhaust_results=False)
 | 
			
		||||
            if not _results:
 | 
			
		||||
                break
 | 
			
		||||
            self._model_cache += _results
 | 
			
		||||
        return self._model_cache
 | 
			
		||||
 | 
			
		||||
    def first(self):
 | 
			
		||||
        query = FindQuery(expressions=self.expressions, model=self.model,
 | 
			
		||||
                          offset=0, limit=1, sort_fields=self.sort_fields)
 | 
			
		||||
        return query.execute()[0]
 | 
			
		||||
 | 
			
		||||
    def all(self, batch_size=10):
 | 
			
		||||
        if batch_size != self.page_size:
 | 
			
		||||
            # TODO: There's probably a copy-with-change mechanism in Pydantic,
 | 
			
		||||
            #  or can we use one from dataclasses?
 | 
			
		||||
            query = FindQuery(expressions=self.expressions,
 | 
			
		||||
                              model=self.model,
 | 
			
		||||
                              offset=self.offset,
 | 
			
		||||
                              page_size=batch_size,
 | 
			
		||||
                              limit=batch_size,
 | 
			
		||||
                              sort_fields=self.sort_fields)
 | 
			
		||||
            return query.execute()
 | 
			
		||||
        return self.execute()
 | 
			
		||||
 | 
			
		||||
    def sort_by(self, *fields):
 | 
			
		||||
        if not fields:
 | 
			
		||||
            return self
 | 
			
		||||
        return FindQuery(expressions=self.expressions,
 | 
			
		||||
                         model=self.model,
 | 
			
		||||
                         offset=self.offset,
 | 
			
		||||
                         page_size=self.page_size,
 | 
			
		||||
                         limit=self.limit,
 | 
			
		||||
                         sort_fields=list(fields))
 | 
			
		||||
 | 
			
		||||
    def update(self, **kwargs):
 | 
			
		||||
        """Update all matching records in this query."""
 | 
			
		||||
        # TODO
 | 
			
		||||
 | 
			
		||||
    def delete(cls, **field_values):
 | 
			
		||||
        """Delete all matching records in this query."""
 | 
			
		||||
        for field_name, value in field_values:
 | 
			
		||||
            valid_attr = hasattr(cls.model, field_name)
 | 
			
		||||
            if not valid_attr:
 | 
			
		||||
                raise RedisModelError(f"Can't update field {field_name} because "
 | 
			
		||||
                                      f"the field does not exist on the model {cls}")
 | 
			
		||||
 | 
			
		||||
        return cls
 | 
			
		||||
 | 
			
		||||
    def __iter__(self):
 | 
			
		||||
        if self._model_cache:
 | 
			
		||||
            for m in self._model_cache:
 | 
			
		||||
                yield m
 | 
			
		||||
        else:
 | 
			
		||||
            for m in self.execute():
 | 
			
		||||
                yield m
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, item: int):
 | 
			
		||||
        """
 | 
			
		||||
        Given this code:
 | 
			
		||||
            Model.find()[1000]
 | 
			
		||||
 | 
			
		||||
        We should return only the 1000th result.
 | 
			
		||||
 | 
			
		||||
            1. If the result is loaded in the query cache for this query,
 | 
			
		||||
               we can return it directly from the cache.
 | 
			
		||||
 | 
			
		||||
            2. If the query cache does not have enough elements to return
 | 
			
		||||
               that result, then we should clone the current query and
 | 
			
		||||
               give it a new offset and limit: offset=n, limit=1.
 | 
			
		||||
        """
 | 
			
		||||
        if self._model_cache and len(self._model_cache) >= item:
 | 
			
		||||
            return self._model_cache[item]
 | 
			
		||||
 | 
			
		||||
        query = FindQuery(expressions=self.expressions,
 | 
			
		||||
                          model=self.model,
 | 
			
		||||
                          offset=item,
 | 
			
		||||
                          sort_fields=self.sort_fields,
 | 
			
		||||
                          limit=1)
 | 
			
		||||
 | 
			
		||||
        return query.execute()[0]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PrimaryKeyCreator(Protocol):
 | 
			
		||||
    def create_pk(self, *args, **kwargs) -> str:
 | 
			
		||||
        """Create a new primary key"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Uuid4PrimaryKey:
 | 
			
		||||
    def create_pk(self, *args, **kwargs) -> str:
 | 
			
		||||
        return str(uuid.uuid4())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def __dataclass_transform__(
 | 
			
		||||
    *,
 | 
			
		||||
    eq_default: bool = True,
 | 
			
		||||
| 
						 | 
				
			
			@ -395,21 +593,22 @@ def Field(
 | 
			
		|||
    return field_info
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class PrimaryKey:
 | 
			
		||||
    name: str
 | 
			
		||||
    field: ModelField
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DefaultMeta:
 | 
			
		||||
    # TODO: Should this really be optional here?
 | 
			
		||||
    global_key_prefix: Optional[str] = None
 | 
			
		||||
    model_key_prefix: Optional[str] = None
 | 
			
		||||
    primary_key_pattern: Optional[str] = None
 | 
			
		||||
    database: Optional[redis.Redis] = None
 | 
			
		||||
    primary_key: Optional[PrimaryKey] = None
 | 
			
		||||
    primary_key_creator_cls: Type[PrimaryKeyCreator] = None
 | 
			
		||||
    index_name: str = None
 | 
			
		||||
    abstract: bool = False
 | 
			
		||||
    primary_key_creator_cls: Optional[Type[PrimaryKeyCreator]] = None
 | 
			
		||||
    index_name: Optional[str] = None
 | 
			
		||||
    abstract: Optional[bool] = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ModelMeta(ModelMetaclass):
 | 
			
		||||
| 
						 | 
				
			
			@ -473,6 +672,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
 | 
			
		|||
    pk: Optional[str] = Field(default=None, primary_key=True)
 | 
			
		||||
 | 
			
		||||
    Meta = DefaultMeta
 | 
			
		||||
    # TODO: Missing _meta here is causing IDE warnings.
 | 
			
		||||
 | 
			
		||||
    class Config:
 | 
			
		||||
        orm_mode = True
 | 
			
		||||
| 
						 | 
				
			
			@ -484,6 +684,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
 | 
			
		|||
        __pydantic_self__.validate_primary_key()
 | 
			
		||||
 | 
			
		||||
    def __lt__(self, other):
 | 
			
		||||
        """Default sort: compare all shared model fields."""
 | 
			
		||||
        my_keys = set(self.__fields__.keys())
 | 
			
		||||
        other_keys = set(other.__fields__.keys())
 | 
			
		||||
        shared_keys = list(my_keys & other_keys)
 | 
			
		||||
| 
						 | 
				
			
			@ -528,8 +729,13 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
 | 
			
		|||
    def db(cls):
 | 
			
		||||
        return cls._meta.database
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def find(cls, *expressions: Expression):
 | 
			
		||||
        return FindQuery(expressions=expressions, model=cls)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_redis(cls, res: Any):
 | 
			
		||||
        # TODO: Parsing logic borrowed from redisearch-py. Evaluate.
 | 
			
		||||
        import six
 | 
			
		||||
        from six.moves import xrange, zip as izip
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -537,20 +743,20 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
 | 
			
		|||
            if isinstance(s, six.string_types):
 | 
			
		||||
                return s
 | 
			
		||||
            elif isinstance(s, six.binary_type):
 | 
			
		||||
                return s.decode('utf-8','ignore')
 | 
			
		||||
                return s.decode('utf-8', 'ignore')
 | 
			
		||||
            else:
 | 
			
		||||
                return s  # Not a string we care about
 | 
			
		||||
 | 
			
		||||
        docs = []
 | 
			
		||||
        step = 2  # Because the result has content
 | 
			
		||||
        offset = 1
 | 
			
		||||
        offset = 1  # The first item is the count of total matches.
 | 
			
		||||
 | 
			
		||||
        for i in xrange(1, len(res), step):
 | 
			
		||||
            fields_offset = offset
 | 
			
		||||
 | 
			
		||||
            fields = dict(
 | 
			
		||||
                dict(izip(map(to_string, res[i + fields_offset][::2]),
 | 
			
		||||
                            map(to_string, res[i + fields_offset][1::2])))
 | 
			
		||||
                          map(to_string, res[i + fields_offset][1::2])))
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
| 
						 | 
				
			
			@ -562,17 +768,6 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
 | 
			
		|||
            docs.append(doc)
 | 
			
		||||
        return docs
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def find(cls, *expressions: Expression):
 | 
			
		||||
        query = FindQuery(expressions=expressions, model=cls)
 | 
			
		||||
        raw_result = query.find()
 | 
			
		||||
        return cls.from_redis(raw_result)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def find_one(cls, *expressions: Expression):
 | 
			
		||||
        query = FindQuery(expressions=expressions, model=cls, limit=1, offset=0)
 | 
			
		||||
        raw_result = query.find()
 | 
			
		||||
        return cls.from_redis(raw_result)[0]
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def add(cls, models: Sequence['RedisModel']) -> Sequence['RedisModel']:
 | 
			
		||||
| 
						 | 
				
			
			@ -580,6 +775,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
 | 
			
		|||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def update(cls, **field_values):
 | 
			
		||||
        """Update this model instance."""
 | 
			
		||||
        return cls
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,55 +0,0 @@
 | 
			
		|||
from redis_developer.orm.model import Expression
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class QueryIterator:
 | 
			
		||||
    """
 | 
			
		||||
    A lazy iterator that yields results from a RediSearch query.
 | 
			
		||||
 | 
			
		||||
    Examples:
 | 
			
		||||
 | 
			
		||||
        results = Model.filter(email == "a@example.com")
 | 
			
		||||
 | 
			
		||||
        # Consume all results.
 | 
			
		||||
        for r in results:
 | 
			
		||||
            print(r)
 | 
			
		||||
 | 
			
		||||
        # Consume an item at an index.
 | 
			
		||||
        print(results[100])
 | 
			
		||||
 | 
			
		||||
        # Consume a slice.
 | 
			
		||||
        print(results[0:100])
 | 
			
		||||
 | 
			
		||||
        # Alternative notation to consume all items.
 | 
			
		||||
        print(results[0:-1])
 | 
			
		||||
 | 
			
		||||
        # Specify the batch size:
 | 
			
		||||
        results = Model.filter(email == "a@example.com", batch_size=1000)
 | 
			
		||||
        ...
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, client, query, batch_size=100):
 | 
			
		||||
        self.client = client
 | 
			
		||||
        self.query = query
 | 
			
		||||
        self.batch_size = batch_size
 | 
			
		||||
 | 
			
		||||
    def __iter__(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def __getattr__(self, item):
 | 
			
		||||
        """Support getting a single value or a slice."""
 | 
			
		||||
 | 
			
		||||
    # TODO: Query mixin?
 | 
			
		||||
 | 
			
		||||
    def filter(self, *expressions: Expression):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def exclude(self, *expressions: Expression):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def and_(self, *expressions: Expression):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def or_(self, *expressions: Expression):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def not_(self, *expressions: Expression):
 | 
			
		||||
        pass
 | 
			
		||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue