Handle TAG queries that include the separator
This commit is contained in:
		
							parent
							
								
									b46408ccd2
								
							
						
					
					
						commit
						8f32b359f0
					
				
					 7 changed files with 591 additions and 128 deletions
				
			
		| 
						 | 
				
			
			@ -1,9 +1,9 @@
 | 
			
		|||
import abc
 | 
			
		||||
import dataclasses
 | 
			
		||||
import decimal
 | 
			
		||||
import json
 | 
			
		||||
import logging
 | 
			
		||||
import operator
 | 
			
		||||
import re
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from functools import reduce
 | 
			
		||||
| 
						 | 
				
			
			@ -22,10 +22,9 @@ from typing import (
 | 
			
		|||
    no_type_check,
 | 
			
		||||
    Protocol,
 | 
			
		||||
    List,
 | 
			
		||||
    Type,
 | 
			
		||||
    Pattern, get_origin, get_args
 | 
			
		||||
    get_origin,
 | 
			
		||||
    get_args, Type
 | 
			
		||||
)
 | 
			
		||||
import uuid
 | 
			
		||||
 | 
			
		||||
import redis
 | 
			
		||||
from pydantic import BaseModel, validator
 | 
			
		||||
| 
						 | 
				
			
			@ -34,46 +33,43 @@ from pydantic.fields import ModelField, Undefined, UndefinedType
 | 
			
		|||
from pydantic.main import ModelMetaclass
 | 
			
		||||
from pydantic.typing import NoArgAnyCallable, resolve_annotations
 | 
			
		||||
from pydantic.utils import Representation
 | 
			
		||||
from ulid import ULID
 | 
			
		||||
 | 
			
		||||
from .encoders import jsonable_encoder
 | 
			
		||||
from .render_tree import render_tree
 | 
			
		||||
from .token_escaper import TokenEscaper
 | 
			
		||||
 | 
			
		||||
model_registry = {}
 | 
			
		||||
_T = TypeVar("_T")
 | 
			
		||||
log = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TokenEscaper:
 | 
			
		||||
    """
 | 
			
		||||
    Escape punctuation within an input string.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # Characters that RediSearch requires us to escape during queries.
 | 
			
		||||
    # Source: https://oss.redis.com/redisearch/Escaping/#the_rules_of_text_field_tokenization
 | 
			
		||||
    DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\ ]"
 | 
			
		||||
 | 
			
		||||
    def __init__(self, escape_chars_re: Optional[Pattern] = None):
 | 
			
		||||
        if escape_chars_re:
 | 
			
		||||
            self.escaped_chars_re = escape_chars_re
 | 
			
		||||
        else:
 | 
			
		||||
            self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS)
 | 
			
		||||
 | 
			
		||||
    def escape(self, string):
 | 
			
		||||
        def escape_symbol(match):
 | 
			
		||||
            value = match.group(0)
 | 
			
		||||
            return f"\\{value}"
 | 
			
		||||
 | 
			
		||||
        return self.escaped_chars_re.sub(escape_symbol, string)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
escaper = TokenEscaper()
 | 
			
		||||
 | 
			
		||||
# For basic exact-match field types like an indexed string, we create a TAG
 | 
			
		||||
# field in the RediSearch index. TAG is designed for multi-value fields
 | 
			
		||||
# separated by a "separator" character. We're using the field for single values
 | 
			
		||||
# (multi-value TAGs will be exposed as a separate field type), and we use the
 | 
			
		||||
# pipe character (|) as the separator. There is no way to escape this character
 | 
			
		||||
# in hash fields or JSON objects, so if someone indexes a value that includes
 | 
			
		||||
# the pipe, we'll warn but allow, and then warn again if they try to query for
 | 
			
		||||
# values that contain this separator.
 | 
			
		||||
SINGLE_VALUE_TAG_FIELD_SEPARATOR = "|"
 | 
			
		||||
 | 
			
		||||
# This is the default field separator in RediSearch. We need it to determine if
 | 
			
		||||
# someone has accidentally passed in the field separator with string value of a
 | 
			
		||||
# multi-value field lookup, like a IN or NOT_IN.
 | 
			
		||||
DEFAULT_REDISEARCH_FIELD_SEPARATOR = ","
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RedisModelError(Exception):
 | 
			
		||||
    pass
 | 
			
		||||
    """Raised when a problem exists in the definition of a RedisModel."""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class QuerySyntaxError(Exception):
 | 
			
		||||
    """Raised when a query is constructed improperly."""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NotFoundError(Exception):
 | 
			
		||||
    """A query found no results."""
 | 
			
		||||
    """Raised when a query found no results."""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Operators(Enum):
 | 
			
		||||
| 
						 | 
				
			
			@ -91,9 +87,45 @@ class Operators(Enum):
 | 
			
		|||
    LIKE = 12
 | 
			
		||||
    ALL = 13
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return str(self.name)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
ExpressionOrModelField = Union['Expression', 'NegatedExpression', ModelField]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ExpressionProtocol(Protocol):
 | 
			
		||||
    op: Operators
 | 
			
		||||
    left: ExpressionOrModelField
 | 
			
		||||
    right: ExpressionOrModelField
 | 
			
		||||
 | 
			
		||||
    def __invert__(self) -> 'Expression':
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def __and__(self, other: ExpressionOrModelField):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def __or__(self, other: ExpressionOrModelField):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def name(self) -> str:
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def tree(self) -> str:
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class NegatedExpression:
 | 
			
		||||
    """A negated Expression object.
 | 
			
		||||
 | 
			
		||||
    For now, this is a separate dataclass from Expression that acts as a facade
 | 
			
		||||
    around an Expression, indicating to model code (specifically, code
 | 
			
		||||
    responsible for querying) to negate the logic in the wrapped Expression. A
 | 
			
		||||
    better design is probably possible, maybe at least an ExpressionProtocol?
 | 
			
		||||
    """
 | 
			
		||||
    expression: 'Expression'
 | 
			
		||||
 | 
			
		||||
    def __invert__(self):
 | 
			
		||||
| 
						 | 
				
			
			@ -105,22 +137,53 @@ class NegatedExpression:
 | 
			
		|||
    def __or__(self, other):
 | 
			
		||||
        return Expression(left=self, op=Operators.OR, right=other)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def left(self):
 | 
			
		||||
        return self.expression.left
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def right(self):
 | 
			
		||||
        return self.expression.right
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def op(self):
 | 
			
		||||
        return self.expression.op
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def name(self):
 | 
			
		||||
        if self.expression.op is Operators.EQ:
 | 
			
		||||
            return f"NOT {self.expression.name}"
 | 
			
		||||
        else:
 | 
			
		||||
            return f"{self.expression.name} NOT"
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def tree(self):
 | 
			
		||||
        return render_tree(self)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class Expression:
 | 
			
		||||
    op: Operators
 | 
			
		||||
    left: Any
 | 
			
		||||
    right: Any
 | 
			
		||||
    left: ExpressionOrModelField
 | 
			
		||||
    right: ExpressionOrModelField
 | 
			
		||||
 | 
			
		||||
    def __invert__(self):
 | 
			
		||||
        return NegatedExpression(self)
 | 
			
		||||
 | 
			
		||||
    def __and__(self, other):
 | 
			
		||||
    def __and__(self, other: ExpressionOrModelField):
 | 
			
		||||
        return Expression(left=self, op=Operators.AND, right=other)
 | 
			
		||||
 | 
			
		||||
    def __or__(self, other):
 | 
			
		||||
    def __or__(self, other: ExpressionOrModelField):
 | 
			
		||||
        return Expression(left=self, op=Operators.OR, right=other)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def name(self):
 | 
			
		||||
        return str(self.op)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def tree(self):
 | 
			
		||||
        return render_tree(self)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
ExpressionOrNegated = Union[Expression, NegatedExpression]
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -129,22 +192,22 @@ class ExpressionProxy:
 | 
			
		|||
    def __init__(self, field: ModelField):
 | 
			
		||||
        self.field = field
 | 
			
		||||
 | 
			
		||||
    def __eq__(self, other: Any) -> Expression:
 | 
			
		||||
    def __eq__(self, other: Any) -> Expression:  # type: ignore[override]
 | 
			
		||||
        return Expression(left=self.field, op=Operators.EQ, right=other)
 | 
			
		||||
 | 
			
		||||
    def __ne__(self, other: Any) -> Expression:
 | 
			
		||||
    def __ne__(self, other: Any) -> Expression:  # type: ignore[override]
 | 
			
		||||
        return Expression(left=self.field, op=Operators.NE, right=other)
 | 
			
		||||
 | 
			
		||||
    def __lt__(self, other: Any) -> Expression:
 | 
			
		||||
    def __lt__(self, other: Any) -> Expression:  # type: ignore[override]
 | 
			
		||||
        return Expression(left=self.field, op=Operators.LT, right=other)
 | 
			
		||||
 | 
			
		||||
    def __le__(self, other: Any) -> Expression:
 | 
			
		||||
    def __le__(self, other: Any) -> Expression:  # type: ignore[override]
 | 
			
		||||
        return Expression(left=self.field, op=Operators.LE, right=other)
 | 
			
		||||
 | 
			
		||||
    def __gt__(self, other: Any) -> Expression:
 | 
			
		||||
    def __gt__(self, other: Any) -> Expression:  # type: ignore[override]
 | 
			
		||||
        return Expression(left=self.field, op=Operators.GT, right=other)
 | 
			
		||||
 | 
			
		||||
    def __ge__(self, other: Any) -> Expression:
 | 
			
		||||
    def __ge__(self, other: Any) -> Expression:  # type: ignore[override]
 | 
			
		||||
        return Expression(left=self.field, op=Operators.GE, right=other)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -184,9 +247,9 @@ class FindQuery:
 | 
			
		|||
            self.sort_fields = []
 | 
			
		||||
 | 
			
		||||
        self._expression = None
 | 
			
		||||
        self._query = None
 | 
			
		||||
        self._pagination = []
 | 
			
		||||
        self._model_cache = []
 | 
			
		||||
        self._query: Optional[str] = None
 | 
			
		||||
        self._pagination: list[str] = []
 | 
			
		||||
        self._model_cache: list[RedisModel] = []
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def pagination(self):
 | 
			
		||||
| 
						 | 
				
			
			@ -236,24 +299,24 @@ class FindQuery:
 | 
			
		|||
        else:
 | 
			
		||||
            # TAG fields are the default field type.
 | 
			
		||||
            # TODO: A ListField or ArrayField that supports multiple values
 | 
			
		||||
            #  and contains logic.
 | 
			
		||||
            #  and contains logic should allow IN and NOT_IN queries.
 | 
			
		||||
            return RediSearchFieldTypes.TAG
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def expand_tag_value(value):
 | 
			
		||||
        err = RedisModelError(f"Using the IN operator requires passing a sequence of "
 | 
			
		||||
                              "possible values. You passed: {value}")
 | 
			
		||||
        if isinstance(str, value):
 | 
			
		||||
            raise err
 | 
			
		||||
            return value
 | 
			
		||||
        try:
 | 
			
		||||
            expanded_value = "|".join([escaper.escape(v) for v in value])
 | 
			
		||||
        except TypeError:
 | 
			
		||||
            raise err
 | 
			
		||||
            raise QuerySyntaxError("Values passed to an IN query must be iterables,"
 | 
			
		||||
                                   "like a list of strings. For more information, see:"
 | 
			
		||||
                                   "TODO: doc.")
 | 
			
		||||
        return expanded_value
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def resolve_value(cls, field_name: str, field_type: RediSearchFieldTypes,
 | 
			
		||||
                      op: Operators, value: Any) -> str:
 | 
			
		||||
                      field_info: PydanticFieldInfo, op: Operators, value: Any) -> str:
 | 
			
		||||
        result = ""
 | 
			
		||||
        if field_type is RediSearchFieldTypes.TEXT:
 | 
			
		||||
            result = f"@{field_name}:"
 | 
			
		||||
| 
						 | 
				
			
			@ -282,17 +345,41 @@ class FindQuery:
 | 
			
		|||
                result += f"@{field_name}:[{value} +inf]"
 | 
			
		||||
            elif op is Operators.LE:
 | 
			
		||||
                result += f"@{field_name}:[-inf {value}]"
 | 
			
		||||
        # TODO: How will we know the difference between a multi-value use of a TAG
 | 
			
		||||
        #  field and our hidden use of TAG for exact-match queries?
 | 
			
		||||
        elif field_type is RediSearchFieldTypes.TAG:
 | 
			
		||||
            if op is Operators.EQ:
 | 
			
		||||
                value = escaper.escape(value)
 | 
			
		||||
                result += f"@{field_name}:{{{value}}}"
 | 
			
		||||
                separator_char = getattr(field_info, 'separator',
 | 
			
		||||
                                         SINGLE_VALUE_TAG_FIELD_SEPARATOR)
 | 
			
		||||
                if value == separator_char:
 | 
			
		||||
                    # The value is ONLY the TAG field separator character --
 | 
			
		||||
                    # this is not going to work.
 | 
			
		||||
                    log.warning("Your query against the field %s is for a single character, %s, "
 | 
			
		||||
                                "that is used internally by redis-developer-python. We must ignore "
 | 
			
		||||
                                "this portion of the query. Please review your query to find "
 | 
			
		||||
                                "an alternative query that uses a string containing more than "
 | 
			
		||||
                                "just the character %s.", field_name, separator_char, separator_char)
 | 
			
		||||
                    return ""
 | 
			
		||||
                if separator_char in value:
 | 
			
		||||
                    # The value contains the TAG field separator. We can work
 | 
			
		||||
                    # around this by breaking apart the values and unioning them
 | 
			
		||||
                    # with multiple field:{} queries.
 | 
			
		||||
                    values = filter(None, value.split(separator_char))
 | 
			
		||||
                    for value in values:
 | 
			
		||||
                        value = escaper.escape(value)
 | 
			
		||||
                        result += f"@{field_name}:{{{value}}}"
 | 
			
		||||
                else:
 | 
			
		||||
                    value = escaper.escape(value)
 | 
			
		||||
                    result += f"@{field_name}:{{{value}}}"
 | 
			
		||||
            elif op is Operators.NE:
 | 
			
		||||
                value = escaper.escape(value)
 | 
			
		||||
                result += f"-(@{field_name}:{{{value}}})"
 | 
			
		||||
            elif op is Operators.IN:
 | 
			
		||||
                # TODO: Implement IN, test this...
 | 
			
		||||
                expanded_value = cls.expand_tag_value(value)
 | 
			
		||||
                result += f"(@{field_name}:{{{expanded_value}}})"
 | 
			
		||||
            elif op is Operators.NOT_IN:
 | 
			
		||||
                # TODO: Implement NOT_IN, test this...
 | 
			
		||||
                expanded_value = cls.expand_tag_value(value)
 | 
			
		||||
                result += f"-(@{field_name}:{{{expanded_value}}})"
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -314,10 +401,11 @@ class FindQuery:
 | 
			
		|||
            return ["SORTBY", *fields]
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def resolve_redisearch_query(cls, expression: ExpressionOrNegated):
 | 
			
		||||
    def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
 | 
			
		||||
        """Resolve an expression to a string RediSearch query."""
 | 
			
		||||
        field_type = None
 | 
			
		||||
        field_name = None
 | 
			
		||||
        field_info = None
 | 
			
		||||
        encompassing_expression_is_negated = False
 | 
			
		||||
        result = ""
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -328,7 +416,7 @@ class FindQuery:
 | 
			
		|||
        if expression.op is Operators.ALL:
 | 
			
		||||
            if encompassing_expression_is_negated:
 | 
			
		||||
                # TODO: Is there a use case for this, perhaps for dynamic
 | 
			
		||||
                # scoring purposes?
 | 
			
		||||
                #  scoring purposes?
 | 
			
		||||
                raise QueryNotSupportedError("You cannot negate a query for all results.")
 | 
			
		||||
            return "*"
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -338,6 +426,7 @@ class FindQuery:
 | 
			
		|||
        elif isinstance(expression.left, ModelField):
 | 
			
		||||
            field_type = cls.resolve_field_type(expression.left)
 | 
			
		||||
            field_name = expression.left.name
 | 
			
		||||
            field_info = expression.left.field_info
 | 
			
		||||
        else:
 | 
			
		||||
            raise QueryNotSupportedError(f"A query expression should start with either a field "
 | 
			
		||||
                                         f"or an expression enclosed in parenthesis. See docs: "
 | 
			
		||||
| 
						 | 
				
			
			@ -365,8 +454,7 @@ class FindQuery:
 | 
			
		|||
            if isinstance(right, ModelField):
 | 
			
		||||
                raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO")
 | 
			
		||||
            else:
 | 
			
		||||
                # TODO: Optionals causing IDE errors here
 | 
			
		||||
                result += cls.resolve_value(field_name, field_type, expression.op, right)
 | 
			
		||||
                result += cls.resolve_value(field_name, field_type, field_info, expression.op, right)
 | 
			
		||||
 | 
			
		||||
        if encompassing_expression_is_negated:
 | 
			
		||||
            result = f"-({result})"
 | 
			
		||||
| 
						 | 
				
			
			@ -416,7 +504,10 @@ class FindQuery:
 | 
			
		|||
    def first(self):
 | 
			
		||||
        query = FindQuery(expressions=self.expressions, model=self.model,
 | 
			
		||||
                          offset=0, limit=1, sort_fields=self.sort_fields)
 | 
			
		||||
        return query.execute()[0]
 | 
			
		||||
        results = query.execute()
 | 
			
		||||
        if not results:
 | 
			
		||||
            raise NotFoundError()
 | 
			
		||||
        return results[0]
 | 
			
		||||
 | 
			
		||||
    def all(self, batch_size=10):
 | 
			
		||||
        if batch_size != self.page_size:
 | 
			
		||||
| 
						 | 
				
			
			@ -494,9 +585,13 @@ class PrimaryKeyCreator(Protocol):
 | 
			
		|||
        """Create a new primary key"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Uuid4PrimaryKey:
 | 
			
		||||
    def create_pk(self, *args, **kwargs) -> str:
 | 
			
		||||
        return str(uuid.uuid4())
 | 
			
		||||
class UlidPrimaryKey:
 | 
			
		||||
    """A client-side generated primary key that follows the ULID spec.
 | 
			
		||||
        https://github.com/ulid/javascript#specification
 | 
			
		||||
    """
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def create_pk(*args, **kwargs) -> str:
 | 
			
		||||
        return str(ULID())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def __dataclass_transform__(
 | 
			
		||||
| 
						 | 
				
			
			@ -601,8 +696,24 @@ class PrimaryKey:
 | 
			
		|||
    field: ModelField
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MetaProtocol(Protocol):
 | 
			
		||||
    global_key_prefix: str
 | 
			
		||||
    model_key_prefix: str
 | 
			
		||||
    primary_key_pattern: str
 | 
			
		||||
    database: redis.Redis
 | 
			
		||||
    primary_key: PrimaryKey
 | 
			
		||||
    primary_key_creator_cls: Type[PrimaryKeyCreator]
 | 
			
		||||
    index_name: str
 | 
			
		||||
    abstract: bool
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class DefaultMeta:
 | 
			
		||||
    # TODO: Should this really be optional here?
 | 
			
		||||
    """A default placeholder Meta object.
 | 
			
		||||
 | 
			
		||||
    TODO: Revisit whether this is really necessary, and whether making
 | 
			
		||||
     these all optional here is the right choice.
 | 
			
		||||
    """
 | 
			
		||||
    global_key_prefix: Optional[str] = None
 | 
			
		||||
    model_key_prefix: Optional[str] = None
 | 
			
		||||
    primary_key_pattern: Optional[str] = None
 | 
			
		||||
| 
						 | 
				
			
			@ -614,6 +725,8 @@ class DefaultMeta:
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class ModelMeta(ModelMetaclass):
 | 
			
		||||
    _meta: MetaProtocol
 | 
			
		||||
 | 
			
		||||
    def __new__(cls, name, bases, attrs, **kwargs):  # noqa C901
 | 
			
		||||
        meta = attrs.pop('Meta', None)
 | 
			
		||||
        new_class = super().__new__(cls, name, bases, attrs, **kwargs)
 | 
			
		||||
| 
						 | 
				
			
			@ -656,11 +769,16 @@ class ModelMeta(ModelMetaclass):
 | 
			
		|||
                                               redis.Redis(decode_responses=True))
 | 
			
		||||
        if not getattr(new_class._meta, 'primary_key_creator_cls', None):
 | 
			
		||||
            new_class._meta.primary_key_creator_cls = getattr(base_meta, "primary_key_creator_cls",
 | 
			
		||||
                                                              Uuid4PrimaryKey)
 | 
			
		||||
                                                              UlidPrimaryKey)
 | 
			
		||||
        if not getattr(new_class._meta, 'index_name', None):
 | 
			
		||||
            new_class._meta.index_name = f"{new_class._meta.global_key_prefix}:" \
 | 
			
		||||
                                         f"{new_class._meta.model_key_prefix}:index"
 | 
			
		||||
 | 
			
		||||
        # Not an abstract model class
 | 
			
		||||
        if abc.ABC not in bases:
 | 
			
		||||
            key = f"{new_class.__module__}.{new_class.__qualname__}"
 | 
			
		||||
            model_registry[key] = new_class
 | 
			
		||||
 | 
			
		||||
        return new_class
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -680,12 +798,8 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
 | 
			
		|||
        __pydantic_self__.validate_primary_key()
 | 
			
		||||
 | 
			
		||||
    def __lt__(self, other):
 | 
			
		||||
        """Default sort: compare all shared model fields."""
 | 
			
		||||
        my_keys = set(self.__fields__.keys())
 | 
			
		||||
        other_keys = set(other.__fields__.keys())
 | 
			
		||||
        shared_keys = list(my_keys & other_keys)
 | 
			
		||||
        lt = [getattr(self, k) < getattr(other, k) for k in shared_keys]
 | 
			
		||||
        return len(lt) > len(shared_keys) / 2
 | 
			
		||||
        """Default sort: compare primary key of models."""
 | 
			
		||||
        return self.pk < other.pk
 | 
			
		||||
 | 
			
		||||
    @validator("pk", always=True)
 | 
			
		||||
    def validate_pk(cls, v):
 | 
			
		||||
| 
						 | 
				
			
			@ -726,7 +840,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
 | 
			
		|||
        return cls._meta.database
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def find(cls, *expressions: Union[Any, Expression]):  # TODO: How to type annotate this?
 | 
			
		||||
    def find(cls, *expressions: Union[Any, Expression]) -> FindQuery:  # TODO: How to type annotate this?
 | 
			
		||||
        return FindQuery(expressions=expressions, model=cls)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
| 
						 | 
				
			
			@ -760,7 +874,17 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
 | 
			
		|||
            except KeyError:
 | 
			
		||||
                pass
 | 
			
		||||
 | 
			
		||||
            doc = cls(**fields)
 | 
			
		||||
            try:
 | 
			
		||||
                fields['json'] = fields['$']
 | 
			
		||||
                del fields['$']
 | 
			
		||||
            except KeyError:
 | 
			
		||||
                pass
 | 
			
		||||
 | 
			
		||||
            if 'json' in fields:
 | 
			
		||||
                json_fields = json.loads(fields['json'])
 | 
			
		||||
                doc = cls(**json_fields)
 | 
			
		||||
            else:
 | 
			
		||||
                doc = cls(**fields)
 | 
			
		||||
            docs.append(doc)
 | 
			
		||||
        return docs
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -847,7 +971,7 @@ class HashModel(RedisModel, abc.ABC):
 | 
			
		|||
            _type = field.outer_type_
 | 
			
		||||
            if getattr(field.field_info, 'primary_key', None):
 | 
			
		||||
                if issubclass(_type, str):
 | 
			
		||||
                    redisearch_field = f"{name} TAG"
 | 
			
		||||
                    redisearch_field = f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
 | 
			
		||||
                else:
 | 
			
		||||
                    redisearch_field = cls.schema_for_type(name, _type, field.field_info)
 | 
			
		||||
                schema_parts.append(redisearch_field)
 | 
			
		||||
| 
						 | 
				
			
			@ -872,7 +996,7 @@ class HashModel(RedisModel, abc.ABC):
 | 
			
		|||
        return schema_parts
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def schema_for_type(cls, name, typ: Type, field_info: FieldInfo):
 | 
			
		||||
    def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
 | 
			
		||||
        if get_origin(typ) == list:
 | 
			
		||||
            embedded_cls = get_args(typ)
 | 
			
		||||
            if not embedded_cls:
 | 
			
		||||
| 
						 | 
				
			
			@ -885,9 +1009,10 @@ class HashModel(RedisModel, abc.ABC):
 | 
			
		|||
            return f"{name} NUMERIC"
 | 
			
		||||
        elif issubclass(typ, str):
 | 
			
		||||
            if getattr(field_info, 'full_text_search', False) is True:
 | 
			
		||||
                return f"{name} TAG {name}_fts TEXT"
 | 
			
		||||
                return f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR} " \
 | 
			
		||||
                       f"{name}_fts TEXT"
 | 
			
		||||
            else:
 | 
			
		||||
                return f"{name} TAG"
 | 
			
		||||
                return f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
 | 
			
		||||
        elif issubclass(typ, RedisModel):
 | 
			
		||||
            sub_fields = []
 | 
			
		||||
            for embedded_name, field in typ.__fields__.items():
 | 
			
		||||
| 
						 | 
				
			
			@ -895,7 +1020,7 @@ class HashModel(RedisModel, abc.ABC):
 | 
			
		|||
                                                      field.field_info))
 | 
			
		||||
            return " ".join(sub_fields)
 | 
			
		||||
        else:
 | 
			
		||||
            return f"{name} TAG"
 | 
			
		||||
            return f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class JsonModel(RedisModel, abc.ABC):
 | 
			
		||||
| 
						 | 
				
			
			@ -927,7 +1052,7 @@ class JsonModel(RedisModel, abc.ABC):
 | 
			
		|||
            _type = field.outer_type_
 | 
			
		||||
            if getattr(field.field_info, 'primary_key', None):
 | 
			
		||||
                if issubclass(_type, str):
 | 
			
		||||
                    redisearch_field = f"{json_path}.{name} AS {name} TAG"
 | 
			
		||||
                    redisearch_field = f"{json_path}.{name} AS {name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
 | 
			
		||||
                else:
 | 
			
		||||
                    redisearch_field = cls.schema_for_type(f"{json_path}.{name}", name, "", _type, field.field_info)
 | 
			
		||||
                schema_parts.append(redisearch_field)
 | 
			
		||||
| 
						 | 
				
			
			@ -957,8 +1082,8 @@ class JsonModel(RedisModel, abc.ABC):
 | 
			
		|||
    #  find it in the JSON document, AND the name of the field as it should
 | 
			
		||||
    #  be in the redisearch schema (address_address_line_1). Maybe both "name"
 | 
			
		||||
    #  and "name_prefix"?
 | 
			
		||||
    def schema_for_type(cls, json_path: str, name: str, name_prefix: str, typ: Type,
 | 
			
		||||
                        field_info: FieldInfo) -> str:
 | 
			
		||||
    def schema_for_type(cls, json_path: str, name: str, name_prefix: str, typ: Any,
 | 
			
		||||
                        field_info: PydanticFieldInfo) -> str:
 | 
			
		||||
        index_field_name = f"{name_prefix}{name}"
 | 
			
		||||
        should_index = getattr(field_info, 'index', False)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -986,10 +1111,10 @@ class JsonModel(RedisModel, abc.ABC):
 | 
			
		|||
                return f"{json_path} AS {index_field_name} NUMERIC"
 | 
			
		||||
            elif issubclass(typ, str):
 | 
			
		||||
                if getattr(field_info, 'full_text_search', False) is True:
 | 
			
		||||
                    return f"{json_path} AS {index_field_name} TAG " \
 | 
			
		||||
                    return f"{json_path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR} " \
 | 
			
		||||
                           f"{json_path} AS {index_field_name}_fts TEXT"
 | 
			
		||||
                else:
 | 
			
		||||
                    return f"{json_path} AS {index_field_name} TAG"
 | 
			
		||||
                    return f"{json_path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
 | 
			
		||||
            else:
 | 
			
		||||
                return f"{json_path} AS {index_field_name} TAG"
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										59
									
								
								redis_developer/orm/render_tree.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								redis_developer/orm/render_tree.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,59 @@
 | 
			
		|||
