Avoid using Pydantic for FindQuery
This commit is contained in:
		
							parent
							
								
									01cab5352b
								
							
						
					
					
						commit
						ef58e854c1
					
				
					 3 changed files with 102 additions and 38 deletions
				
			
		| 
						 | 
					@ -73,7 +73,7 @@ class Migrator:
 | 
				
			||||||
        for name, cls in model_registry.items():
 | 
					        for name, cls in model_registry.items():
 | 
				
			||||||
            hash_key = schema_hash_key(cls.Meta.index_name)
 | 
					            hash_key = schema_hash_key(cls.Meta.index_name)
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                schema = cls.schema()
 | 
					                schema = cls.redisearch_schema()
 | 
				
			||||||
            except NotImplementedError:
 | 
					            except NotImplementedError:
 | 
				
			||||||
                log.info("Skipping migrations for %s", name)
 | 
					                log.info("Skipping migrations for %s", name)
 | 
				
			||||||
                continue
 | 
					                continue
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,9 +1,10 @@
 | 
				
			||||||
import abc
 | 
					import abc
 | 
				
			||||||
import dataclasses
 | 
					import dataclasses
 | 
				
			||||||
import decimal
 | 
					import decimal
 | 
				
			||||||
 | 
					import inspect
 | 
				
			||||||
import operator
 | 
					import operator
 | 
				
			||||||
import re
 | 
					import re
 | 
				
			||||||
from copy import deepcopy
 | 
					from copy import deepcopy, copy
 | 
				
			||||||
from enum import Enum
 | 
					from enum import Enum
 | 
				
			||||||
from functools import reduce
 | 
					from functools import reduce
 | 
				
			||||||
