From ef58e854c198e123de3233aa6c65181c7a3bb763 Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Wed, 29 Sep 2021 20:23:39 -0700 Subject: [PATCH] Avoid using Pydantic for FindQuery --- redis_developer/orm/migrations/migrator.py | 2 +- redis_developer/orm/model.py | 98 ++++++++++++++++------ tests/test_hash_model.py | 40 ++++++--- 3 files changed, 102 insertions(+), 38 deletions(-) diff --git a/redis_developer/orm/migrations/migrator.py b/redis_developer/orm/migrations/migrator.py index 440d58a..f083218 100644 --- a/redis_developer/orm/migrations/migrator.py +++ b/redis_developer/orm/migrations/migrator.py @@ -73,7 +73,7 @@ class Migrator: for name, cls in model_registry.items(): hash_key = schema_hash_key(cls.Meta.index_name) try: - schema = cls.schema() + schema = cls.redisearch_schema() except NotImplementedError: log.info("Skipping migrations for %s", name) continue diff --git a/redis_developer/orm/model.py b/redis_developer/orm/model.py index 52b84db..29a2d5e 100644 --- a/redis_developer/orm/model.py +++ b/redis_developer/orm/model.py @@ -1,9 +1,10 @@ import abc import dataclasses import decimal +import inspect import operator import re -from copy import deepcopy +from copy import deepcopy, copy from enum import Enum from functools import reduce from typing import ( @@ -41,6 +42,25 @@ model_registry = {} _T = TypeVar("_T") +def subclass_exception(name, bases, module, attached_to): + """ + Create exception subclass. Used by RedisModel below. + + The exception is created in a way that allows it to be pickled, assuming + that the returned exception class will be added as an attribute to the + 'attached_to' class. + """ + return type(name, bases, { + '__module__': module, + '__qualname__': '%s.%s' % (attached_to.__qualname__, name), + }) + + +def _has_contribute_to_class(value): + # Only call contribute_to_class() if it's bound. + return not inspect.isclass(value) and hasattr(value, 'contribute_to_class') + + class TokenEscaper: """ Escape punctuation within an input string. @@ -167,21 +187,29 @@ NUMERIC_TYPES = (float, int, decimal.Decimal) DEFAULT_PAGE_SIZE = 10 -class FindQuery(BaseModel): - expressions: Sequence[ExpressionOrNegated] - model: Type['RedisModel'] - offset: int = 0 - limit: int = DEFAULT_PAGE_SIZE - page_size: int = DEFAULT_PAGE_SIZE - sort_fields: Optional[List[str]] = Field(default_factory=list) +class FindQuery: + def __init__(self, + expressions: Sequence[ExpressionOrNegated], + model: Type['RedisModel'], + offset: int = 0, + limit: int = DEFAULT_PAGE_SIZE, + page_size: int = DEFAULT_PAGE_SIZE, + sort_fields: Optional[List[str]] = None): + self.expressions = expressions + self.model = model + self.offset = offset + self.limit = limit + self.page_size = page_size - _expression: Expression = PrivateAttr(default=None) - _query: str = PrivateAttr(default=None) - _pagination: List[str] = PrivateAttr(default_factory=list) - _model_cache: Optional[List['RedisModel']] = PrivateAttr(default_factory=list) + if sort_fields: + self.sort_fields = self.validate_sort_fields(sort_fields) + else: + self.sort_fields = [] - class Config: - arbitrary_types_allowed = True + self._expression = None + self._query = None + self._pagination = [] + self._model_cache = [] @property def pagination(self): @@ -204,19 +232,17 @@ class FindQuery(BaseModel): def query(self): return self.resolve_redisearch_query(self.expression) - @validator("sort_fields") - def validate_sort_fields(cls, v, values): - model = values['model'] - for sort_field in v: + def validate_sort_fields(self, sort_fields): + for sort_field in sort_fields: field_name = sort_field.lstrip("-") - if field_name not in model.__fields__: + if field_name not in self.model.__fields__: raise QueryNotSupportedError(f"You tried sort by {field_name}, but that field " - f"does not exist on the model {model}") - field_proxy = getattr(model, field_name) + f"does not exist on the model {self.model}") + field_proxy = getattr(self.model, field_name) if not getattr(field_proxy.field.field_info, 'sortable', False): - raise QueryNotSupportedError(f"You tried sort by {field_name}, but {cls} does " + raise QueryNotSupportedError(f"You tried sort by {field_name}, but {self.model} does " "not define that field as sortable. See docs: XXX") - return v + return sort_fields @staticmethod def resolve_field_type(field: ModelField) -> RediSearchFieldTypes: @@ -429,7 +455,7 @@ class FindQuery(BaseModel): return query.execute() return self.execute() - def sort_by(self, *fields): + def sort_by(self, *fields: str): if not fields: return self return FindQuery(expressions=self.expressions, @@ -612,12 +638,19 @@ class DefaultMeta: class ModelMeta(ModelMetaclass): + def add_to_class(cls, name, value): + if _has_contribute_to_class(value): + value.contribute_to_class(cls, name) + else: + setattr(cls, name, value) + def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 meta = attrs.pop('Meta', None) new_class = super().__new__(cls, name, bases, attrs, **kwargs) meta = meta or getattr(new_class, 'Meta', None) base_meta = getattr(new_class, '_meta', None) + parents = [b for b in bases if isinstance(b, ModelMeta)] if meta and meta != DefaultMeta and meta != base_meta: new_class.Meta = meta @@ -638,6 +671,17 @@ class ModelMeta(ModelMetaclass): key = f"{new_class.__module__}.{new_class.__name__}" model_registry[key] = new_class + if not hasattr(new_class, 'NotFoundError'): + new_class.add_to_class( + 'NotFoundError', + subclass_exception( + 'NotFoundError', + tuple( + x.NotFoundError for x in bases if hasattr(x, '_meta') and not issubclass(x, abc.ABC) + ), + attrs['__module__'], + attached_to=new_class)) + # Create proxies for each model field so that we can use the field # in queries, like Model.get(Model.field_name == 1) for field_name, field in new_class.__fields__.items(): @@ -790,7 +834,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): raise NotImplementedError @classmethod - def schema(cls): + def redisearch_schema(cls): raise NotImplementedError @@ -818,7 +862,7 @@ class HashModel(RedisModel, abc.ABC): def get(cls, pk: Any) -> 'HashModel': document = cls.db().hgetall(cls.make_primary_key(pk)) if not document: - raise NotFoundError + raise cls.NotFoundError return cls.parse_obj(document) @classmethod @@ -848,7 +892,7 @@ class HashModel(RedisModel, abc.ABC): return f"{name} TAG" @classmethod - def schema(cls): + def redisearch_schema(cls): hash_prefix = cls.make_key(cls._meta.primary_key_pattern.format(pk="")) schema_prefix = f"ON HASH PREFIX 1 {hash_prefix} SCHEMA" schema_parts = [schema_prefix] diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index 1847160..02371cb 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -2,6 +2,7 @@ import abc import decimal import datetime from typing import Optional +from unittest import mock import pytest import redis @@ -11,7 +12,7 @@ from redis_developer.orm import ( HashModel, Field, ) -from redis_developer.orm.model import RedisModelError +from redis_developer.orm.model import RedisModelError, QueryNotSupportedError r = redis.Redis() today = datetime.date.today() @@ -172,14 +173,17 @@ def test_paginate_query(members): def test_access_result_by_index_cached(members): - _, member2, _ = members + member1, member2, member3 = members query = Member.find().sort_by('age') # Load the cache, throw away the result. + assert query._model_cache == [] query.execute() + assert query._model_cache == [member2, member1, member3] # Access an item that should be in the cache. - # TODO: Assert that we didn't make a Redis request. - assert query[0] == member2 + with mock.patch.object(query.model, 'db') as mock_db: + assert query[0] == member2 + assert not mock_db.called def test_access_result_by_index_not_cached(members): @@ -284,11 +288,28 @@ def test_numeric_queries(members): actual = Member.find(Member.age >= 100).all() assert actual == [member3] - import ipdb; ipdb.set_trace() actual = Member.find(~(Member.age == 100)).all() assert sorted(actual) == [member1, member2] +def test_sorting(members): + member1, member2, member3 = members + + actual = Member.find(Member.age > 34).sort_by('age').all() + assert sorted(actual) == [member3, member1] + + actual = Member.find(Member.age > 34).sort_by('-age').all() + assert sorted(actual) == [member1, member3] + + with pytest.raises(QueryNotSupportedError): + # This field does not exist. + Member.find().sort_by('not-a-real-field').all() + + with pytest.raises(QueryNotSupportedError): + # This field is not sortable. + Member.find().sort_by('join_date').all() + + def test_schema(): class Address(BaseHashModel): a_string: str = Field(index=True) @@ -298,8 +319,7 @@ def test_schema(): another_integer: int another_float: float - # TODO: Fix - assert Address.schema() == "ON HASH PREFIX 1 redis-developer:tests.test_hash_model.Address: " \ - "SCHEMA pk TAG a_string TAG a_full_text_string TAG " \ - "a_full_text_string_fts TEXT an_integer NUMERIC SORTABLE " \ - "a_float NUMERIC" + assert Address.redisearch_schema() == "ON HASH PREFIX 1 redis-developer:tests.test_hash_model.Address: " \ + "SCHEMA pk TAG a_string TAG a_full_text_string TAG " \ + "a_full_text_string_fts TEXT an_integer NUMERIC SORTABLE " \ + "a_float NUMERIC"