From 01cab5352b2117cf820ade8325ea4b37620d8789 Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Sat, 25 Sep 2021 20:38:02 -0700 Subject: [PATCH] WIP on pagination, sorting all()/first() methods --- redis_developer/orm/model.py | 608 +++++++++++++++++--------- redis_developer/orm/query_iterator.py | 55 --- tests/conftest.py | 2 +- tests/test_hash_model.py | 102 +++-- 4 files changed, 473 insertions(+), 294 deletions(-) delete mode 100644 redis_developer/orm/query_iterator.py diff --git a/redis_developer/orm/model.py b/redis_developer/orm/model.py index 4e92eea..52b84db 100644 --- a/redis_developer/orm/model.py +++ b/redis_developer/orm/model.py @@ -2,8 +2,8 @@ import abc import dataclasses import decimal import operator -from copy import copy, deepcopy -from dataclasses import dataclass +import re +from copy import deepcopy from enum import Enum from functools import reduce from typing import ( @@ -20,13 +20,15 @@ from typing import ( Sequence, no_type_check, Protocol, - List, Type + List, + Type, + Pattern ) import uuid import redis from pydantic import BaseModel, validator -from pydantic.fields import FieldInfo as PydanticFieldInfo +from pydantic.fields import FieldInfo as PydanticFieldInfo, PrivateAttr, Field from pydantic.fields import ModelField, Undefined, UndefinedType from pydantic.main import ModelMetaclass from pydantic.typing import NoArgAnyCallable @@ -34,18 +36,47 @@ from pydantic.utils import Representation from .encoders import jsonable_encoder - model_registry = {} _T = TypeVar("_T") +class TokenEscaper: + """ + Escape punctuation within an input string. + """ + + # Characters that RediSearch requires us to escape during queries. + # Source: https://oss.redis.com/redisearch/Escaping/#the_rules_of_text_field_tokenization + DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\ ]" + + def __init__(self, escape_chars_re: Optional[Pattern] = None): + if escape_chars_re: + self.escaped_chars_re = escape_chars_re + else: + self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS) + + def escape(self, string): + def escape_symbol(match): + value = match.group(0) + return f"\\{value}" + + return self.escaped_chars_re.sub(escape_symbol, string) + + +escaper = TokenEscaper() + + class RedisModelError(Exception): pass class NotFoundError(Exception): - pass + """ + A query found no results. + + TODO: embed in Model class? + """ class Operators(Enum): @@ -61,9 +92,10 @@ class Operators(Enum): IN = 10 NOT_IN = 11 LIKE = 12 + ALL = 13 -@dataclass +@dataclasses.dataclass class NegatedExpression: expression: 'Expression' @@ -77,7 +109,7 @@ class NegatedExpression: return Expression(left=self, op=Operators.OR, right=other) -@dataclass +@dataclasses.dataclass class Expression: op: Operators left: Any @@ -96,186 +128,6 @@ class Expression: ExpressionOrNegated = Union[Expression, NegatedExpression] -class QueryNotSupportedError(Exception): - """The attempted query is not supported.""" - - -class RediSearchFieldTypes(Enum): - TEXT = 'TEXT' - TAG = 'TAG' - NUMERIC = 'NUMERIC' - GEO = 'GEO' - - -# TODO: How to handle Geo fields? -NUMERIC_TYPES = (float, int, decimal.Decimal) - - -@dataclass -class FindQuery: - expressions: Sequence[Expression] - expression: Expression = dataclasses.field(init=False) - query: str = dataclasses.field(init=False) - pagination: List[str] = dataclasses.field(init=False) - model: Type['RedisModel'] - limit: Optional[int] = None - offset: Optional[int] = None - - def __post_init__(self): - self.expression = reduce(operator.and_, self.expressions) - self.query = self.resolve_redisearch_query(self.expression) - self.pagination = self.resolve_redisearch_pagination() - - def resolve_field_type(self, field: ModelField) -> RediSearchFieldTypes: - if getattr(field.field_info, 'primary_key', None) is True: - return RediSearchFieldTypes.TAG - elif getattr(field.field_info, 'full_text_search', None) is True: - return RediSearchFieldTypes.TEXT - - field_type = field.outer_type_ - - # TODO: GEO - if any(issubclass(field_type, t) for t in NUMERIC_TYPES): - return RediSearchFieldTypes.NUMERIC - else: - # TAG fields are the default field type. - return RediSearchFieldTypes.TAG - - @staticmethod - def expand_tag_value(value): - err = RedisModelError(f"Using the IN operator requires passing an iterable of " - "possible values. You passed: {value}") - if isinstance(str, value): - raise err - try: - expanded_value = "|".join(value) - except TypeError: - raise err - return expanded_value - - def resolve_value(self, field_name: str, field_type: RediSearchFieldTypes, - op: Operators, value: Any) -> str: - result = "" - if field_type is RediSearchFieldTypes.TEXT: - result = f"@{field_name}:" - if op is Operators.EQ: - result += f'"{value}"' - elif op is Operators.NE: - result = f'-({result}"{value}")' - elif op is Operators.LIKE: - result += value - else: - # TODO: Handling TAG, TEXT switch-offs, etc. - raise QueryNotSupportedError("Only equals (=) and not-equals (!=) comparisons are " - "currently supported for TEXT fields. See docs: TODO") - elif field_type is RediSearchFieldTypes.NUMERIC: - if op is Operators.EQ: - result += f"@{field_name}:[{value} {value}]" - elif op is Operators.NE: - # TODO: Is this enough or do we also need a clause for all values - # ([-inf +inf]) from which we then subtract the undesirable value? - result += f"-(@{field_name}:[{value} {value}])" - elif op is Operators.GT: - result += f"@{field_name}:[({value} +inf]" - elif op is Operators.LT: - result += f"@{field_name}:[-inf ({value}]" - elif op is Operators.GE: - result += f"@{field_name}:[{value} +inf]" - elif op is Operators.LE: - result += f"@{field_name}:[-inf {value}]" - elif field_type is RediSearchFieldTypes.TAG: - if op is Operators.EQ: - result += f"@{field_name}:{{{value}}}" - elif op is Operators.NE: - result += f"-(@{field_name}:{{{value}}})" - elif op is Operators.IN: - expanded_value = self.expand_tag_value(value) - result += f"(@{field_name}:{{{expanded_value}}})" - elif op is Operators.NOT_IN: - expanded_value = self.expand_tag_value(value) - result += f"-(@{field_name}:{{{expanded_value}}})" - - return result - - def resolve_redisearch_pagination(self): - """Resolve pagination options for a query.""" - if not self.limit and not self.offset: - return [] - offset = self.offset or 0 - limit = self.limit or 10 - return ["LIMIT", offset, limit] - - def resolve_redisearch_query(self, expression: ExpressionOrNegated): - """Resolve an expression to a string RediSearch query.""" - field_type = None - field_name = None - encompassing_expression_is_negated = False - result = "" - - if isinstance(expression, NegatedExpression): - encompassing_expression_is_negated = True - expression = expression.expression - - if isinstance(expression.left, Expression) or \ - isinstance(expression.left, NegatedExpression): - result += f"({self.resolve_redisearch_query(expression.left)})" - elif isinstance(expression.left, ModelField): - field_type = self.resolve_field_type(expression.left) - field_name = expression.left.name - else: - import ipdb; ipdb.set_trace() - raise QueryNotSupportedError(f"A query expression should start with either a field " - f"or an expression enclosed in parenthesis. See docs: " - f"TODO") - - right = expression.right - right_is_negated = isinstance(right, NegatedExpression) - - if isinstance(right, Expression) or right_is_negated: - if expression.op == Operators.AND: - result += " " - elif expression.op == Operators.OR: - result += "| " - else: - raise QueryNotSupportedError("You can only combine two query expressions with" - "AND (&) or OR (|). See docs: TODO") - - if right_is_negated: - result += "-" - # We're handling the RediSearch operator in this call ("-"), so resolve the - # inner expression instead of the NegatedExpression. - right = right.expression - - result += f"({self.resolve_redisearch_query(right)})" - else: - if isinstance(right, ModelField): - raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO") - else: - result += self.resolve_value(field_name, field_type, expression.op, right) - - if encompassing_expression_is_negated: - result = f"-({result})" - - return result - - def find(self): - args = ["ft.search", self.model.Meta.index_name, self.query] - # TODO: Do we need self.pagination if we're just appending to query anyway? - if self.pagination: - args.extend(self.pagination) - return self.model.db().execute_command(*args) - - -class PrimaryKeyCreator(Protocol): - def create_pk(self, *args, **kwargs) -> str: - """Create a new primary key""" - - -class Uuid4PrimaryKey: - def create_pk(self) -> str: - return str(uuid.uuid4()) - - class ExpressionProxy: def __init__(self, field: ModelField): self.field = field @@ -299,6 +151,352 @@ class ExpressionProxy: return Expression(left=self.field, op=Operators.GE, right=other) +class QueryNotSupportedError(Exception): + """The attempted query is not supported.""" + + +class RediSearchFieldTypes(Enum): + TEXT = 'TEXT' + TAG = 'TAG' + NUMERIC = 'NUMERIC' + GEO = 'GEO' + + +# TODO: How to handle Geo fields? +NUMERIC_TYPES = (float, int, decimal.Decimal) +DEFAULT_PAGE_SIZE = 10 + + +class FindQuery(BaseModel): + expressions: Sequence[ExpressionOrNegated] + model: Type['RedisModel'] + offset: int = 0 + limit: int = DEFAULT_PAGE_SIZE + page_size: int = DEFAULT_PAGE_SIZE + sort_fields: Optional[List[str]] = Field(default_factory=list) + + _expression: Expression = PrivateAttr(default=None) + _query: str = PrivateAttr(default=None) + _pagination: List[str] = PrivateAttr(default_factory=list) + _model_cache: Optional[List['RedisModel']] = PrivateAttr(default_factory=list) + + class Config: + arbitrary_types_allowed = True + + @property + def pagination(self): + if self._pagination: + return self._pagination + self._pagination = self.resolve_redisearch_pagination() + return self._pagination + + @property + def expression(self): + if self._expression: + return self._expression + if self.expressions: + self._expression = reduce(operator.and_, self.expressions) + else: + self._expression = Expression(left=None, right=None, op=Operators.ALL) + return self._expression + + @property + def query(self): + return self.resolve_redisearch_query(self.expression) + + @validator("sort_fields") + def validate_sort_fields(cls, v, values): + model = values['model'] + for sort_field in v: + field_name = sort_field.lstrip("-") + if field_name not in model.__fields__: + raise QueryNotSupportedError(f"You tried sort by {field_name}, but that field " + f"does not exist on the model {model}") + field_proxy = getattr(model, field_name) + if not getattr(field_proxy.field.field_info, 'sortable', False): + raise QueryNotSupportedError(f"You tried sort by {field_name}, but {cls} does " + "not define that field as sortable. See docs: XXX") + return v + + @staticmethod + def resolve_field_type(field: ModelField) -> RediSearchFieldTypes: + if getattr(field.field_info, 'primary_key', None) is True: + return RediSearchFieldTypes.TAG + elif getattr(field.field_info, 'full_text_search', None) is True: + return RediSearchFieldTypes.TEXT + + field_type = field.outer_type_ + + # TODO: GEO + if any(issubclass(field_type, t) for t in NUMERIC_TYPES): + return RediSearchFieldTypes.NUMERIC + else: + # TAG fields are the default field type. + # TODO: A ListField or ArrayField that supports multiple values + # and contains logic. + return RediSearchFieldTypes.TAG + + @staticmethod + def expand_tag_value(value): + err = RedisModelError(f"Using the IN operator requires passing a sequence of " + "possible values. You passed: {value}") + if isinstance(str, value): + raise err + try: + expanded_value = "|".join([escaper.escape(v) for v in value]) + except TypeError: + raise err + return expanded_value + + @classmethod + def resolve_value(cls, field_name: str, field_type: RediSearchFieldTypes, + op: Operators, value: Any) -> str: + result = "" + if field_type is RediSearchFieldTypes.TEXT: + result = f"@{field_name}:" + if op is Operators.EQ: + result += f'"{value}"' + elif op is Operators.NE: + result = f'-({result}"{value}")' + elif op is Operators.LIKE: + result += value + else: + raise QueryNotSupportedError("Only equals (=), not-equals (!=), and like() " + "comparisons are supported for TEXT fields. See " + "docs: TODO.") + elif field_type is RediSearchFieldTypes.NUMERIC: + if op is Operators.EQ: + result += f"@{field_name}:[{value} {value}]" + elif op is Operators.NE: + # TODO: Is this enough or do we also need a clause for all values + # ([-inf +inf]) from which we then subtract the undesirable value? + result += f"-(@{field_name}:[{value} {value}])" + elif op is Operators.GT: + result += f"@{field_name}:[({value} +inf]" + elif op is Operators.LT: + result += f"@{field_name}:[-inf ({value}]" + elif op is Operators.GE: + result += f"@{field_name}:[{value} +inf]" + elif op is Operators.LE: + result += f"@{field_name}:[-inf {value}]" + elif field_type is RediSearchFieldTypes.TAG: + if op is Operators.EQ: + value = escaper.escape(value) + result += f"@{field_name}:{{{value}}}" + elif op is Operators.NE: + value = escaper.escape(value) + result += f"-(@{field_name}:{{{value}}})" + elif op is Operators.IN: + expanded_value = cls.expand_tag_value(value) + result += f"(@{field_name}:{{{expanded_value}}})" + elif op is Operators.NOT_IN: + expanded_value = cls.expand_tag_value(value) + result += f"-(@{field_name}:{{{expanded_value}}})" + + return result + + def resolve_redisearch_pagination(self): + """Resolve pagination options for a query.""" + return ["LIMIT", self.offset, self.limit] + + def resolve_redisearch_sort_fields(self): + """Resolve sort options for a query.""" + if not self.sort_fields: + return + fields = [] + for f in self.sort_fields: + direction = "desc" if f.startswith('-') else 'asc' + fields.extend([f.lstrip('-'), direction]) + if self.sort_fields: + return ["SORTBY", *fields] + + @classmethod + def resolve_redisearch_query(cls, expression: ExpressionOrNegated): + """Resolve an expression to a string RediSearch query.""" + field_type = None + field_name = None + encompassing_expression_is_negated = False + result = "" + + if isinstance(expression, NegatedExpression): + encompassing_expression_is_negated = True + expression = expression.expression + + if expression.op is Operators.ALL: + if encompassing_expression_is_negated: + # TODO: Is there a use case for this, perhaps for dynamic + # scoring purposes? + raise QueryNotSupportedError("You cannot negate a query for all results.") + return "*" + + if isinstance(expression.left, Expression) or \ + isinstance(expression.left, NegatedExpression): + result += f"({cls.resolve_redisearch_query(expression.left)})" + elif isinstance(expression.left, ModelField): + field_type = cls.resolve_field_type(expression.left) + field_name = expression.left.name + else: + import ipdb; ipdb.set_trace() + raise QueryNotSupportedError(f"A query expression should start with either a field " + f"or an expression enclosed in parenthesis. See docs: " + f"TODO") + + right = expression.right + + if isinstance(right, Expression) or isinstance(right, NegatedExpression): + if expression.op == Operators.AND: + result += " " + elif expression.op == Operators.OR: + result += "| " + else: + raise QueryNotSupportedError("You can only combine two query expressions with" + "AND (&) or OR (|). See docs: TODO") + + if isinstance(right, NegatedExpression): + result += "-" + # We're handling the RediSearch operator in this call ("-"), so resolve the + # inner expression instead of the NegatedExpression. + right = right.expression + + result += f"({cls.resolve_redisearch_query(right)})" + else: + if isinstance(right, ModelField): + raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO") + else: + # TODO: Optionals causing IDE errors here + result += cls.resolve_value(field_name, field_type, expression.op, right) + + if encompassing_expression_is_negated: + result = f"-({result})" + + return result + + def execute(self, exhaust_results=True): + args = ["ft.search", self.model.Meta.index_name, self.query, *self.pagination] + if self.sort_fields: + args += self.resolve_redisearch_sort_fields() + + # Reset the cache if we're executing from offset 0. + if self.offset == 0: + self._model_cache.clear() + + # If the offset is greater than 0, we're paginating through a result set, + # so append the new results to results already in the cache. + raw_result = self.model.db().execute_command(*args) + count = raw_result[0] + results = self.model.from_redis(raw_result) + self._model_cache += results + + if not exhaust_results: + return self._model_cache + + # The query returned all results, so we have no more work to do. + if count <= len(results): + return self._model_cache + + # Transparently (to the user) make subsequent requests to paginate + # through the results and finally return them all. + query = self + while True: + # Make a query for each pass of the loop, with a new offset equal to the + # current offset plus `page_size`, until we stop getting results back. + query = FindQuery(expressions=query.expressions, + model=query.model, + offset=query.offset + query.page_size, + page_size=query.page_size, + limit=query.limit) + _results = query.execute(exhaust_results=False) + if not _results: + break + self._model_cache += _results + return self._model_cache + + def first(self): + query = FindQuery(expressions=self.expressions, model=self.model, + offset=0, limit=1, sort_fields=self.sort_fields) + return query.execute()[0] + + def all(self, batch_size=10): + if batch_size != self.page_size: + # TODO: There's probably a copy-with-change mechanism in Pydantic, + # or can we use one from dataclasses? + query = FindQuery(expressions=self.expressions, + model=self.model, + offset=self.offset, + page_size=batch_size, + limit=batch_size, + sort_fields=self.sort_fields) + return query.execute() + return self.execute() + + def sort_by(self, *fields): + if not fields: + return self + return FindQuery(expressions=self.expressions, + model=self.model, + offset=self.offset, + page_size=self.page_size, + limit=self.limit, + sort_fields=list(fields)) + + def update(self, **kwargs): + """Update all matching records in this query.""" + # TODO + + def delete(cls, **field_values): + """Delete all matching records in this query.""" + for field_name, value in field_values: + valid_attr = hasattr(cls.model, field_name) + if not valid_attr: + raise RedisModelError(f"Can't update field {field_name} because " + f"the field does not exist on the model {cls}") + + return cls + + def __iter__(self): + if self._model_cache: + for m in self._model_cache: + yield m + else: + for m in self.execute(): + yield m + + def __getitem__(self, item: int): + """ + Given this code: + Model.find()[1000] + + We should return only the 1000th result. + + 1. If the result is loaded in the query cache for this query, + we can return it directly from the cache. + + 2. If the query cache does not have enough elements to return + that result, then we should clone the current query and + give it a new offset and limit: offset=n, limit=1. + """ + if self._model_cache and len(self._model_cache) >= item: + return self._model_cache[item] + + query = FindQuery(expressions=self.expressions, + model=self.model, + offset=item, + sort_fields=self.sort_fields, + limit=1) + + return query.execute()[0] + + +class PrimaryKeyCreator(Protocol): + def create_pk(self, *args, **kwargs) -> str: + """Create a new primary key""" + + +class Uuid4PrimaryKey: + def create_pk(self, *args, **kwargs) -> str: + return str(uuid.uuid4()) + + def __dataclass_transform__( *, eq_default: bool = True, @@ -395,21 +593,22 @@ def Field( return field_info -@dataclass +@dataclasses.dataclass class PrimaryKey: name: str field: ModelField class DefaultMeta: + # TODO: Should this really be optional here? global_key_prefix: Optional[str] = None model_key_prefix: Optional[str] = None primary_key_pattern: Optional[str] = None database: Optional[redis.Redis] = None primary_key: Optional[PrimaryKey] = None - primary_key_creator_cls: Type[PrimaryKeyCreator] = None - index_name: str = None - abstract: bool = False + primary_key_creator_cls: Optional[Type[PrimaryKeyCreator]] = None + index_name: Optional[str] = None + abstract: Optional[bool] = False class ModelMeta(ModelMetaclass): @@ -473,6 +672,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): pk: Optional[str] = Field(default=None, primary_key=True) Meta = DefaultMeta + # TODO: Missing _meta here is causing IDE warnings. class Config: orm_mode = True @@ -484,6 +684,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): __pydantic_self__.validate_primary_key() def __lt__(self, other): + """Default sort: compare all shared model fields.""" my_keys = set(self.__fields__.keys()) other_keys = set(other.__fields__.keys()) shared_keys = list(my_keys & other_keys) @@ -528,8 +729,13 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): def db(cls): return cls._meta.database + @classmethod + def find(cls, *expressions: Expression): + return FindQuery(expressions=expressions, model=cls) + @classmethod def from_redis(cls, res: Any): + # TODO: Parsing logic borrowed from redisearch-py. Evaluate. import six from six.moves import xrange, zip as izip @@ -537,20 +743,20 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): if isinstance(s, six.string_types): return s elif isinstance(s, six.binary_type): - return s.decode('utf-8','ignore') + return s.decode('utf-8', 'ignore') else: return s # Not a string we care about docs = [] step = 2 # Because the result has content - offset = 1 + offset = 1 # The first item is the count of total matches. for i in xrange(1, len(res), step): fields_offset = offset fields = dict( dict(izip(map(to_string, res[i + fields_offset][::2]), - map(to_string, res[i + fields_offset][1::2]))) + map(to_string, res[i + fields_offset][1::2]))) ) try: @@ -562,17 +768,6 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): docs.append(doc) return docs - @classmethod - def find(cls, *expressions: Expression): - query = FindQuery(expressions=expressions, model=cls) - raw_result = query.find() - return cls.from_redis(raw_result) - - @classmethod - def find_one(cls, *expressions: Expression): - query = FindQuery(expressions=expressions, model=cls, limit=1, offset=0) - raw_result = query.find() - return cls.from_redis(raw_result)[0] @classmethod def add(cls, models: Sequence['RedisModel']) -> Sequence['RedisModel']: @@ -580,6 +775,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): @classmethod def update(cls, **field_values): + """Update this model instance.""" return cls @classmethod diff --git a/redis_developer/orm/query_iterator.py b/redis_developer/orm/query_iterator.py deleted file mode 100644 index 698b743..0000000 --- a/redis_developer/orm/query_iterator.py +++ /dev/null @@ -1,55 +0,0 @@ -from redis_developer.orm.model import Expression - - -class QueryIterator: - """ - A lazy iterator that yields results from a RediSearch query. - - Examples: - - results = Model.filter(email == "a@example.com") - - # Consume all results. - for r in results: - print(r) - - # Consume an item at an index. - print(results[100]) - - # Consume a slice. - print(results[0:100]) - - # Alternative notation to consume all items. - print(results[0:-1]) - - # Specify the batch size: - results = Model.filter(email == "a@example.com", batch_size=1000) - ... - """ - def __init__(self, client, query, batch_size=100): - self.client = client - self.query = query - self.batch_size = batch_size - - def __iter__(self): - pass - - def __getattr__(self, item): - """Support getting a single value or a slice.""" - - # TODO: Query mixin? - - def filter(self, *expressions: Expression): - pass - - def exclude(self, *expressions: Expression): - pass - - def and_(self, *expressions: Expression): - pass - - def or_(self, *expressions: Expression): - pass - - def not_(self, *expressions: Expression): - pass diff --git a/tests/conftest.py b/tests/conftest.py index 237f174..3224ba7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,7 +19,7 @@ def redis(): def key_prefix(): # TODO yield "redis-developer" - + def _delete_test_keys(prefix: str, conn: Redis): for key in conn.scan_iter(f"{prefix}:*"): diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index c1027f2..1847160 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -151,51 +151,73 @@ def test_saves_many(): Member.add(members) -@pytest.mark.skip("No implemented yet") -def test_updates_a_model(): - member = Member( - first_name="Andrew", - last_name="Brookins", - email="a@example.com", - join_date=today - ) - - # Update a model instance in Redis - member.first_name = "Andrew" - member.last_name = "Brookins" - member.save() +@pytest.mark.skip("Not ready yet") +def test_updates_a_model(members): + member1, member2, member3 = members # Or, with an implicit save: - member.update(last_name="Smith") + member1.update(last_name="Smith") + assert Member.find(Member.pk == member1.pk).first() == member1 # Or, affecting multiple model instances with an implicit save: Member.find(Member.last_name == "Brookins").update(last_name="Smith") + results = Member.find(Member.last_name == "Smith") + assert sorted(results) == members + + +def test_paginate_query(members): + member1, member2, member3 = members + actual = Member.find().all(batch_size=1) + assert sorted(actual) == [member1, member2, member3] + + +def test_access_result_by_index_cached(members): + _, member2, _ = members + query = Member.find().sort_by('age') + # Load the cache, throw away the result. + query.execute() + + # Access an item that should be in the cache. + # TODO: Assert that we didn't make a Redis request. + assert query[0] == member2 + + +def test_access_result_by_index_not_cached(members): + member1, member2, member3 = members + query = Member.find().sort_by('age') + + # Assert that we don't have any models in the cache yet -- we + # haven't made any requests of Redis. + assert query._model_cache == [] + assert query[0] == member2 + assert query[1] == member1 + assert query[2] == member3 def test_exact_match_queries(members): member1, member2, member3 = members - actual = Member.find(Member.last_name == "Brookins") + actual = Member.find(Member.last_name == "Brookins").all() assert sorted(actual) == [member1, member2] actual = Member.find( - (Member.last_name == "Brookins") & ~(Member.first_name == "Andrew")) + (Member.last_name == "Brookins") & ~(Member.first_name == "Andrew")).all() assert actual == [member2] - actual = Member.find(~(Member.last_name == "Brookins")) + actual = Member.find(~(Member.last_name == "Brookins")).all() assert actual == [member3] - actual = Member.find(Member.last_name != "Brookins") + actual = Member.find(Member.last_name != "Brookins").all() assert actual == [member3] actual = Member.find( (Member.last_name == "Brookins") & (Member.first_name == "Andrew") | (Member.first_name == "Kim") - ) + ).all() assert actual == [member2, member1] - actual = Member.find_one(Member.last_name == "Brookins") - assert actual == member2 + actual = Member.find(Member.first_name == "Kim", Member.last_name == "Brookins").all() + assert actual == [member2] def test_recursive_query_resolution(members): @@ -203,7 +225,7 @@ def test_recursive_query_resolution(members): actual = Member.find((Member.last_name == "Brookins") | ( Member.age == 100 - ) & (Member.last_name == "Smith")) + ) & (Member.last_name == "Smith")).all() assert sorted(actual) == [member1, member2, member3] @@ -212,42 +234,58 @@ def test_tag_queries_boolean_logic(members): actual = Member.find( (Member.first_name == "Andrew") & - (Member.last_name == "Brookins") | (Member.last_name == "Smith")) + (Member.last_name == "Brookins") | (Member.last_name == "Smith")).all() assert sorted(actual) == [member1, member3] +def test_tag_queries_punctuation(): + member = Member( + first_name="Andrew the Michael", + last_name="St. Brookins-on-Pier", + email="a@example.com", + age=38, + join_date=today + ) + member.save() + + assert Member.find(Member.first_name == "Andrew the Michael").first() == member + assert Member.find(Member.last_name == "St. Brookins-on-Pier").first() == member + assert Member.find(Member.email == "a@example.com").first() == member + + def test_tag_queries_negation(members): member1, member2, member3 = members actual = Member.find( ~(Member.first_name == "Andrew") & - (Member.last_name == "Brookins") | (Member.last_name == "Smith")) + (Member.last_name == "Brookins") | (Member.last_name == "Smith")).all() assert sorted(actual) == [member2, member3] actual = Member.find( - (Member.first_name == "Andrew") & ~(Member.last_name == "Brookins")) + (Member.first_name == "Andrew") & ~(Member.last_name == "Brookins")).all() assert sorted(actual) == [member3] def test_numeric_queries(members): member1, member2, member3 = members - actual = Member.find_one(Member.age == 34) - assert actual == member2 + actual = Member.find(Member.age == 34).all() + assert actual == [member2] - actual = Member.find(Member.age > 34) + actual = Member.find(Member.age > 34).all() assert sorted(actual) == [member1, member3] - actual = Member.find(Member.age < 35) + actual = Member.find(Member.age < 35).all() assert actual == [member2] - actual = Member.find(Member.age <= 34) + actual = Member.find(Member.age <= 34).all() assert actual == [member2] - actual = Member.find(Member.age >= 100) + actual = Member.find(Member.age >= 100).all() assert actual == [member3] - actual = Member.find(~(Member.age == 100)) + import ipdb; ipdb.set_trace() + actual = Member.find(~(Member.age == 100)).all() assert sorted(actual) == [member1, member2]