Avoid using Pydantic for FindQuery
This commit is contained in:
parent
01cab5352b
commit
ef58e854c1
3 changed files with 102 additions and 38 deletions
|
@ -73,7 +73,7 @@ class Migrator:
|
|||
for name, cls in model_registry.items():
|
||||
hash_key = schema_hash_key(cls.Meta.index_name)
|
||||
try:
|
||||
schema = cls.schema()
|
||||
schema = cls.redisearch_schema()
|
||||
except NotImplementedError:
|
||||
log.info("Skipping migrations for %s", name)
|
||||
continue
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import abc
|
||||
import dataclasses
|
||||
import decimal
|
||||
import inspect
|
||||
import operator
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from copy import deepcopy, copy
|
||||
from enum import Enum
|
||||
from functools import reduce
|
||||
from typing import (
|
||||
|
@ -41,6 +42,25 @@ model_registry = {}
|
|||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def subclass_exception(name, bases, module, attached_to):
|
||||
"""
|
||||
Create exception subclass. Used by RedisModel below.
|
||||
|
||||
The exception is created in a way that allows it to be pickled, assuming
|
||||
that the returned exception class will be added as an attribute to the
|
||||
'attached_to' class.
|
||||
"""
|
||||
return type(name, bases, {
|
||||
'__module__': module,
|
||||
'__qualname__': '%s.%s' % (attached_to.__qualname__, name),
|
||||
})
|
||||
|
||||
|
||||
def _has_contribute_to_class(value):
|
||||
# Only call contribute_to_class() if it's bound.
|
||||
return not inspect.isclass(value) and hasattr(value, 'contribute_to_class')
|
||||
|
||||
|
||||
class TokenEscaper:
|
||||
"""
|
||||
Escape punctuation within an input string.
|
||||
|
@ -167,21 +187,29 @@ NUMERIC_TYPES = (float, int, decimal.Decimal)
|
|||
DEFAULT_PAGE_SIZE = 10
|
||||
|
||||
|
||||
class FindQuery(BaseModel):
|
||||
expressions: Sequence[ExpressionOrNegated]
|
||||
model: Type['RedisModel']
|
||||
offset: int = 0
|
||||
limit: int = DEFAULT_PAGE_SIZE
|
||||
page_size: int = DEFAULT_PAGE_SIZE
|
||||
sort_fields: Optional[List[str]] = Field(default_factory=list)
|
||||
class FindQuery:
|
||||
def __init__(self,
|
||||
expressions: Sequence[ExpressionOrNegated],
|
||||
model: Type['RedisModel'],
|
||||
offset: int = 0,
|
||||
limit: int = DEFAULT_PAGE_SIZE,
|
||||
page_size: int = DEFAULT_PAGE_SIZE,
|
||||
sort_fields: Optional[List[str]] = None):
|
||||
self.expressions = expressions
|
||||
self.model = model
|
||||
self.offset = offset
|
||||
self.limit = limit
|
||||
self.page_size = page_size
|
||||
|
||||
_expression: Expression = PrivateAttr(default=None)
|
||||
_query: str = PrivateAttr(default=None)
|
||||
_pagination: List[str] = PrivateAttr(default_factory=list)
|
||||
_model_cache: Optional[List['RedisModel']] = PrivateAttr(default_factory=list)
|
||||
if sort_fields:
|
||||
self.sort_fields = self.validate_sort_fields(sort_fields)
|
||||
else:
|
||||
self.sort_fields = []
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
self._expression = None
|
||||
self._query = None
|
||||
self._pagination = []
|
||||
self._model_cache = []
|
||||
|
||||
@property
|
||||
def pagination(self):
|
||||
|
@ -204,19 +232,17 @@ class FindQuery(BaseModel):
|
|||
def query(self):
|
||||
return self.resolve_redisearch_query(self.expression)
|
||||
|
||||
@validator("sort_fields")
|
||||
def validate_sort_fields(cls, v, values):
|
||||
model = values['model']
|
||||
for sort_field in v:
|
||||
def validate_sort_fields(self, sort_fields):
|
||||
for sort_field in sort_fields:
|
||||
field_name = sort_field.lstrip("-")
|
||||
if field_name not in model.__fields__:
|
||||
if field_name not in self.model.__fields__:
|
||||
raise QueryNotSupportedError(f"You tried sort by {field_name}, but that field "
|
||||
f"does not exist on the model {model}")
|
||||
field_proxy = getattr(model, field_name)
|
||||
f"does not exist on the model {self.model}")
|
||||
field_proxy = getattr(self.model, field_name)
|
||||
if not getattr(field_proxy.field.field_info, 'sortable', False):
|
||||
raise QueryNotSupportedError(f"You tried sort by {field_name}, but {cls} does "
|
||||
raise QueryNotSupportedError(f"You tried sort by {field_name}, but {self.model} does "
|
||||
"not define that field as sortable. See docs: XXX")
|
||||
return v
|
||||
return sort_fields
|
||||
|
||||
@staticmethod
|
||||
def resolve_field_type(field: ModelField) -> RediSearchFieldTypes:
|
||||
|
@ -429,7 +455,7 @@ class FindQuery(BaseModel):
|
|||
return query.execute()
|
||||
return self.execute()
|
||||
|
||||
def sort_by(self, *fields):
|
||||
def sort_by(self, *fields: str):
|
||||
if not fields:
|
||||
return self
|
||||
return FindQuery(expressions=self.expressions,
|
||||
|
@ -612,12 +638,19 @@ class DefaultMeta:
|
|||
|
||||
|
||||
class ModelMeta(ModelMetaclass):
|
||||
def add_to_class(cls, name, value):
|
||||
if _has_contribute_to_class(value):
|
||||
value.contribute_to_class(cls, name)
|
||||
else:
|
||||
setattr(cls, name, value)
|
||||
|
||||
def __new__(cls, name, bases, attrs, **kwargs): # noqa C901
|
||||
meta = attrs.pop('Meta', None)
|
||||
new_class = super().__new__(cls, name, bases, attrs, **kwargs)
|
||||
|
||||
meta = meta or getattr(new_class, 'Meta', None)
|
||||
base_meta = getattr(new_class, '_meta', None)
|
||||
parents = [b for b in bases if isinstance(b, ModelMeta)]
|
||||
|
||||
if meta and meta != DefaultMeta and meta != base_meta:
|
||||
new_class.Meta = meta
|
||||
|
@ -638,6 +671,17 @@ class ModelMeta(ModelMetaclass):
|
|||
key = f"{new_class.__module__}.{new_class.__name__}"
|
||||
model_registry[key] = new_class
|
||||
|
||||
if not hasattr(new_class, 'NotFoundError'):
|
||||
new_class.add_to_class(
|
||||
'NotFoundError',
|
||||
subclass_exception(
|
||||
'NotFoundError',
|
||||
tuple(
|
||||
x.NotFoundError for x in bases if hasattr(x, '_meta') and not issubclass(x, abc.ABC)
|
||||
),
|
||||
attrs['__module__'],
|
||||
attached_to=new_class))
|
||||
|
||||
# Create proxies for each model field so that we can use the field
|
||||
# in queries, like Model.get(Model.field_name == 1)
|
||||
for field_name, field in new_class.__fields__.items():
|
||||
|
@ -790,7 +834,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
|||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def schema(cls):
|
||||
def redisearch_schema(cls):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
@ -818,7 +862,7 @@ class HashModel(RedisModel, abc.ABC):
|
|||
def get(cls, pk: Any) -> 'HashModel':
|
||||
document = cls.db().hgetall(cls.make_primary_key(pk))
|
||||
if not document:
|
||||
raise NotFoundError
|
||||
raise cls.NotFoundError
|
||||
return cls.parse_obj(document)
|
||||
|
||||
@classmethod
|
||||
|
@ -848,7 +892,7 @@ class HashModel(RedisModel, abc.ABC):
|
|||
return f"{name} TAG"
|
||||
|
||||
@classmethod
|
||||
def schema(cls):
|
||||
def redisearch_schema(cls):
|
||||
hash_prefix = cls.make_key(cls._meta.primary_key_pattern.format(pk=""))
|
||||
schema_prefix = f"ON HASH PREFIX 1 {hash_prefix} SCHEMA"
|
||||
schema_parts = [schema_prefix]
|
||||
|
|
|
@ -2,6 +2,7 @@ import abc
|
|||
import decimal
|
||||
import datetime
|
||||
from typing import Optional
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import redis
|
||||
|
@ -11,7 +12,7 @@ from redis_developer.orm import (
|
|||
HashModel,
|
||||
Field,
|
||||
)
|
||||
from redis_developer.orm.model import RedisModelError
|
||||
from redis_developer.orm.model import RedisModelError, QueryNotSupportedError
|
||||
|
||||
r = redis.Redis()
|
||||
today = datetime.date.today()
|
||||
|
@ -172,14 +173,17 @@ def test_paginate_query(members):
|
|||
|
||||
|
||||
def test_access_result_by_index_cached(members):
|
||||
_, member2, _ = members
|
||||
member1, member2, member3 = members
|
||||
query = Member.find().sort_by('age')
|
||||
# Load the cache, throw away the result.
|
||||
assert query._model_cache == []
|
||||
query.execute()
|
||||
assert query._model_cache == [member2, member1, member3]
|
||||
|
||||
# Access an item that should be in the cache.
|
||||
# TODO: Assert that we didn't make a Redis request.
|
||||
with mock.patch.object(query.model, 'db') as mock_db:
|
||||
assert query[0] == member2
|
||||
assert not mock_db.called
|
||||
|
||||
|
||||
def test_access_result_by_index_not_cached(members):
|
||||
|
@ -284,11 +288,28 @@ def test_numeric_queries(members):
|
|||
actual = Member.find(Member.age >= 100).all()
|
||||
assert actual == [member3]
|
||||
|
||||
import ipdb; ipdb.set_trace()
|
||||
actual = Member.find(~(Member.age == 100)).all()
|
||||
assert sorted(actual) == [member1, member2]
|
||||
|
||||
|
||||
def test_sorting(members):
|
||||
member1, member2, member3 = members
|
||||
|
||||
actual = Member.find(Member.age > 34).sort_by('age').all()
|
||||
assert sorted(actual) == [member3, member1]
|
||||
|
||||
actual = Member.find(Member.age > 34).sort_by('-age').all()
|
||||
assert sorted(actual) == [member1, member3]
|
||||
|
||||
with pytest.raises(QueryNotSupportedError):
|
||||
# This field does not exist.
|
||||
Member.find().sort_by('not-a-real-field').all()
|
||||
|
||||
with pytest.raises(QueryNotSupportedError):
|
||||
# This field is not sortable.
|
||||
Member.find().sort_by('join_date').all()
|
||||
|
||||
|
||||
def test_schema():
|
||||
class Address(BaseHashModel):
|
||||
a_string: str = Field(index=True)
|
||||
|
@ -298,8 +319,7 @@ def test_schema():
|
|||
another_integer: int
|
||||
another_float: float
|
||||
|
||||
# TODO: Fix
|
||||
assert Address.schema() == "ON HASH PREFIX 1 redis-developer:tests.test_hash_model.Address: " \
|
||||
assert Address.redisearch_schema() == "ON HASH PREFIX 1 redis-developer:tests.test_hash_model.Address: " \
|
||||
"SCHEMA pk TAG a_string TAG a_full_text_string TAG " \
|
||||
"a_full_text_string_fts TEXT an_integer NUMERIC SORTABLE " \
|
||||
"a_float NUMERIC"
|
||||
|
|
Loading…
Reference in a new issue