import abc import dataclasses import decimal import operator import re from copy import deepcopy from enum import Enum from functools import reduce from typing import ( AbstractSet, Any, Callable, Dict, Mapping, Optional, Set, Tuple, TypeVar, Union, Sequence, no_type_check, Protocol, List, Type, Pattern ) import uuid import redis from pydantic import BaseModel, validator from pydantic.fields import FieldInfo as PydanticFieldInfo, PrivateAttr, Field from pydantic.fields import ModelField, Undefined, UndefinedType from pydantic.main import ModelMetaclass from pydantic.typing import NoArgAnyCallable from pydantic.utils import Representation from .encoders import jsonable_encoder model_registry = {} _T = TypeVar("_T") class TokenEscaper: """ Escape punctuation within an input string. """ # Characters that RediSearch requires us to escape during queries. # Source: https://oss.redis.com/redisearch/Escaping/#the_rules_of_text_field_tokenization DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\ ]" def __init__(self, escape_chars_re: Optional[Pattern] = None): if escape_chars_re: self.escaped_chars_re = escape_chars_re else: self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS) def escape(self, string): def escape_symbol(match): value = match.group(0) return f"\\{value}" return self.escaped_chars_re.sub(escape_symbol, string) escaper = TokenEscaper() class RedisModelError(Exception): pass class NotFoundError(Exception): """ A query found no results. TODO: embed in Model class? """ class Operators(Enum): EQ = 1 NE = 2 LT = 3 LE = 4 GT = 5 GE = 6 OR = 7 AND = 8 NOT = 9 IN = 10 NOT_IN = 11 LIKE = 12 ALL = 13 @dataclasses.dataclass class NegatedExpression: expression: 'Expression' def __invert__(self): return self.expression def __and__(self, other): return Expression(left=self, op=Operators.AND, right=other) def __or__(self, other): return Expression(left=self, op=Operators.OR, right=other) @dataclasses.dataclass class Expression: op: Operators left: Any right: Any def __invert__(self): return NegatedExpression(self) def __and__(self, other): return Expression(left=self, op=Operators.AND, right=other) def __or__(self, other): return Expression(left=self, op=Operators.OR, right=other) ExpressionOrNegated = Union[Expression, NegatedExpression] class ExpressionProxy: def __init__(self, field: ModelField): self.field = field def __eq__(self, other: Any) -> Expression: return Expression(left=self.field, op=Operators.EQ, right=other) def __ne__(self, other: Any) -> Expression: return Expression(left=self.field, op=Operators.NE, right=other) def __lt__(self, other: Any) -> Expression: return Expression(left=self.field, op=Operators.LT, right=other) def __le__(self, other: Any) -> Expression: return Expression(left=self.field, op=Operators.LE, right=other) def __gt__(self, other: Any) -> Expression: return Expression(left=self.field, op=Operators.GT, right=other) def __ge__(self, other: Any) -> Expression: return Expression(left=self.field, op=Operators.GE, right=other) class QueryNotSupportedError(Exception): """The attempted query is not supported.""" class RediSearchFieldTypes(Enum): TEXT = 'TEXT' TAG = 'TAG' NUMERIC = 'NUMERIC' GEO = 'GEO' # TODO: How to handle Geo fields? NUMERIC_TYPES = (float, int, decimal.Decimal) DEFAULT_PAGE_SIZE = 10 class FindQuery(BaseModel): expressions: Sequence[ExpressionOrNegated] model: Type['RedisModel'] offset: int = 0 limit: int = DEFAULT_PAGE_SIZE page_size: int = DEFAULT_PAGE_SIZE sort_fields: Optional[List[str]] = Field(default_factory=list) _expression: Expression = PrivateAttr(default=None) _query: str = PrivateAttr(default=None) _pagination: List[str] = PrivateAttr(default_factory=list) _model_cache: Optional[List['RedisModel']] = PrivateAttr(default_factory=list) class Config: arbitrary_types_allowed = True @property def pagination(self): if self._pagination: return self._pagination self._pagination = self.resolve_redisearch_pagination() return self._pagination @property def expression(self): if self._expression: return self._expression if self.expressions: self._expression = reduce(operator.and_, self.expressions) else: self._expression = Expression(left=None, right=None, op=Operators.ALL) return self._expression @property def query(self): return self.resolve_redisearch_query(self.expression) @validator("sort_fields") def validate_sort_fields(cls, v, values): model = values['model'] for sort_field in v: field_name = sort_field.lstrip("-") if field_name not in model.__fields__: raise QueryNotSupportedError(f"You tried sort by {field_name}, but that field " f"does not exist on the model {model}") field_proxy = getattr(model, field_name) if not getattr(field_proxy.field.field_info, 'sortable', False): raise QueryNotSupportedError(f"You tried sort by {field_name}, but {cls} does " "not define that field as sortable. See docs: XXX") return v @staticmethod def resolve_field_type(field: ModelField) -> RediSearchFieldTypes: if getattr(field.field_info, 'primary_key', None) is True: return RediSearchFieldTypes.TAG elif getattr(field.field_info, 'full_text_search', None) is True: return RediSearchFieldTypes.TEXT field_type = field.outer_type_ # TODO: GEO if any(issubclass(field_type, t) for t in NUMERIC_TYPES): return RediSearchFieldTypes.NUMERIC else: # TAG fields are the default field type. # TODO: A ListField or ArrayField that supports multiple values # and contains logic. return RediSearchFieldTypes.TAG @staticmethod def expand_tag_value(value): err = RedisModelError(f"Using the IN operator requires passing a sequence of " "possible values. You passed: {value}") if isinstance(str, value): raise err try: expanded_value = "|".join([escaper.escape(v) for v in value]) except TypeError: raise err return expanded_value @classmethod def resolve_value(cls, field_name: str, field_type: RediSearchFieldTypes, op: Operators, value: Any) -> str: result = "" if field_type is RediSearchFieldTypes.TEXT: result = f"@{field_name}:" if op is Operators.EQ: result += f'"{value}"' elif op is Operators.NE: result = f'-({result}"{value}")' elif op is Operators.LIKE: result += value else: raise QueryNotSupportedError("Only equals (=), not-equals (!=), and like() " "comparisons are supported for TEXT fields. See " "docs: TODO.") elif field_type is RediSearchFieldTypes.NUMERIC: if op is Operators.EQ: result += f"@{field_name}:[{value} {value}]" elif op is Operators.NE: # TODO: Is this enough or do we also need a clause for all values # ([-inf +inf]) from which we then subtract the undesirable value? result += f"-(@{field_name}:[{value} {value}])" elif op is Operators.GT: result += f"@{field_name}:[({value} +inf]" elif op is Operators.LT: result += f"@{field_name}:[-inf ({value}]" elif op is Operators.GE: result += f"@{field_name}:[{value} +inf]" elif op is Operators.LE: result += f"@{field_name}:[-inf {value}]" elif field_type is RediSearchFieldTypes.TAG: if op is Operators.EQ: value = escaper.escape(value) result += f"@{field_name}:{{{value}}}" elif op is Operators.NE: value = escaper.escape(value) result += f"-(@{field_name}:{{{value}}})" elif op is Operators.IN: expanded_value = cls.expand_tag_value(value) result += f"(@{field_name}:{{{expanded_value}}})" elif op is Operators.NOT_IN: expanded_value = cls.expand_tag_value(value) result += f"-(@{field_name}:{{{expanded_value}}})" return result def resolve_redisearch_pagination(self): """Resolve pagination options for a query.""" return ["LIMIT", self.offset, self.limit] def resolve_redisearch_sort_fields(self): """Resolve sort options for a query.""" if not self.sort_fields: return fields = [] for f in self.sort_fields: direction = "desc" if f.startswith('-') else 'asc' fields.extend([f.lstrip('-'), direction]) if self.sort_fields: return ["SORTBY", *fields] @classmethod def resolve_redisearch_query(cls, expression: ExpressionOrNegated): """Resolve an expression to a string RediSearch query.""" field_type = None field_name = None encompassing_expression_is_negated = False result = "" if isinstance(expression, NegatedExpression): encompassing_expression_is_negated = True expression = expression.expression if expression.op is Operators.ALL: if encompassing_expression_is_negated: # TODO: Is there a use case for this, perhaps for dynamic # scoring purposes? raise QueryNotSupportedError("You cannot negate a query for all results.") return "*" if isinstance(expression.left, Expression) or \ isinstance(expression.left, NegatedExpression): result += f"({cls.resolve_redisearch_query(expression.left)})" elif isinstance(expression.left, ModelField): field_type = cls.resolve_field_type(expression.left) field_name = expression.left.name else: import ipdb; ipdb.set_trace() raise QueryNotSupportedError(f"A query expression should start with either a field " f"or an expression enclosed in parenthesis. See docs: " f"TODO") right = expression.right if isinstance(right, Expression) or isinstance(right, NegatedExpression): if expression.op == Operators.AND: result += " " elif expression.op == Operators.OR: result += "| " else: raise QueryNotSupportedError("You can only combine two query expressions with" "AND (&) or OR (|). See docs: TODO") if isinstance(right, NegatedExpression): result += "-" # We're handling the RediSearch operator in this call ("-"), so resolve the # inner expression instead of the NegatedExpression. right = right.expression result += f"({cls.resolve_redisearch_query(right)})" else: if isinstance(right, ModelField): raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO") else: # TODO: Optionals causing IDE errors here result += cls.resolve_value(field_name, field_type, expression.op, right) if encompassing_expression_is_negated: result = f"-({result})" return result def execute(self, exhaust_results=True): args = ["ft.search", self.model.Meta.index_name, self.query, *self.pagination] if self.sort_fields: args += self.resolve_redisearch_sort_fields() # Reset the cache if we're executing from offset 0. if self.offset == 0: self._model_cache.clear() # If the offset is greater than 0, we're paginating through a result set, # so append the new results to results already in the cache. raw_result = self.model.db().execute_command(*args) count = raw_result[0] results = self.model.from_redis(raw_result) self._model_cache += results if not exhaust_results: return self._model_cache # The query returned all results, so we have no more work to do. if count <= len(results): return self._model_cache # Transparently (to the user) make subsequent requests to paginate # through the results and finally return them all. query = self while True: # Make a query for each pass of the loop, with a new offset equal to the # current offset plus `page_size`, until we stop getting results back. query = FindQuery(expressions=query.expressions, model=query.model, offset=query.offset + query.page_size, page_size=query.page_size, limit=query.limit) _results = query.execute(exhaust_results=False) if not _results: break self._model_cache += _results return self._model_cache def first(self): query = FindQuery(expressions=self.expressions, model=self.model, offset=0, limit=1, sort_fields=self.sort_fields) return query.execute()[0] def all(self, batch_size=10): if batch_size != self.page_size: # TODO: There's probably a copy-with-change mechanism in Pydantic, # or can we use one from dataclasses? query = FindQuery(expressions=self.expressions, model=self.model, offset=self.offset, page_size=batch_size, limit=batch_size, sort_fields=self.sort_fields) return query.execute() return self.execute() def sort_by(self, *fields): if not fields: return self return FindQuery(expressions=self.expressions, model=self.model, offset=self.offset, page_size=self.page_size, limit=self.limit, sort_fields=list(fields)) def update(self, **kwargs): """Update all matching records in this query.""" # TODO def delete(cls, **field_values): """Delete all matching records in this query.""" for field_name, value in field_values: valid_attr = hasattr(cls.model, field_name) if not valid_attr: raise RedisModelError(f"Can't update field {field_name} because " f"the field does not exist on the model {cls}") return cls def __iter__(self): if self._model_cache: for m in self._model_cache: yield m else: for m in self.execute(): yield m def __getitem__(self, item: int): """ Given this code: Model.find()[1000] We should return only the 1000th result. 1. If the result is loaded in the query cache for this query, we can return it directly from the cache. 2. If the query cache does not have enough elements to return that result, then we should clone the current query and give it a new offset and limit: offset=n, limit=1. """ if self._model_cache and len(self._model_cache) >= item: return self._model_cache[item] query = FindQuery(expressions=self.expressions, model=self.model, offset=item, sort_fields=self.sort_fields, limit=1) return query.execute()[0] class PrimaryKeyCreator(Protocol): def create_pk(self, *args, **kwargs) -> str: """Create a new primary key""" class Uuid4PrimaryKey: def create_pk(self, *args, **kwargs) -> str: return str(uuid.uuid4()) def __dataclass_transform__( *, eq_default: bool = True, order_default: bool = False, kw_only_default: bool = False, field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()), ) -> Callable[[_T], _T]: return lambda a: a class FieldInfo(PydanticFieldInfo): def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: primary_key = kwargs.pop("primary_key", False) sortable = kwargs.pop("sortable", Undefined) index = kwargs.pop("index", Undefined) full_text_search = kwargs.pop("full_text_search", Undefined) super().__init__(default=default, **kwargs) self.primary_key = primary_key self.sortable = sortable self.index = index self.full_text_search = full_text_search class RelationshipInfo(Representation): def __init__( self, *, back_populates: Optional[str] = None, link_model: Optional[Any] = None, ) -> None: self.back_populates = back_populates self.link_model = link_model def Field( default: Any = Undefined, *, default_factory: Optional[NoArgAnyCallable] = None, alias: str = None, title: str = None, description: str = None, exclude: Union[ AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any ] = None, include: Union[ AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any ] = None, const: bool = None, gt: float = None, ge: float = None, lt: float = None, le: float = None, multiple_of: float = None, min_items: int = None, max_items: int = None, min_length: int = None, max_length: int = None, allow_mutation: bool = True, regex: str = None, primary_key: bool = False, sortable: Union[bool, UndefinedType] = Undefined, index: Union[bool, UndefinedType] = Undefined, full_text_search: Union[bool, UndefinedType] = Undefined, schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: current_schema_extra = schema_extra or {} field_info = FieldInfo( default, default_factory=default_factory, alias=alias, title=title, description=description, exclude=exclude, include=include, const=const, gt=gt, ge=ge, lt=lt, le=le, multiple_of=multiple_of, min_items=min_items, max_items=max_items, min_length=min_length, max_length=max_length, allow_mutation=allow_mutation, regex=regex, primary_key=primary_key, sortable=sortable, index=index, full_text_search=full_text_search, **current_schema_extra, ) field_info._validate() return field_info @dataclasses.dataclass class PrimaryKey: name: str field: ModelField class DefaultMeta: # TODO: Should this really be optional here? global_key_prefix: Optional[str] = None model_key_prefix: Optional[str] = None primary_key_pattern: Optional[str] = None database: Optional[redis.Redis] = None primary_key: Optional[PrimaryKey] = None primary_key_creator_cls: Optional[Type[PrimaryKeyCreator]] = None index_name: Optional[str] = None abstract: Optional[bool] = False class ModelMeta(ModelMetaclass): def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 meta = attrs.pop('Meta', None) new_class = super().__new__(cls, name, bases, attrs, **kwargs) meta = meta or getattr(new_class, 'Meta', None) base_meta = getattr(new_class, '_meta', None) if meta and meta != DefaultMeta and meta != base_meta: new_class.Meta = meta new_class._meta = meta elif base_meta: new_class._meta = deepcopy(base_meta) new_class.Meta = new_class._meta # Unset inherited values we don't want to reuse (typically based on the model name). new_class._meta.abstract = False new_class._meta.model_key_prefix = None new_class._meta.index_name = None else: new_class._meta = deepcopy(DefaultMeta) new_class.Meta = new_class._meta # Not an abstract model class if abc.ABC not in bases: key = f"{new_class.__module__}.{new_class.__name__}" model_registry[key] = new_class # Create proxies for each model field so that we can use the field # in queries, like Model.get(Model.field_name == 1) for field_name, field in new_class.__fields__.items(): setattr(new_class, field_name, ExpressionProxy(field)) # Check if this is our FieldInfo version with extended ORM metadata. if isinstance(field.field_info, FieldInfo): if field.field_info.primary_key: new_class._meta.primary_key = PrimaryKey(name=field_name, field=field) if not getattr(new_class._meta, 'global_key_prefix', None): new_class._meta.global_key_prefix = getattr(base_meta, "global_key_prefix", "") if not getattr(new_class._meta, 'model_key_prefix', None): # Don't look at the base class for this. new_class._meta.model_key_prefix = f"{new_class.__module__}.{new_class.__name__}" if not getattr(new_class._meta, 'primary_key_pattern', None): new_class._meta.primary_key_pattern = getattr(base_meta, "primary_key_pattern", "{pk}") if not getattr(new_class._meta, 'database', None): new_class._meta.database = getattr(base_meta, "database", redis.Redis(decode_responses=True)) if not getattr(new_class._meta, 'primary_key_creator_cls', None): new_class._meta.primary_key_creator_cls = getattr(base_meta, "primary_key_creator_cls", Uuid4PrimaryKey) if not getattr(new_class._meta, 'index_name', None): new_class._meta.index_name = f"{new_class._meta.global_key_prefix}:" \ f"{new_class._meta.model_key_prefix}:index" return new_class class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): pk: Optional[str] = Field(default=None, primary_key=True) Meta = DefaultMeta # TODO: Missing _meta here is causing IDE warnings. class Config: orm_mode = True arbitrary_types_allowed = True extra = 'allow' def __init__(__pydantic_self__, **data: Any) -> None: super().__init__(**data) __pydantic_self__.validate_primary_key() def __lt__(self, other): """Default sort: compare all shared model fields.""" my_keys = set(self.__fields__.keys()) other_keys = set(other.__fields__.keys()) shared_keys = list(my_keys & other_keys) lt = [getattr(self, k) < getattr(other, k) for k in shared_keys] return len(lt) > len(shared_keys) / 2 @validator("pk", always=True) def validate_pk(cls, v): if not v: v = cls._meta.primary_key_creator_cls().create_pk() return v @classmethod def validate_primary_key(cls): """Check for a primary key. We need one (and only one).""" primary_keys = 0 for name, field in cls.__fields__.items(): if getattr(field.field_info, 'primary_key', None): primary_keys += 1 if primary_keys == 0: raise RedisModelError("You must define a primary key for the model") elif primary_keys > 1: raise RedisModelError("You must define only one primary key for a model") @classmethod def make_key(cls, part: str): global_prefix = getattr(cls._meta, 'global_key_prefix', '').strip(":") model_prefix = getattr(cls._meta, 'model_key_prefix', '').strip(":") return f"{global_prefix}:{model_prefix}:{part}" @classmethod def make_primary_key(cls, pk: Any): """Return the Redis key for this model.""" return cls.make_key(cls._meta.primary_key_pattern.format(pk=pk)) def key(self): """Return the Redis key for this model.""" pk = getattr(self, self._meta.primary_key.field.name) return self.make_primary_key(pk) @classmethod def db(cls): return cls._meta.database @classmethod def find(cls, *expressions: Expression): return FindQuery(expressions=expressions, model=cls) @classmethod def from_redis(cls, res: Any): # TODO: Parsing logic borrowed from redisearch-py. Evaluate. import six from six.moves import xrange, zip as izip def to_string(s): if isinstance(s, six.string_types): return s elif isinstance(s, six.binary_type): return s.decode('utf-8', 'ignore') else: return s # Not a string we care about docs = [] step = 2 # Because the result has content offset = 1 # The first item is the count of total matches. for i in xrange(1, len(res), step): fields_offset = offset fields = dict( dict(izip(map(to_string, res[i + fields_offset][::2]), map(to_string, res[i + fields_offset][1::2]))) ) try: del fields['id'] except KeyError: pass doc = cls(**fields) docs.append(doc) return docs @classmethod def add(cls, models: Sequence['RedisModel']) -> Sequence['RedisModel']: return [model.save() for model in models] @classmethod def update(cls, **field_values): """Update this model instance.""" return cls @classmethod def values(cls): """Return raw values from Redis instead of model instances.""" return cls def delete(self): return self.db().delete(self.key()) def save(self, *args, **kwargs) -> 'RedisModel': raise NotImplementedError @classmethod def schema(cls): raise NotImplementedError class HashModel(RedisModel, abc.ABC): def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) for name, field in cls.__fields__.items(): if issubclass(field.outer_type_, RedisModel): raise RedisModelError(f"HashModels cannot have embedded model " f"fields. Field: {name}") for typ in (Set, Mapping, List): if issubclass(field.outer_type_, typ): raise RedisModelError(f"HashModels cannot have set, list," f" or mapping fields. Field: {name}") def save(self, *args, **kwargs) -> 'HashModel': document = jsonable_encoder(self.dict()) success = self.db().hset(self.key(), mapping=document) return success @classmethod def get(cls, pk: Any) -> 'HashModel': document = cls.db().hgetall(cls.make_primary_key(pk)) if not document: raise NotFoundError return cls.parse_obj(document) @classmethod @no_type_check def _get_value(cls, *args, **kwargs) -> Any: """ Always send None as an empty string. TODO: We do this because redis-py's hset() method requires non-null values. Is there a better way? """ val = super()._get_value(*args, **kwargs) if val is None: return "" return val @classmethod def schema_for_type(cls, name, typ: Type, field_info: FieldInfo): if any(issubclass(typ, t) for t in NUMERIC_TYPES): return f"{name} NUMERIC" elif issubclass(typ, str): if getattr(field_info, 'full_text_search', False) is True: return f"{name} TAG {name}_fts TEXT" else: return f"{name} TAG" else: return f"{name} TAG" @classmethod def schema(cls): hash_prefix = cls.make_key(cls._meta.primary_key_pattern.format(pk="")) schema_prefix = f"ON HASH PREFIX 1 {hash_prefix} SCHEMA" schema_parts = [schema_prefix] for name, field in cls.__fields__.items(): _type = field.outer_type_ if getattr(field.field_info, 'primary_key', None): if issubclass(_type, str): redisearch_field = f"{name} TAG" else: redisearch_field = cls.schema_for_type(name, _type, field.field_info) schema_parts.append(redisearch_field) elif getattr(field.field_info, 'index', None) is True: schema_parts.append(cls.schema_for_type(name, _type, field.field_info)) if getattr(field.field_info, 'sortable', False) is True: schema_parts.append("SORTABLE") return " ".join(schema_parts) class JsonModel(RedisModel, abc.ABC): def save(self, *args, **kwargs) -> 'JsonModel': success = self.db().execute_command('JSON.SET', self.key(), ".", self.json()) return success @classmethod def get(cls, pk: Any) -> 'JsonModel': document = cls.db().execute_command("JSON.GET", cls.make_primary_key(pk)) if not document: raise NotFoundError return cls.parse_raw(document)