Avoid using Pydantic for FindQuery

This commit is contained in:
Andrew Brookins 2021-09-29 20:23:39 -07:00
parent 01cab5352b
commit ef58e854c1
3 changed files with 102 additions and 38 deletions

View file

@ -73,7 +73,7 @@ class Migrator:
for name, cls in model_registry.items():
hash_key = schema_hash_key(cls.Meta.index_name)
try:
schema = cls.schema()
schema = cls.redisearch_schema()
except NotImplementedError:
log.info("Skipping migrations for %s", name)
continue

View file

@ -1,9 +1,10 @@
import abc
import dataclasses
import decimal
import inspect
import operator
import re
from copy import deepcopy
from copy import deepcopy, copy
from enum import Enum
from functools import reduce
from typing import (
@ -41,6 +42,25 @@ model_registry = {}
_T = TypeVar("_T")
def subclass_exception(name, bases, module, attached_to):
"""
Create exception subclass. Used by RedisModel below.
The exception is created in a way that allows it to be pickled, assuming
that the returned exception class will be added as an attribute to the
'attached_to' class.
"""
return type(name, bases, {
'__module__': module,
'__qualname__': '%s.%s' % (attached_to.__qualname__, name),
})
def _has_contribute_to_class(value):
# Only call contribute_to_class() if it's bound.
return not inspect.isclass(value) and hasattr(value, 'contribute_to_class')
class TokenEscaper:
"""
Escape punctuation within an input string.
@ -167,21 +187,29 @@ NUMERIC_TYPES = (float, int, decimal.Decimal)
DEFAULT_PAGE_SIZE = 10
class FindQuery(BaseModel):
expressions: Sequence[ExpressionOrNegated]
model: Type['RedisModel']
offset: int = 0
limit: int = DEFAULT_PAGE_SIZE
page_size: int = DEFAULT_PAGE_SIZE
sort_fields: Optional[List[str]] = Field(default_factory=list)
class FindQuery:
def __init__(self,
expressions: Sequence[ExpressionOrNegated],
model: Type['RedisModel'],
offset: int = 0,
limit: int = DEFAULT_PAGE_SIZE,
page_size: int = DEFAULT_PAGE_SIZE,
sort_fields: Optional[List[str]] = None):
self.expressions = expressions
self.model = model
self.offset = offset
self.limit = limit
self.page_size = page_size
_expression: Expression = PrivateAttr(default=None)
_query: str = PrivateAttr(default=None)
_pagination: List[str] = PrivateAttr(default_factory=list)
_model_cache: Optional[List['RedisModel']] = PrivateAttr(default_factory=list)
if sort_fields:
self.sort_fields = self.validate_sort_fields(sort_fields)
else:
self.sort_fields = []
class Config:
arbitrary_types_allowed = True
self._expression = None
self._query = None
self._pagination = []
self._model_cache = []
@property
def pagination(self):
@ -204,19 +232,17 @@ class FindQuery(BaseModel):
def query(self):
return self.resolve_redisearch_query(self.expression)
@validator("sort_fields")
def validate_sort_fields(cls, v, values):
model = values['model']
for sort_field in v:
def validate_sort_fields(self, sort_fields):
for sort_field in sort_fields:
field_name = sort_field.lstrip("-")
if field_name not in model.__fields__:
if field_name not in self.model.__fields__:
raise QueryNotSupportedError(f"You tried sort by {field_name}, but that field "
f"does not exist on the model {model}")
field_proxy = getattr(model, field_name)
f"does not exist on the model {self.model}")
field_proxy = getattr(self.model, field_name)
if not getattr(field_proxy.field.field_info, 'sortable', False):
raise QueryNotSupportedError(f"You tried sort by {field_name}, but {cls} does "
raise QueryNotSupportedError(f"You tried sort by {field_name}, but {self.model} does "
"not define that field as sortable. See docs: XXX")
return v
return sort_fields
@staticmethod
def resolve_field_type(field: ModelField) -> RediSearchFieldTypes:
@ -429,7 +455,7 @@ class FindQuery(BaseModel):
return query.execute()
return self.execute()
def sort_by(self, *fields):
def sort_by(self, *fields: str):
if not fields:
return self
return FindQuery(expressions=self.expressions,
@ -612,12 +638,19 @@ class DefaultMeta:
class ModelMeta(ModelMetaclass):
def add_to_class(cls, name, value):
if _has_contribute_to_class(value):
value.contribute_to_class(cls, name)
else:
setattr(cls, name, value)
def __new__(cls, name, bases, attrs, **kwargs): # noqa C901
meta = attrs.pop('Meta', None)
new_class = super().__new__(cls, name, bases, attrs, **kwargs)
meta = meta or getattr(new_class, 'Meta', None)
base_meta = getattr(new_class, '_meta', None)
parents = [b for b in bases if isinstance(b, ModelMeta)]
if meta and meta != DefaultMeta and meta != base_meta:
new_class.Meta = meta
@ -638,6 +671,17 @@ class ModelMeta(ModelMetaclass):
key = f"{new_class.__module__}.{new_class.__name__}"
model_registry[key] = new_class
if not hasattr(new_class, 'NotFoundError'):
new_class.add_to_class(
'NotFoundError',
subclass_exception(
'NotFoundError',
tuple(
x.NotFoundError for x in bases if hasattr(x, '_meta') and not issubclass(x, abc.ABC)
),
attrs['__module__'],
attached_to=new_class))
# Create proxies for each model field so that we can use the field
# in queries, like Model.get(Model.field_name == 1)
for field_name, field in new_class.__fields__.items():
@ -790,7 +834,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
raise NotImplementedError
@classmethod
def schema(cls):
def redisearch_schema(cls):
raise NotImplementedError
@ -818,7 +862,7 @@ class HashModel(RedisModel, abc.ABC):
def get(cls, pk: Any) -> 'HashModel':
document = cls.db().hgetall(cls.make_primary_key(pk))
if not document:
raise NotFoundError
raise cls.NotFoundError
return cls.parse_obj(document)
@classmethod
@ -848,7 +892,7 @@ class HashModel(RedisModel, abc.ABC):
return f"{name} TAG"
@classmethod
def schema(cls):
def redisearch_schema(cls):
hash_prefix = cls.make_key(cls._meta.primary_key_pattern.format(pk=""))
schema_prefix = f"ON HASH PREFIX 1 {hash_prefix} SCHEMA"
schema_parts = [schema_prefix]

View file

@ -2,6 +2,7 @@ import abc
import decimal
import datetime
from typing import Optional
from unittest import mock
import pytest
import redis
@ -11,7 +12,7 @@ from redis_developer.orm import (
HashModel,
Field,
)
from redis_developer.orm.model import RedisModelError
from redis_developer.orm.model import RedisModelError, QueryNotSupportedError
r = redis.Redis()
today = datetime.date.today()
@ -172,14 +173,17 @@ def test_paginate_query(members):
def test_access_result_by_index_cached(members):
_, member2, _ = members
member1, member2, member3 = members
query = Member.find().sort_by('age')
# Load the cache, throw away the result.
assert query._model_cache == []
query.execute()
assert query._model_cache == [member2, member1, member3]
# Access an item that should be in the cache.
# TODO: Assert that we didn't make a Redis request.
assert query[0] == member2
with mock.patch.object(query.model, 'db') as mock_db:
assert query[0] == member2
assert not mock_db.called
def test_access_result_by_index_not_cached(members):
@ -284,11 +288,28 @@ def test_numeric_queries(members):
actual = Member.find(Member.age >= 100).all()
assert actual == [member3]
import ipdb; ipdb.set_trace()
actual = Member.find(~(Member.age == 100)).all()
assert sorted(actual) == [member1, member2]
def test_sorting(members):
member1, member2, member3 = members
actual = Member.find(Member.age > 34).sort_by('age').all()
assert sorted(actual) == [member3, member1]
actual = Member.find(Member.age > 34).sort_by('-age').all()
assert sorted(actual) == [member1, member3]
with pytest.raises(QueryNotSupportedError):
# This field does not exist.
Member.find().sort_by('not-a-real-field').all()
with pytest.raises(QueryNotSupportedError):
# This field is not sortable.
Member.find().sort_by('join_date').all()
def test_schema():
class Address(BaseHashModel):
a_string: str = Field(index=True)
@ -298,8 +319,7 @@ def test_schema():
another_integer: int
another_float: float
# TODO: Fix
assert Address.schema() == "ON HASH PREFIX 1 redis-developer:tests.test_hash_model.Address: " \
"SCHEMA pk TAG a_string TAG a_full_text_string TAG " \
"a_full_text_string_fts TEXT an_integer NUMERIC SORTABLE " \
"a_float NUMERIC"
assert Address.redisearch_schema() == "ON HASH PREFIX 1 redis-developer:tests.test_hash_model.Address: " \
"SCHEMA pk TAG a_string TAG a_full_text_string TAG " \
"a_full_text_string_fts TEXT an_integer NUMERIC SORTABLE " \
"a_float NUMERIC"