Avoid using Pydantic for FindQuery
This commit is contained in:
		
							parent
							
								
									01cab5352b
								
							
						
					
					
						commit
						ef58e854c1
					
				
					 3 changed files with 102 additions and 38 deletions
				
			
		| 
						 | 
				
			
			@ -73,7 +73,7 @@ class Migrator:
 | 
			
		|||
        for name, cls in model_registry.items():
 | 
			
		||||
            hash_key = schema_hash_key(cls.Meta.index_name)
 | 
			
		||||
            try:
 | 
			
		||||
                schema = cls.schema()
 | 
			
		||||
                schema = cls.redisearch_schema()
 | 
			
		||||
            except NotImplementedError:
 | 
			
		||||
                log.info("Skipping migrations for %s", name)
 | 
			
		||||
                continue
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,9 +1,10 @@
 | 
			
		|||
import abc
 | 
			
		||||
import dataclasses
 | 
			
		||||
import decimal
 | 
			
		||||
import inspect
 | 
			
		||||
import operator
 | 
			
		||||
import re
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from copy import deepcopy, copy
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from functools import reduce
 | 
			
		||||
from typing import (
 | 
			
		||||
| 
						 | 
				
			
			@ -41,6 +42,25 @@ model_registry = {}
 | 
			
		|||
_T = TypeVar("_T")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def subclass_exception(name, bases, module, attached_to):
 | 
			
		||||
    """
 | 
			
		||||
    Create exception subclass. Used by RedisModel below.
 | 
			
		||||
 | 
			
		||||
    The exception is created in a way that allows it to be pickled, assuming
 | 
			
		||||
    that the returned exception class will be added as an attribute to the
 | 
			
		||||
    'attached_to' class.
 | 
			
		||||
    """
 | 
			
		||||
    return type(name, bases, {
 | 
			
		||||
        '__module__': module,
 | 
			
		||||
        '__qualname__': '%s.%s' % (attached_to.__qualname__, name),
 | 
			
		||||
    })
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _has_contribute_to_class(value):
 | 
			
		||||
    # Only call contribute_to_class() if it's bound.
 | 
			
		||||
    return not inspect.isclass(value) and hasattr(value, 'contribute_to_class')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TokenEscaper:
 | 
			
		||||
    """
 | 
			
		||||
    Escape punctuation within an input string.
 | 
			
		||||
| 
						 | 
				
			
			@ -167,21 +187,29 @@ NUMERIC_TYPES = (float, int, decimal.Decimal)
 | 
			
		|||
DEFAULT_PAGE_SIZE = 10
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FindQuery(BaseModel):
 | 
			
		||||
    expressions: Sequence[ExpressionOrNegated]
 | 
			
		||||
    model: Type['RedisModel']
 | 
			
		||||
    offset: int = 0
 | 
			
		||||
    limit: int = DEFAULT_PAGE_SIZE
 | 
			
		||||
    page_size: int = DEFAULT_PAGE_SIZE
 | 
			
		||||
    sort_fields: Optional[List[str]] = Field(default_factory=list)
 | 
			
		||||
class FindQuery:
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 expressions: Sequence[ExpressionOrNegated],
 | 
			
		||||
                 model: Type['RedisModel'],
 | 
			
		||||
                 offset: int = 0,
 | 
			
		||||
                 limit: int = DEFAULT_PAGE_SIZE,
 | 
			
		||||
                 page_size: int = DEFAULT_PAGE_SIZE,
 | 
			
		||||
                 sort_fields: Optional[List[str]] = None):
 | 
			
		||||
        self.expressions = expressions
 | 
			
		||||
        self.model = model
 | 
			
		||||
        self.offset = offset
 | 
			
		||||
        self.limit = limit
 | 
			
		||||
        self.page_size = page_size
 | 
			
		||||
 | 
			
		||||
    _expression: Expression = PrivateAttr(default=None)
 | 
			
		||||
    _query: str = PrivateAttr(default=None)
 | 
			
		||||
    _pagination: List[str] = PrivateAttr(default_factory=list)
 | 
			
		||||
    _model_cache: Optional[List['RedisModel']] = PrivateAttr(default_factory=list)
 | 
			
		||||
        if sort_fields:
 | 
			
		||||
            self.sort_fields = self.validate_sort_fields(sort_fields)
 | 
			
		||||
        else:
 | 
			
		||||
            self.sort_fields = []
 | 
			
		||||
 | 
			
		||||
    class Config:
 | 
			
		||||
        arbitrary_types_allowed = True
 | 
			
		||||
        self._expression = None
 | 
			
		||||
        self._query = None
 | 
			
		||||
        self._pagination = []
 | 
			
		||||
        self._model_cache = []
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def pagination(self):
 | 
			
		||||
| 
						 | 
				
			
			@ -204,19 +232,17 @@ class FindQuery(BaseModel):
 | 
			
		|||
    def query(self):
 | 
			
		||||
        return self.resolve_redisearch_query(self.expression)
 | 
			
		||||
 | 
			
		||||
    @validator("sort_fields")
 | 
			
		||||
    def validate_sort_fields(cls, v, values):
 | 
			
		||||
        model = values['model']
 | 
			
		||||
        for sort_field in v:
 | 
			
		||||
    def validate_sort_fields(self, sort_fields):
 | 
			
		||||
        for sort_field in sort_fields:
 | 
			
		||||
            field_name = sort_field.lstrip("-")
 | 
			
		||||
            if field_name not in model.__fields__:
 | 
			
		||||
            if field_name not in self.model.__fields__:
 | 
			
		||||
                raise QueryNotSupportedError(f"You tried sort by {field_name}, but that field "
 | 
			
		||||
                                             f"does not exist on the model {model}")
 | 
			
		||||
            field_proxy = getattr(model, field_name)
 | 
			
		||||
                                             f"does not exist on the model {self.model}")
 | 
			
		||||
            field_proxy = getattr(self.model, field_name)
 | 
			
		||||
            if not getattr(field_proxy.field.field_info, 'sortable', False):
 | 
			
		||||
                raise QueryNotSupportedError(f"You tried sort by {field_name}, but {cls} does "
 | 
			
		||||
                raise QueryNotSupportedError(f"You tried sort by {field_name}, but {self.model} does "
 | 
			
		||||
                                             "not define that field as sortable. See docs: XXX")
 | 
			
		||||
        return v
 | 
			
		||||
        return sort_fields
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def resolve_field_type(field: ModelField) -> RediSearchFieldTypes:
 | 
			
		||||
| 
						 | 
				
			
			@ -429,7 +455,7 @@ class FindQuery(BaseModel):
 | 
			
		|||
            return query.execute()
 | 
			
		||||
        return self.execute()
 | 
			
		||||
 | 
			
		||||
    def sort_by(self, *fields):
 | 
			
		||||
    def sort_by(self, *fields: str):
 | 
			
		||||
        if not fields:
 | 
			
		||||
            return self
 | 
			
		||||
        return FindQuery(expressions=self.expressions,
 | 
			
		||||
| 
						 | 
				
			
			@ -612,12 +638,19 @@ class DefaultMeta:
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class ModelMeta(ModelMetaclass):
 | 
			
		||||
    def add_to_class(cls, name, value):
 | 
			
		||||
        if _has_contribute_to_class(value):
 | 
			
		||||
            value.contribute_to_class(cls, name)
 | 
			
		||||
        else:
 | 
			
		||||
            setattr(cls, name, value)
 | 
			
		||||
 | 
			
		||||
    def __new__(cls, name, bases, attrs, **kwargs):  # noqa C901
 | 
			
		||||
        meta = attrs.pop('Meta', None)
 | 
			
		||||
        new_class = super().__new__(cls, name, bases, attrs, **kwargs)
 | 
			
		||||
 | 
			
		||||
        meta = meta or getattr(new_class, 'Meta', None)
 | 
			
		||||
        base_meta = getattr(new_class, '_meta', None)
 | 
			
		||||
        parents = [b for b in bases if isinstance(b, ModelMeta)]
 | 
			
		||||
 | 
			
		||||
        if meta and meta != DefaultMeta and meta != base_meta:
 | 
			
		||||
            new_class.Meta = meta
 | 
			
		||||
| 
						 | 
				
			
			@ -638,6 +671,17 @@ class ModelMeta(ModelMetaclass):
 | 
			
		|||
            key = f"{new_class.__module__}.{new_class.__name__}"
 | 
			
		||||
            model_registry[key] = new_class
 | 
			
		||||
 | 
			
		||||
            if not hasattr(new_class, 'NotFoundError'):
 | 
			
		||||
                new_class.add_to_class(
 | 
			
		||||
                    'NotFoundError',
 | 
			
		||||
                    subclass_exception(
 | 
			
		||||
                        'NotFoundError',
 | 
			
		||||
                        tuple(
 | 
			
		||||
                            x.NotFoundError for x in bases if hasattr(x, '_meta') and not issubclass(x, abc.ABC)
 | 
			
		||||
                        ),
 | 
			
		||||
                        attrs['__module__'],
 | 
			
		||||
                        attached_to=new_class))
 | 
			
		||||
 | 
			
		||||
        # Create proxies for each model field so that we can use the field
 | 
			
		||||
        # in queries, like Model.get(Model.field_name == 1)
 | 
			
		||||
        for field_name, field in new_class.__fields__.items():
 | 
			
		||||
| 
						 | 
				
			
			@ -790,7 +834,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
 | 
			
		|||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def schema(cls):
 | 
			
		||||
    def redisearch_schema(cls):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -818,7 +862,7 @@ class HashModel(RedisModel, abc.ABC):
 | 
			
		|||
    def get(cls, pk: Any) -> 'HashModel':
 | 
			
		||||
        document = cls.db().hgetall(cls.make_primary_key(pk))
 | 
			
		||||
        if not document:
 | 
			
		||||
            raise NotFoundError
 | 
			
		||||
            raise cls.NotFoundError
 | 
			
		||||
        return cls.parse_obj(document)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
| 
						 | 
				
			
			@ -848,7 +892,7 @@ class HashModel(RedisModel, abc.ABC):
 | 
			
		|||
            return f"{name} TAG"
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def schema(cls):
 | 
			
		||||
    def redisearch_schema(cls):
 | 
			
		||||
        hash_prefix = cls.make_key(cls._meta.primary_key_pattern.format(pk=""))
 | 
			
		||||
        schema_prefix = f"ON HASH PREFIX 1 {hash_prefix} SCHEMA"
 | 
			
		||||
        schema_parts = [schema_prefix]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue