Avoid using Pydantic for FindQuery

This commit is contained in:
Andrew Brookins 2021-09-29 20:23:39 -07:00
parent 01cab5352b
commit ef58e854c1
3 changed files with 102 additions and 38 deletions

View file

@ -73,7 +73,7 @@ class Migrator:
for name, cls in model_registry.items(): for name, cls in model_registry.items():
hash_key = schema_hash_key(cls.Meta.index_name) hash_key = schema_hash_key(cls.Meta.index_name)
try: try:
schema = cls.schema() schema = cls.redisearch_schema()
except NotImplementedError: except NotImplementedError:
log.info("Skipping migrations for %s", name) log.info("Skipping migrations for %s", name)
continue continue

View file

@ -1,9 +1,10 @@
import abc import abc
import dataclasses import dataclasses
import decimal import decimal
import inspect
import operator import operator
import re import re
from copy import deepcopy from copy import deepcopy, copy
from enum import Enum from enum import Enum
from functools import reduce from functools import reduce
from typing import ( from typing import (
@ -41,6 +42,25 @@ model_registry = {}
_T = TypeVar("_T") _T = TypeVar("_T")
def subclass_exception(name, bases, module, attached_to):
"""
Create exception subclass. Used by RedisModel below.
The exception is created in a way that allows it to be pickled, assuming
that the returned exception class will be added as an attribute to the
'attached_to' class.
"""
return type(name, bases, {
'__module__': module,
'__qualname__': '%s.%s' % (attached_to.__qualname__, name),
})
def _has_contribute_to_class(value):
# Only call contribute_to_class() if it's bound.
return not inspect.isclass(value) and hasattr(value, 'contribute_to_class')
class TokenEscaper: class TokenEscaper:
""" """
Escape punctuation within an input string. Escape punctuation within an input string.
@ -167,21 +187,29 @@ NUMERIC_TYPES = (float, int, decimal.Decimal)
DEFAULT_PAGE_SIZE = 10 DEFAULT_PAGE_SIZE = 10
class FindQuery(BaseModel): class FindQuery:
expressions: Sequence[ExpressionOrNegated] def __init__(self,
model: Type['RedisModel'] expressions: Sequence[ExpressionOrNegated],
offset: int = 0 model: Type['RedisModel'],
limit: int = DEFAULT_PAGE_SIZE offset: int = 0,
page_size: int = DEFAULT_PAGE_SIZE limit: int = DEFAULT_PAGE_SIZE,
sort_fields: Optional[List[str]] = Field(default_factory=list) page_size: int = DEFAULT_PAGE_SIZE,
sort_fields: Optional[List[str]] = None):
self.expressions = expressions
self.model = model
self.offset = offset
self.limit = limit
self.page_size = page_size
_expression: Expression = PrivateAttr(default=None) if sort_fields:
_query: str = PrivateAttr(default=None) self.sort_fields = self.validate_sort_fields(sort_fields)
_pagination: List[str] = PrivateAttr(default_factory=list) else:
_model_cache: Optional[List['RedisModel']] = PrivateAttr(default_factory=list) self.sort_fields = []
class Config: self._expression = None
arbitrary_types_allowed = True self._query = None
self._pagination = []
self._model_cache = []
@property @property
def pagination(self): def pagination(self):
@ -204,19 +232,17 @@ class FindQuery(BaseModel):
def query(self): def query(self):
return self.resolve_redisearch_query(self.expression) return self.resolve_redisearch_query(self.expression)
@validator("sort_fields") def validate_sort_fields(self, sort_fields):
def validate_sort_fields(cls, v, values): for sort_field in sort_fields:
model = values['model']
for sort_field in v:
field_name = sort_field.lstrip("-") field_name = sort_field.lstrip("-")
if field_name not in model.__fields__: if field_name not in self.model.__fields__:
raise QueryNotSupportedError(f"You tried sort by {field_name}, but that field " raise QueryNotSupportedError(f"You tried sort by {field_name}, but that field "
f"does not exist on the model {model}") f"does not exist on the model {self.model}")
field_proxy = getattr(model, field_name) field_proxy = getattr(self.model, field_name)
if not getattr(field_proxy.field.field_info, 'sortable', False): if not getattr(field_proxy.field.field_info, 'sortable', False):
raise QueryNotSupportedError(f"You tried sort by {field_name}, but {cls} does " raise QueryNotSupportedError(f"You tried sort by {field_name}, but {self.model} does "
"not define that field as sortable. See docs: XXX") "not define that field as sortable. See docs: XXX")
return v return sort_fields
@staticmethod @staticmethod
def resolve_field_type(field: ModelField) -> RediSearchFieldTypes: def resolve_field_type(field: ModelField) -> RediSearchFieldTypes:
@ -429,7 +455,7 @@ class FindQuery(BaseModel):
return query.execute() return query.execute()
return self.execute() return self.execute()
def sort_by(self, *fields): def sort_by(self, *fields: str):
if not fields: if not fields:
return self return self
return FindQuery(expressions=self.expressions, return FindQuery(expressions=self.expressions,
@ -612,12 +638,19 @@ class DefaultMeta:
class ModelMeta(ModelMetaclass): class ModelMeta(ModelMetaclass):
def add_to_class(cls, name, value):
if _has_contribute_to_class(value):
value.contribute_to_class(cls, name)
else:
setattr(cls, name, value)
def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 def __new__(cls, name, bases, attrs, **kwargs): # noqa C901
meta = attrs.pop('Meta', None) meta = attrs.pop('Meta', None)
new_class = super().__new__(cls, name, bases, attrs, **kwargs) new_class = super().__new__(cls, name, bases, attrs, **kwargs)
meta = meta or getattr(new_class, 'Meta', None) meta = meta or getattr(new_class, 'Meta', None)
base_meta = getattr(new_class, '_meta', None) base_meta = getattr(new_class, '_meta', None)
parents = [b for b in bases if isinstance(b, ModelMeta)]
if meta and meta != DefaultMeta and meta != base_meta: if meta and meta != DefaultMeta and meta != base_meta:
new_class.Meta = meta new_class.Meta = meta
@ -638,6 +671,17 @@ class ModelMeta(ModelMetaclass):
key = f"{new_class.__module__}.{new_class.__name__}" key = f"{new_class.__module__}.{new_class.__name__}"
model_registry[key] = new_class model_registry[key] = new_class
if not hasattr(new_class, 'NotFoundError'):
new_class.add_to_class(
'NotFoundError',
subclass_exception(
'NotFoundError',
tuple(
x.NotFoundError for x in bases if hasattr(x, '_meta') and not issubclass(x, abc.ABC)
),
attrs['__module__'],
attached_to=new_class))
# Create proxies for each model field so that we can use the field # Create proxies for each model field so that we can use the field
# in queries, like Model.get(Model.field_name == 1) # in queries, like Model.get(Model.field_name == 1)
for field_name, field in new_class.__fields__.items(): for field_name, field in new_class.__fields__.items():
@ -790,7 +834,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def schema(cls): def redisearch_schema(cls):
raise NotImplementedError raise NotImplementedError
@ -818,7 +862,7 @@ class HashModel(RedisModel, abc.ABC):
def get(cls, pk: Any) -> 'HashModel': def get(cls, pk: Any) -> 'HashModel':
document = cls.db().hgetall(cls.make_primary_key(pk)) document = cls.db().hgetall(cls.make_primary_key(pk))
if not document: if not document:
raise NotFoundError raise cls.NotFoundError
return cls.parse_obj(document) return cls.parse_obj(document)
@classmethod @classmethod
@ -848,7 +892,7 @@ class HashModel(RedisModel, abc.ABC):
return f"{name} TAG" return f"{name} TAG"
@classmethod @classmethod
def schema(cls): def redisearch_schema(cls):
hash_prefix = cls.make_key(cls._meta.primary_key_pattern.format(pk="")) hash_prefix = cls.make_key(cls._meta.primary_key_pattern.format(pk=""))
schema_prefix = f"ON HASH PREFIX 1 {hash_prefix} SCHEMA" schema_prefix = f"ON HASH PREFIX 1 {hash_prefix} SCHEMA"
schema_parts = [schema_prefix] schema_parts = [schema_prefix]

View file

@ -2,6 +2,7 @@ import abc
import decimal import decimal
import datetime import datetime
from typing import Optional from typing import Optional
from unittest import mock
import pytest import pytest
import redis import redis
@ -11,7 +12,7 @@ from redis_developer.orm import (
HashModel, HashModel,
Field, Field,
) )
from redis_developer.orm.model import RedisModelError from redis_developer.orm.model import RedisModelError, QueryNotSupportedError
r = redis.Redis() r = redis.Redis()
today = datetime.date.today() today = datetime.date.today()
@ -172,14 +173,17 @@ def test_paginate_query(members):
def test_access_result_by_index_cached(members): def test_access_result_by_index_cached(members):
_, member2, _ = members member1, member2, member3 = members
query = Member.find().sort_by('age') query = Member.find().sort_by('age')
# Load the cache, throw away the result. # Load the cache, throw away the result.
assert query._model_cache == []
query.execute() query.execute()
assert query._model_cache == [member2, member1, member3]
# Access an item that should be in the cache. # Access an item that should be in the cache.
# TODO: Assert that we didn't make a Redis request. with mock.patch.object(query.model, 'db') as mock_db:
assert query[0] == member2 assert query[0] == member2
assert not mock_db.called
def test_access_result_by_index_not_cached(members): def test_access_result_by_index_not_cached(members):
@ -284,11 +288,28 @@ def test_numeric_queries(members):
actual = Member.find(Member.age >= 100).all() actual = Member.find(Member.age >= 100).all()
assert actual == [member3] assert actual == [member3]
import ipdb; ipdb.set_trace()
actual = Member.find(~(Member.age == 100)).all() actual = Member.find(~(Member.age == 100)).all()
assert sorted(actual) == [member1, member2] assert sorted(actual) == [member1, member2]
def test_sorting(members):
member1, member2, member3 = members
actual = Member.find(Member.age > 34).sort_by('age').all()
assert sorted(actual) == [member3, member1]
actual = Member.find(Member.age > 34).sort_by('-age').all()
assert sorted(actual) == [member1, member3]
with pytest.raises(QueryNotSupportedError):
# This field does not exist.
Member.find().sort_by('not-a-real-field').all()
with pytest.raises(QueryNotSupportedError):
# This field is not sortable.
Member.find().sort_by('join_date').all()
def test_schema(): def test_schema():
class Address(BaseHashModel): class Address(BaseHashModel):
a_string: str = Field(index=True) a_string: str = Field(index=True)
@ -298,8 +319,7 @@ def test_schema():
another_integer: int another_integer: int
another_float: float another_float: float
# TODO: Fix assert Address.redisearch_schema() == "ON HASH PREFIX 1 redis-developer:tests.test_hash_model.Address: " \
assert Address.schema() == "ON HASH PREFIX 1 redis-developer:tests.test_hash_model.Address: " \ "SCHEMA pk TAG a_string TAG a_full_text_string TAG " \
"SCHEMA pk TAG a_string TAG a_full_text_string TAG " \ "a_full_text_string_fts TEXT an_integer NUMERIC SORTABLE " \
"a_full_text_string_fts TEXT an_integer NUMERIC SORTABLE " \ "a_float NUMERIC"
"a_float NUMERIC"