"""
 | 
			
		||||
This code adapted from the library "pptree," Copyright (c) 2017 Clément Michard
 | 
			
		||||
and released under the MIT license: https://github.com/clemtoy/pptree
 | 
			
		||||
"""
 | 
			
		||||
import io
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def render_tree(current_node, nameattr='name', left_child='left',
 | 
			
		||||
                           right_child='right', indent='', last='updown',
 | 
			
		||||
                           buffer=None):
 | 
			
		||||
    """Print a tree-like structure, `current_node`.
 | 
			
		||||
 | 
			
		||||
    This is a mostly-direct-copy of the print_tree() function from the ppbtree
 | 
			
		||||
    module of the pptree library, but instead of printing to standard out, we
 | 
			
		||||
    write to a StringIO buffer, then use that buffer to accumulate written lines
 | 
			
		||||
    during recursive calls to render_tree().
 | 
			
		||||
    """
 | 
			
		||||
    if buffer is None:
 | 
			
		||||
        buffer = io.StringIO()
 | 
			
		||||
    if hasattr(current_node, nameattr):
 | 
			
		||||
        name = lambda node: getattr(node, nameattr)
 | 
			
		||||
    else:
 | 
			
		||||
        name = lambda node: str(node)
 | 
			
		||||
 | 
			
		||||
    up = getattr(current_node, left_child, None)
 | 
			
		||||
    down = getattr(current_node, right_child, None)
 | 
			
		||||
 | 
			
		||||
    if up is not None:
 | 
			
		||||
        next_last = 'up'
 | 
			
		||||
        next_indent = '{0}{1}{2}'.format(indent, ' ' if 'up' in last else '|', ' ' * len(str(name(current_node))))
 | 
			
		||||
        render_tree(up, nameattr, left_child, right_child, next_indent, next_last, buffer)
 | 
			
		||||
 | 
			
		||||
    if last == 'up':
 | 
			
		||||
        start_shape = '┌'
 | 
			
		||||
    elif last == 'down':
 | 
			
		||||
        start_shape = '└'
 | 
			
		||||
    elif last == 'updown':
 | 
			
		||||
        start_shape = ' '
 | 
			
		||||
    else:
 | 
			
		||||
        start_shape = '├'
 | 
			
		||||
 | 
			
		||||
    if up is not None and down is not None:
 | 
			
		||||
        end_shape = '┤'
 | 
			
		||||
    elif up:
 | 
			
		||||
        end_shape = '┘'
 | 
			
		||||
    elif down:
 | 
			
		||||
        end_shape = '┐'
 | 
			
		||||
    else:
 | 
			
		||||
        end_shape = ''
 | 
			
		||||
 | 
			
		||||
    print('{0}{1}{2}{3}'.format(indent, start_shape, name(current_node), end_shape),
 | 
			
		||||
          file=buffer)
 | 
			
		||||
 | 
			
		||||
    if down is not None:
 | 
			
		||||
        next_last = 'down'
 | 
			
		||||
        next_indent = '{0}{1}{2}'.format(indent, ' ' if 'down' in last else '|', ' ' * len(str(name(current_node))))
 | 
			
		||||
        render_tree(down, nameattr, left_child, right_child, next_indent, next_last, buffer)
 | 
			
		||||
 | 
			
		||||
    return f"\n{buffer.getvalue()}"
 | 
			
		||||
							
								
								
									
										24
									
								
								redis_developer/orm/token_escaper.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								redis_developer/orm/token_escaper.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,24 @@
 | 
			
		|||
import re
 | 
			
		||||
from typing import Optional, Pattern
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TokenEscaper:
 | 
			
		||||
    """
 | 
			
		||||
    Escape punctuation within an input string.
 | 
			
		||||
    """
 | 
			
		||||
    # Characters that RediSearch requires us to escape during queries.
 | 
			
		||||
    # Source: https://oss.redis.com/redisearch/Escaping/#the_rules_of_text_field_tokenization
 | 
			
		||||
    DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\ ]"
 | 
			
		||||
 | 
			
		||||
    def __init__(self, escape_chars_re: Optional[Pattern] = None):
 | 
			
		||||
        if escape_chars_re:
 | 
			
		||||
            self.escaped_chars_re = escape_chars_re
 | 
			
		||||
        else:
 | 
			
		||||
            self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS)
 | 
			
		||||
 | 
			
		||||
    def escape(self, string: str) -> str:
 | 
			
		||||
        def escape_symbol(match):
 | 
			
		||||
            value = match.group(0)
 | 
			
		||||
            return f"\\{value}"
 | 
			
		||||
 | 
			
		||||
        return self.escaped_chars_re.sub(escape_symbol, string)
 | 
			
		||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue