Avoid using Pydantic for FindQuery
This commit is contained in:
parent
01cab5352b
commit
ef58e854c1
3 changed files with 102 additions and 38 deletions
|
@ -73,7 +73,7 @@ class Migrator:
|
||||||
for name, cls in model_registry.items():
|
for name, cls in model_registry.items():
|
||||||
hash_key = schema_hash_key(cls.Meta.index_name)
|
hash_key = schema_hash_key(cls.Meta.index_name)
|
||||||
try:
|
try:
|
||||||
schema = cls.schema()
|
schema = cls.redisearch_schema()
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
log.info("Skipping migrations for %s", name)
|
log.info("Skipping migrations for %s", name)
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
import abc
|
import abc
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import decimal
|
import decimal
|
||||||
|
import inspect
|
||||||
import operator
|
import operator
|
||||||
import re
|
import re
|
||||||
from copy import deepcopy
|
from copy import deepcopy, copy
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import (
|
from typing import (
|
||||||
|
@ -41,6 +42,25 @@ model_registry = {}
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
|
def subclass_exception(name, bases, module, attached_to):
|
||||||
|
"""
|
||||||
|
Create exception subclass. Used by RedisModel below.
|
||||||
|
|
||||||
|
The exception is created in a way that allows it to be pickled, assuming
|
||||||
|
that the returned exception class will be added as an attribute to the
|
||||||
|
'attached_to' class.
|
||||||
|
"""
|
||||||
|
return type(name, bases, {
|
||||||
|
'__module__': module,
|
||||||
|
'__qualname__': '%s.%s' % (attached_to.__qualname__, name),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def _has_contribute_to_class(value):
|
||||||
|
# Only call contribute_to_class() if it's bound.
|
||||||
|
return not inspect.isclass(value) and hasattr(value, 'contribute_to_class')
|
||||||
|
|
||||||
|
|
||||||
class TokenEscaper:
|
class TokenEscaper:
|
||||||
"""
|
"""
|
||||||
Escape punctuation within an input string.
|
Escape punctuation within an input string.
|
||||||
|
@ -167,21 +187,29 @@ NUMERIC_TYPES = (float, int, decimal.Decimal)
|
||||||
DEFAULT_PAGE_SIZE = 10
|
DEFAULT_PAGE_SIZE = 10
|
||||||
|
|
||||||
|
|
||||||
class FindQuery(BaseModel):
|
class FindQuery:
|
||||||
expressions: Sequence[ExpressionOrNegated]
|
def __init__(self,
|
||||||
model: Type['RedisModel']
|
expressions: Sequence[ExpressionOrNegated],
|
||||||
offset: int = 0
|
model: Type['RedisModel'],
|
||||||
limit: int = DEFAULT_PAGE_SIZE
|
offset: int = 0,
|
||||||
page_size: int = DEFAULT_PAGE_SIZE
|
limit: int = DEFAULT_PAGE_SIZE,
|
||||||
sort_fields: Optional[List[str]] = Field(default_factory=list)
|
page_size: int = DEFAULT_PAGE_SIZE,
|
||||||
|
sort_fields: Optional[List[str]] = None):
|
||||||
|
self.expressions = expressions
|
||||||
|
self.model = model
|
||||||
|
self.offset = offset
|
||||||
|
self.limit = limit
|
||||||
|
self.page_size = page_size
|
||||||
|
|
||||||
_expression: Expression = PrivateAttr(default=None)
|
if sort_fields:
|
||||||
_query: str = PrivateAttr(default=None)
|
self.sort_fields = self.validate_sort_fields(sort_fields)
|
||||||
_pagination: List[str] = PrivateAttr(default_factory=list)
|
else:
|
||||||
_model_cache: Optional[List['RedisModel']] = PrivateAttr(default_factory=list)
|
self.sort_fields = []
|
||||||
|
|
||||||
class Config:
|
self._expression = None
|
||||||
arbitrary_types_allowed = True
|
self._query = None
|
||||||
|
self._pagination = []
|
||||||
|
self._model_cache = []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pagination(self):
|
def pagination(self):
|
||||||
|
@ -204,19 +232,17 @@ class FindQuery(BaseModel):
|
||||||
def query(self):
|
def query(self):
|
||||||
return self.resolve_redisearch_query(self.expression)
|
return self.resolve_redisearch_query(self.expression)
|
||||||
|
|
||||||
@validator("sort_fields")
|
def validate_sort_fields(self, sort_fields):
|
||||||
def validate_sort_fields(cls, v, values):
|
for sort_field in sort_fields:
|
||||||
model = values['model']
|
|
||||||
for sort_field in v:
|
|
||||||
field_name = sort_field.lstrip("-")
|
field_name = sort_field.lstrip("-")
|
||||||
if field_name not in model.__fields__:
|
if field_name not in self.model.__fields__:
|
||||||
raise QueryNotSupportedError(f"You tried sort by {field_name}, but that field "
|
raise QueryNotSupportedError(f"You tried sort by {field_name}, but that field "
|
||||||
f"does not exist on the model {model}")
|
f"does not exist on the model {self.model}")
|
||||||
field_proxy = getattr(model, field_name)
|
field_proxy = getattr(self.model, field_name)
|
||||||
if not getattr(field_proxy.field.field_info, 'sortable', False):
|
if not getattr(field_proxy.field.field_info, 'sortable', False):
|
||||||
raise QueryNotSupportedError(f"You tried sort by {field_name}, but {cls} does "
|
raise QueryNotSupportedError(f"You tried sort by {field_name}, but {self.model} does "
|
||||||
"not define that field as sortable. See docs: XXX")
|
"not define that field as sortable. See docs: XXX")
|
||||||
return v
|
return sort_fields
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def resolve_field_type(field: ModelField) -> RediSearchFieldTypes:
|
def resolve_field_type(field: ModelField) -> RediSearchFieldTypes:
|
||||||
|
@ -429,7 +455,7 @@ class FindQuery(BaseModel):
|
||||||
return query.execute()
|
return query.execute()
|
||||||
return self.execute()
|
return self.execute()
|
||||||
|
|
||||||
def sort_by(self, *fields):
|
def sort_by(self, *fields: str):
|
||||||
if not fields:
|
if not fields:
|
||||||
return self
|
return self
|
||||||
return FindQuery(expressions=self.expressions,
|
return FindQuery(expressions=self.expressions,
|
||||||
|
@ -612,12 +638,19 @@ class DefaultMeta:
|
||||||
|
|
||||||
|
|
||||||
class ModelMeta(ModelMetaclass):
|
class ModelMeta(ModelMetaclass):
|
||||||
|
def add_to_class(cls, name, value):
|
||||||
|
if _has_contribute_to_class(value):
|
||||||
|
value.contribute_to_class(cls, name)
|
||||||
|
else:
|
||||||
|
setattr(cls, name, value)
|
||||||
|
|
||||||
def __new__(cls, name, bases, attrs, **kwargs): # noqa C901
|
def __new__(cls, name, bases, attrs, **kwargs): # noqa C901
|
||||||
meta = attrs.pop('Meta', None)
|
meta = attrs.pop('Meta', None)
|
||||||
new_class = super().__new__(cls, name, bases, attrs, **kwargs)
|
new_class = super().__new__(cls, name, bases, attrs, **kwargs)
|
||||||
|
|
||||||
meta = meta or getattr(new_class, 'Meta', None)
|
meta = meta or getattr(new_class, 'Meta', None)
|
||||||
base_meta = getattr(new_class, '_meta', None)
|
base_meta = getattr(new_class, '_meta', None)
|
||||||
|
parents = [b for b in bases if isinstance(b, ModelMeta)]
|
||||||
|
|
||||||
if meta and meta != DefaultMeta and meta != base_meta:
|
if meta and meta != DefaultMeta and meta != base_meta:
|
||||||
new_class.Meta = meta
|
new_class.Meta = meta
|
||||||
|
@ -638,6 +671,17 @@ class ModelMeta(ModelMetaclass):
|
||||||
key = f"{new_class.__module__}.{new_class.__name__}"
|
key = f"{new_class.__module__}.{new_class.__name__}"
|
||||||
model_registry[key] = new_class
|
model_registry[key] = new_class
|
||||||
|
|
||||||
|
if not hasattr(new_class, 'NotFoundError'):
|
||||||
|
new_class.add_to_class(
|
||||||
|
'NotFoundError',
|
||||||
|
subclass_exception(
|
||||||
|
'NotFoundError',
|
||||||
|
tuple(
|
||||||
|
x.NotFoundError for x in bases if hasattr(x, '_meta') and not issubclass(x, abc.ABC)
|
||||||
|
),
|
||||||
|
attrs['__module__'],
|
||||||
|
attached_to=new_class))
|
||||||
|
|
||||||
# Create proxies for each model field so that we can use the field
|
# Create proxies for each model field so that we can use the field
|
||||||
# in queries, like Model.get(Model.field_name == 1)
|
# in queries, like Model.get(Model.field_name == 1)
|
||||||
for field_name, field in new_class.__fields__.items():
|
for field_name, field in new_class.__fields__.items():
|
||||||
|
@ -790,7 +834,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def schema(cls):
|
def redisearch_schema(cls):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@ -818,7 +862,7 @@ class HashModel(RedisModel, abc.ABC):
|
||||||
def get(cls, pk: Any) -> 'HashModel':
|
def get(cls, pk: Any) -> 'HashModel':
|
||||||
document = cls.db().hgetall(cls.make_primary_key(pk))
|
document = cls.db().hgetall(cls.make_primary_key(pk))
|
||||||
if not document:
|
if not document:
|
||||||
raise NotFoundError
|
raise cls.NotFoundError
|
||||||
return cls.parse_obj(document)
|
return cls.parse_obj(document)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -848,7 +892,7 @@ class HashModel(RedisModel, abc.ABC):
|
||||||
return f"{name} TAG"
|
return f"{name} TAG"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def schema(cls):
|
def redisearch_schema(cls):
|
||||||
hash_prefix = cls.make_key(cls._meta.primary_key_pattern.format(pk=""))
|
hash_prefix = cls.make_key(cls._meta.primary_key_pattern.format(pk=""))
|
||||||
schema_prefix = f"ON HASH PREFIX 1 {hash_prefix} SCHEMA"
|
schema_prefix = f"ON HASH PREFIX 1 {hash_prefix} SCHEMA"
|
||||||
schema_parts = [schema_prefix]
|
schema_parts = [schema_prefix]
|
||||||
|
|
|
@ -2,6 +2,7 @@ import abc
|
||||||
import decimal
|
import decimal
|
||||||
import datetime
|
import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import redis
|
import redis
|
||||||
|
@ -11,7 +12,7 @@ from redis_developer.orm import (
|
||||||
HashModel,
|
HashModel,
|
||||||
Field,
|
Field,
|
||||||
)
|
)
|
||||||
from redis_developer.orm.model import RedisModelError
|
from redis_developer.orm.model import RedisModelError, QueryNotSupportedError
|
||||||
|
|
||||||
r = redis.Redis()
|
r = redis.Redis()
|
||||||
today = datetime.date.today()
|
today = datetime.date.today()
|
||||||
|
@ -172,14 +173,17 @@ def test_paginate_query(members):
|
||||||
|
|
||||||
|
|
||||||
def test_access_result_by_index_cached(members):
|
def test_access_result_by_index_cached(members):
|
||||||
_, member2, _ = members
|
member1, member2, member3 = members
|
||||||
query = Member.find().sort_by('age')
|
query = Member.find().sort_by('age')
|
||||||
# Load the cache, throw away the result.
|
# Load the cache, throw away the result.
|
||||||
|
assert query._model_cache == []
|
||||||
query.execute()
|
query.execute()
|
||||||
|
assert query._model_cache == [member2, member1, member3]
|
||||||
|
|
||||||
# Access an item that should be in the cache.
|
# Access an item that should be in the cache.
|
||||||
# TODO: Assert that we didn't make a Redis request.
|
with mock.patch.object(query.model, 'db') as mock_db:
|
||||||
assert query[0] == member2
|
assert query[0] == member2
|
||||||
|
assert not mock_db.called
|
||||||
|
|
||||||
|
|
||||||
def test_access_result_by_index_not_cached(members):
|
def test_access_result_by_index_not_cached(members):
|
||||||
|
@ -284,11 +288,28 @@ def test_numeric_queries(members):
|
||||||
actual = Member.find(Member.age >= 100).all()
|
actual = Member.find(Member.age >= 100).all()
|
||||||
assert actual == [member3]
|
assert actual == [member3]
|
||||||
|
|
||||||
import ipdb; ipdb.set_trace()
|
|
||||||
actual = Member.find(~(Member.age == 100)).all()
|
actual = Member.find(~(Member.age == 100)).all()
|
||||||
assert sorted(actual) == [member1, member2]
|
assert sorted(actual) == [member1, member2]
|
||||||
|
|
||||||
|
|
||||||
|
def test_sorting(members):
|
||||||
|
member1, member2, member3 = members
|
||||||
|
|
||||||
|
actual = Member.find(Member.age > 34).sort_by('age').all()
|
||||||
|
assert sorted(actual) == [member3, member1]
|
||||||
|
|
||||||
|
actual = Member.find(Member.age > 34).sort_by('-age').all()
|
||||||
|
assert sorted(actual) == [member1, member3]
|
||||||
|
|
||||||
|
with pytest.raises(QueryNotSupportedError):
|
||||||
|
# This field does not exist.
|
||||||
|
Member.find().sort_by('not-a-real-field').all()
|
||||||
|
|
||||||
|
with pytest.raises(QueryNotSupportedError):
|
||||||
|
# This field is not sortable.
|
||||||
|
Member.find().sort_by('join_date').all()
|
||||||
|
|
||||||
|
|
||||||
def test_schema():
|
def test_schema():
|
||||||
class Address(BaseHashModel):
|
class Address(BaseHashModel):
|
||||||
a_string: str = Field(index=True)
|
a_string: str = Field(index=True)
|
||||||
|
@ -298,8 +319,7 @@ def test_schema():
|
||||||
another_integer: int
|
another_integer: int
|
||||||
another_float: float
|
another_float: float
|
||||||
|
|
||||||
# TODO: Fix
|
assert Address.redisearch_schema() == "ON HASH PREFIX 1 redis-developer:tests.test_hash_model.Address: " \
|
||||||
assert Address.schema() == "ON HASH PREFIX 1 redis-developer:tests.test_hash_model.Address: " \
|
|
||||||
"SCHEMA pk TAG a_string TAG a_full_text_string TAG " \
|
"SCHEMA pk TAG a_string TAG a_full_text_string TAG " \
|
||||||
"a_full_text_string_fts TEXT an_integer NUMERIC SORTABLE " \
|
"a_full_text_string_fts TEXT an_integer NUMERIC SORTABLE " \
|
||||||
"a_float NUMERIC"
|
"a_float NUMERIC"
|
||||||
|
|
Loading…
Reference in a new issue