from typing import (
 | 
					from typing import (
 | 
				
			||||||
| 
						 | 
					@ -41,6 +42,25 @@ model_registry = {}
 | 
				
			||||||
_T = TypeVar("_T")
 | 
					_T = TypeVar("_T")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def subclass_exception(name, bases, module, attached_to):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Create exception subclass. Used by RedisModel below.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    The exception is created in a way that allows it to be pickled, assuming
 | 
				
			||||||
 | 
					    that the returned exception class will be added as an attribute to the
 | 
				
			||||||
 | 
					    'attached_to' class.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    return type(name, bases, {
 | 
				
			||||||
 | 
					        '__module__': module,
 | 
				
			||||||
 | 
					        '__qualname__': '%s.%s' % (attached_to.__qualname__, name),
 | 
				
			||||||
 | 
					    })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _has_contribute_to_class(value):
 | 
				
			||||||
 | 
					    # Only call contribute_to_class() if it's bound.
 | 
				
			||||||
 | 
					    return not inspect.isclass(value) and hasattr(value, 'contribute_to_class')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TokenEscaper:
 | 
					class TokenEscaper:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Escape punctuation within an input string.
 | 
					    Escape punctuation within an input string.
 | 
				
			||||||
| 
						 | 
					@ -167,21 +187,29 @@ NUMERIC_TYPES = (float, int, decimal.Decimal)
 | 
				
			||||||
DEFAULT_PAGE_SIZE = 10
 | 
					DEFAULT_PAGE_SIZE = 10
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class FindQuery(BaseModel):
 | 
					class FindQuery:
 | 
				
			||||||
    expressions: Sequence[ExpressionOrNegated]
 | 
					    def __init__(self,
 | 
				
			||||||
    model: Type['RedisModel']
 | 
					                 expressions: Sequence[ExpressionOrNegated],
 | 
				
			||||||
    offset: int = 0
 | 
					                 model: Type['RedisModel'],
 | 
				
			||||||
    limit: int = DEFAULT_PAGE_SIZE
 | 
					                 offset: int = 0,
 | 
				
			||||||
    page_size: int = DEFAULT_PAGE_SIZE
 | 
					                 limit: int = DEFAULT_PAGE_SIZE,
 | 
				
			||||||
    sort_fields: Optional[List[str]] = Field(default_factory=list)
 | 
					                 page_size: int = DEFAULT_PAGE_SIZE,
 | 
				
			||||||
 | 
					                 sort_fields: Optional[List[str]] = None):
 | 
				
			||||||
 | 
					        self.expressions = expressions
 | 
				
			||||||
 | 
					        self.model = model
 | 
				
			||||||
 | 
					        self.offset = offset
 | 
				
			||||||
 | 
					        self.limit = limit
 | 
				
			||||||
 | 
					        self.page_size = page_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    _expression: Expression = PrivateAttr(default=None)
 | 
					        if sort_fields:
 | 
				
			||||||
    _query: str = PrivateAttr(default=None)
 | 
					            self.sort_fields = self.validate_sort_fields(sort_fields)
 | 
				
			||||||
    _pagination: List[str] = PrivateAttr(default_factory=list)
 | 
					        else:
 | 
				
			||||||
    _model_cache: Optional[List['RedisModel']] = PrivateAttr(default_factory=list)
 | 
					            self.sort_fields = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    class Config:
 | 
					        self._expression = None
 | 
				
			||||||
        arbitrary_types_allowed = True
 | 
					        self._query = None
 | 
				
			||||||
 | 
					        self._pagination = []
 | 
				
			||||||
 | 
					        self._model_cache = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def pagination(self):
 | 
					    def pagination(self):
 | 
				
			||||||
| 
						 | 
					@ -204,19 +232,17 @@ class FindQuery(BaseModel):
 | 
				
			||||||
    def query(self):
 | 
					    def query(self):
 | 
				
			||||||
        return self.resolve_redisearch_query(self.expression)
 | 
					        return self.resolve_redisearch_query(self.expression)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @validator("sort_fields")
 | 
					    def validate_sort_fields(self, sort_fields):
 | 
				
			||||||
    def validate_sort_fields(cls, v, values):
 | 
					        for sort_field in sort_fields:
 | 
				
			||||||
        model = values['model']
 | 
					 | 
				
			||||||
        for sort_field in v:
 | 
					 | 
				
			||||||
            field_name = sort_field.lstrip("-")
 | 
					            field_name = sort_field.lstrip("-")
 | 
				
			||||||
            if field_name not in model.__fields__:
 | 
					            if field_name not in self.model.__fields__:
 | 
				
			||||||
                raise QueryNotSupportedError(f"You tried sort by {field_name}, but that field "
 | 
					                raise QueryNotSupportedError(f"You tried sort by {field_name}, but that field "
 | 
				
			||||||
                                             f"does not exist on the model {model}")
 | 
					                                             f"does not exist on the model {self.model}")
 | 
				
			||||||
            field_proxy = getattr(model, field_name)
 | 
					            field_proxy = getattr(self.model, field_name)
 | 
				
			||||||
            if not getattr(field_proxy.field.field_info, 'sortable', False):
 | 
					            if not getattr(field_proxy.field.field_info, 'sortable', False):
 | 
				
			||||||
                raise QueryNotSupportedError(f"You tried sort by {field_name}, but {cls} does "
 | 
					                raise QueryNotSupportedError(f"You tried sort by {field_name}, but {self.model} does "
 | 
				
			||||||
                                             "not define that field as sortable. See docs: XXX")
 | 
					                                             "not define that field as sortable. See docs: XXX")
 | 
				
			||||||
        return v
 | 
					        return sort_fields
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    def resolve_field_type(field: ModelField) -> RediSearchFieldTypes:
 | 
					    def resolve_field_type(field: ModelField) -> RediSearchFieldTypes:
 | 
				
			||||||
| 
						 | 
					@ -429,7 +455,7 @@ class FindQuery(BaseModel):
 | 
				
			||||||
            return query.execute()
 | 
					            return query.execute()
 | 
				
			||||||
        return self.execute()
 | 
					        return self.execute()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def sort_by(self, *fields):
 | 
					    def sort_by(self, *fields: str):
 | 
				
			||||||
        if not fields:
 | 
					        if not fields:
 | 
				
			||||||
            return self
 | 
					            return self
 | 
				
			||||||
        return FindQuery(expressions=self.expressions,
 | 
					        return FindQuery(expressions=self.expressions,
 | 
				
			||||||
| 
						 | 
					@ -612,12 +638,19 @@ class DefaultMeta:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ModelMeta(ModelMetaclass):
 | 
					class ModelMeta(ModelMetaclass):
 | 
				
			||||||
 | 
					    def add_to_class(cls, name, value):
 | 
				
			||||||
 | 
					        if _has_contribute_to_class(value):
 | 
				
			||||||
 | 
					            value.contribute_to_class(cls, name)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            setattr(cls, name, value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __new__(cls, name, bases, attrs, **kwargs):  # noqa C901
 | 
					    def __new__(cls, name, bases, attrs, **kwargs):  # noqa C901
 | 
				
			||||||
        meta = attrs.pop('Meta', None)
 | 
					        meta = attrs.pop('Meta', None)
 | 
				
			||||||
        new_class = super().__new__(cls, name, bases, attrs, **kwargs)
 | 
					        new_class = super().__new__(cls, name, bases, attrs, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        meta = meta or getattr(new_class, 'Meta', None)
 | 
					        meta = meta or getattr(new_class, 'Meta', None)
 | 
				
			||||||
        base_meta = getattr(new_class, '_meta', None)
 | 
					        base_meta = getattr(new_class, '_meta', None)
 | 
				
			||||||
 | 
					        parents = [b for b in bases if isinstance(b, ModelMeta)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if meta and meta != DefaultMeta and meta != base_meta:
 | 
					        if meta and meta != DefaultMeta and meta != base_meta:
 | 
				
			||||||
            new_class.Meta = meta
 | 
					            new_class.Meta = meta
 | 
				
			||||||
| 
						 | 
					@ -638,6 +671,17 @@ class ModelMeta(ModelMetaclass):
 | 
				
			||||||
            key = f"{new_class.__module__}.{new_class.__name__}"
 | 
					            key = f"{new_class.__module__}.{new_class.__name__}"
 | 
				
			||||||
            model_registry[key] = new_class
 | 
					            model_registry[key] = new_class
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if not hasattr(new_class, 'NotFoundError'):
 | 
				
			||||||
 | 
					                new_class.add_to_class(
 | 
				
			||||||
 | 
					                    'NotFoundError',
 | 
				
			||||||
 | 
					                    subclass_exception(
 | 
				
			||||||
 | 
					                        'NotFoundError',
 | 
				
			||||||
 | 
					                        tuple(
 | 
				
			||||||
 | 
					                            x.NotFoundError for x in bases if hasattr(x, '_meta') and not issubclass(x, abc.ABC)
 | 
				
			||||||
 | 
					                        ),
 | 
				
			||||||
 | 
					                        attrs['__module__'],
 | 
				
			||||||
 | 
					                        attached_to=new_class))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Create proxies for each model field so that we can use the field
 | 
					        # Create proxies for each model field so that we can use the field
 | 
				
			||||||
        # in queries, like Model.get(Model.field_name == 1)
 | 
					        # in queries, like Model.get(Model.field_name == 1)
 | 
				
			||||||
        for field_name, field in new_class.__fields__.items():
 | 
					        for field_name, field in new_class.__fields__.items():
 | 
				
			||||||
| 
						 | 
					@ -790,7 +834,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
 | 
				
			||||||
        raise NotImplementedError
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def schema(cls):
 | 
					    def redisearch_schema(cls):
 | 
				
			||||||
        raise NotImplementedError
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -818,7 +862,7 @@ class HashModel(RedisModel, abc.ABC):
 | 
				
			||||||
    def get(cls, pk: Any) -> 'HashModel':
 | 
					    def get(cls, pk: Any) -> 'HashModel':
 | 
				
			||||||
        document = cls.db().hgetall(cls.make_primary_key(pk))
 | 
					        document = cls.db().hgetall(cls.make_primary_key(pk))
 | 
				
			||||||
        if not document:
 | 
					        if not document:
 | 
				
			||||||
            raise NotFoundError
 | 
					            raise cls.NotFoundError
 | 
				
			||||||
        return cls.parse_obj(document)
 | 
					        return cls.parse_obj(document)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
| 
						 | 
					@ -848,7 +892,7 @@ class HashModel(RedisModel, abc.ABC):
 | 
				
			||||||
            return f"{name} TAG"
 | 
					            return f"{name} TAG"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def schema(cls):
 | 
					    def redisearch_schema(cls):
 | 
				
			||||||
        hash_prefix = cls.make_key(cls._meta.primary_key_pattern.format(pk=""))
 | 
					        hash_prefix = cls.make_key(cls._meta.primary_key_pattern.format(pk=""))
 | 
				
			||||||
        schema_prefix = f"ON HASH PREFIX 1 {hash_prefix} SCHEMA"
 | 
					        schema_prefix = f"ON HASH PREFIX 1 {hash_prefix} SCHEMA"
 | 
				
			||||||
        schema_parts = [schema_prefix]
 | 
					        schema_parts = [schema_prefix]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2,6 +2,7 @@ import abc
 | 
				
			||||||
import decimal
 | 
					import decimal
 | 
				
			||||||
import datetime
 | 
					import datetime
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					from unittest import mock
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
import redis
 | 
					import redis
 | 
				
			||||||
| 
						 | 
					@ -11,7 +12,7 @@ from redis_developer.orm import (
 | 
				
			||||||
    HashModel,
 | 
					    HashModel,
 | 
				
			||||||
    Field,
 | 
					    Field,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from redis_developer.orm.model import RedisModelError
 | 
					from redis_developer.orm.model import RedisModelError, QueryNotSupportedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
r = redis.Redis()
 | 
					r = redis.Redis()
 | 
				
			||||||
today = datetime.date.today()
 | 
					today = datetime.date.today()
 | 
				
			||||||
| 
						 | 
					@ -172,14 +173,17 @@ def test_paginate_query(members):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_access_result_by_index_cached(members):
 | 
					def test_access_result_by_index_cached(members):
 | 
				
			||||||
    _, member2, _ = members
 | 
					    member1, member2, member3 = members
 | 
				
			||||||
    query = Member.find().sort_by('age')
 | 
					    query = Member.find().sort_by('age')
 | 
				
			||||||
    # Load the cache, throw away the result.
 | 
					    # Load the cache, throw away the result.
 | 
				
			||||||
 | 
					    assert query._model_cache == []
 | 
				
			||||||
    query.execute()
 | 
					    query.execute()
 | 
				
			||||||
 | 
					    assert query._model_cache == [member2, member1, member3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Access an item that should be in the cache.
 | 
					    # Access an item that should be in the cache.
 | 
				
			||||||
    # TODO: Assert that we didn't make a Redis request.
 | 
					    with mock.patch.object(query.model, 'db') as mock_db:
 | 
				
			||||||
        assert query[0] == member2
 | 
					        assert query[0] == member2
 | 
				
			||||||
 | 
					        assert not mock_db.called
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_access_result_by_index_not_cached(members):
 | 
					def test_access_result_by_index_not_cached(members):
 | 
				
			||||||
| 
						 | 
					@ -284,11 +288,28 @@ def test_numeric_queries(members):
 | 
				
			||||||
    actual = Member.find(Member.age >= 100).all()
 | 
					    actual = Member.find(Member.age >= 100).all()
 | 
				
			||||||
    assert actual == [member3]
 | 
					    assert actual == [member3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    import ipdb; ipdb.set_trace()
 | 
					 | 
				
			||||||
    actual = Member.find(~(Member.age == 100)).all()
 | 
					    actual = Member.find(~(Member.age == 100)).all()
 | 
				
			||||||
    assert sorted(actual) == [member1, member2]
 | 
					    assert sorted(actual) == [member1, member2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_sorting(members):
 | 
				
			||||||
 | 
					    member1, member2, member3 = members
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    actual = Member.find(Member.age > 34).sort_by('age').all()
 | 
				
			||||||
 | 
					    assert sorted(actual) == [member3, member1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    actual = Member.find(Member.age > 34).sort_by('-age').all()
 | 
				
			||||||
 | 
					    assert sorted(actual) == [member1, member3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with pytest.raises(QueryNotSupportedError):
 | 
				
			||||||
 | 
					        # This field does not exist.
 | 
				
			||||||
 | 
					        Member.find().sort_by('not-a-real-field').all()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with pytest.raises(QueryNotSupportedError):
 | 
				
			||||||
 | 
					        # This field is not sortable.
 | 
				
			||||||
 | 
					        Member.find().sort_by('join_date').all()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_schema():
 | 
					def test_schema():
 | 
				
			||||||
    class Address(BaseHashModel):
 | 
					    class Address(BaseHashModel):
 | 
				
			||||||
        a_string: str = Field(index=True)
 | 
					        a_string: str = Field(index=True)
 | 
				
			||||||
| 
						 | 
					@ -298,8 +319,7 @@ def test_schema():
 | 
				
			||||||
        another_integer: int
 | 
					        another_integer: int
 | 
				
			||||||
        another_float: float
 | 
					        another_float: float
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # TODO: Fix
 | 
					    assert Address.redisearch_schema() == "ON HASH PREFIX 1 redis-developer:tests.test_hash_model.Address: " \
 | 
				
			||||||
    assert Address.schema() == "ON HASH PREFIX 1 redis-developer:tests.test_hash_model.Address: " \
 | 
					 | 
				
			||||||
                                          "SCHEMA pk TAG a_string TAG a_full_text_string TAG " \
 | 
					                                          "SCHEMA pk TAG a_string TAG a_full_text_string TAG " \
 | 
				
			||||||
                                          "a_full_text_string_fts TEXT an_integer NUMERIC SORTABLE " \
 | 
					                                          "a_full_text_string_fts TEXT an_integer NUMERIC SORTABLE " \
 | 
				
			||||||
                                          "a_float NUMERIC"
 | 
					                                          "a_float NUMERIC"
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue