WIP on pagination, sorting all()/first() methods

This commit is contained in:
Andrew Brookins 2021-09-25 20:38:02 -07:00
parent 3ba45b7c87
commit 01cab5352b
4 changed files with 473 additions and 294 deletions

View file

@ -2,8 +2,8 @@ import abc
import dataclasses import dataclasses
import decimal import decimal
import operator import operator
from copy import copy, deepcopy import re
from dataclasses import dataclass from copy import deepcopy
from enum import Enum from enum import Enum
from functools import reduce from functools import reduce
from typing import ( from typing import (
@ -20,13 +20,15 @@ from typing import (
Sequence, Sequence,
no_type_check, no_type_check,
Protocol, Protocol,
List, Type List,
Type,
Pattern
) )
import uuid import uuid
import redis import redis
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
from pydantic.fields import FieldInfo as PydanticFieldInfo from pydantic.fields import FieldInfo as PydanticFieldInfo, PrivateAttr, Field
from pydantic.fields import ModelField, Undefined, UndefinedType from pydantic.fields import ModelField, Undefined, UndefinedType
from pydantic.main import ModelMetaclass from pydantic.main import ModelMetaclass
from pydantic.typing import NoArgAnyCallable from pydantic.typing import NoArgAnyCallable
@ -34,18 +36,47 @@ from pydantic.utils import Representation
from .encoders import jsonable_encoder from .encoders import jsonable_encoder
model_registry = {} model_registry = {}
_T = TypeVar("_T") _T = TypeVar("_T")
class TokenEscaper:
"""
Escape punctuation within an input string.
"""
# Characters that RediSearch requires us to escape during queries.
# Source: https://oss.redis.com/redisearch/Escaping/#the_rules_of_text_field_tokenization
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\ ]"
def __init__(self, escape_chars_re: Optional[Pattern] = None):
if escape_chars_re:
self.escaped_chars_re = escape_chars_re
else:
self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS)
def escape(self, string):
def escape_symbol(match):
value = match.group(0)
return f"\\{value}"
return self.escaped_chars_re.sub(escape_symbol, string)
escaper = TokenEscaper()
class RedisModelError(Exception): class RedisModelError(Exception):
pass pass
class NotFoundError(Exception): class NotFoundError(Exception):
pass """
A query found no results.
TODO: embed in Model class?
"""
class Operators(Enum): class Operators(Enum):
@ -61,9 +92,10 @@ class Operators(Enum):
IN = 10 IN = 10
NOT_IN = 11 NOT_IN = 11
LIKE = 12 LIKE = 12
ALL = 13
@dataclass @dataclasses.dataclass
class NegatedExpression: class NegatedExpression:
expression: 'Expression' expression: 'Expression'
@ -77,7 +109,7 @@ class NegatedExpression:
return Expression(left=self, op=Operators.OR, right=other) return Expression(left=self, op=Operators.OR, right=other)
@dataclass @dataclasses.dataclass
class Expression: class Expression:
op: Operators op: Operators
left: Any left: Any
@ -96,186 +128,6 @@ class Expression:
ExpressionOrNegated = Union[Expression, NegatedExpression] ExpressionOrNegated = Union[Expression, NegatedExpression]
class QueryNotSupportedError(Exception):
"""The attempted query is not supported."""
class RediSearchFieldTypes(Enum):
TEXT = 'TEXT'
TAG = 'TAG'
NUMERIC = 'NUMERIC'
GEO = 'GEO'
# TODO: How to handle Geo fields?
NUMERIC_TYPES = (float, int, decimal.Decimal)
@dataclass
class FindQuery:
expressions: Sequence[Expression]
expression: Expression = dataclasses.field(init=False)
query: str = dataclasses.field(init=False)
pagination: List[str] = dataclasses.field(init=False)
model: Type['RedisModel']
limit: Optional[int] = None
offset: Optional[int] = None
def __post_init__(self):
self.expression = reduce(operator.and_, self.expressions)
self.query = self.resolve_redisearch_query(self.expression)
self.pagination = self.resolve_redisearch_pagination()
def resolve_field_type(self, field: ModelField) -> RediSearchFieldTypes:
if getattr(field.field_info, 'primary_key', None) is True:
return RediSearchFieldTypes.TAG
elif getattr(field.field_info, 'full_text_search', None) is True:
return RediSearchFieldTypes.TEXT
field_type = field.outer_type_
# TODO: GEO
if any(issubclass(field_type, t) for t in NUMERIC_TYPES):
return RediSearchFieldTypes.NUMERIC
else:
# TAG fields are the default field type.
return RediSearchFieldTypes.TAG
@staticmethod
def expand_tag_value(value):
err = RedisModelError(f"Using the IN operator requires passing an iterable of "
"possible values. You passed: {value}")
if isinstance(str, value):
raise err
try:
expanded_value = "|".join(value)
except TypeError:
raise err
return expanded_value
def resolve_value(self, field_name: str, field_type: RediSearchFieldTypes,
op: Operators, value: Any) -> str:
result = ""
if field_type is RediSearchFieldTypes.TEXT:
result = f"@{field_name}:"
if op is Operators.EQ:
result += f'"{value}"'
elif op is Operators.NE:
result = f'-({result}"{value}")'
elif op is Operators.LIKE:
result += value
else:
# TODO: Handling TAG, TEXT switch-offs, etc.
raise QueryNotSupportedError("Only equals (=) and not-equals (!=) comparisons are "
"currently supported for TEXT fields. See docs: TODO")
elif field_type is RediSearchFieldTypes.NUMERIC:
if op is Operators.EQ:
result += f"@{field_name}:[{value} {value}]"
elif op is Operators.NE:
# TODO: Is this enough or do we also need a clause for all values
# ([-inf +inf]) from which we then subtract the undesirable value?
result += f"-(@{field_name}:[{value} {value}])"
elif op is Operators.GT:
result += f"@{field_name}:[({value} +inf]"
elif op is Operators.LT:
result += f"@{field_name}:[-inf ({value}]"
elif op is Operators.GE:
result += f"@{field_name}:[{value} +inf]"
elif op is Operators.LE:
result += f"@{field_name}:[-inf {value}]"
elif field_type is RediSearchFieldTypes.TAG:
if op is Operators.EQ:
result += f"@{field_name}:{{{value}}}"
elif op is Operators.NE:
result += f"-(@{field_name}:{{{value}}})"
elif op is Operators.IN:
expanded_value = self.expand_tag_value(value)
result += f"(@{field_name}:{{{expanded_value}}})"
elif op is Operators.NOT_IN:
expanded_value = self.expand_tag_value(value)
result += f"-(@{field_name}:{{{expanded_value}}})"
return result
def resolve_redisearch_pagination(self):
"""Resolve pagination options for a query."""
if not self.limit and not self.offset:
return []
offset = self.offset or 0
limit = self.limit or 10
return ["LIMIT", offset, limit]
def resolve_redisearch_query(self, expression: ExpressionOrNegated):
"""Resolve an expression to a string RediSearch query."""
field_type = None
field_name = None
encompassing_expression_is_negated = False
result = ""
if isinstance(expression, NegatedExpression):
encompassing_expression_is_negated = True
expression = expression.expression
if isinstance(expression.left, Expression) or \
isinstance(expression.left, NegatedExpression):
result += f"({self.resolve_redisearch_query(expression.left)})"
elif isinstance(expression.left, ModelField):
field_type = self.resolve_field_type(expression.left)
field_name = expression.left.name
else:
import ipdb; ipdb.set_trace()
raise QueryNotSupportedError(f"A query expression should start with either a field "
f"or an expression enclosed in parenthesis. See docs: "
f"TODO")
right = expression.right
right_is_negated = isinstance(right, NegatedExpression)
if isinstance(right, Expression) or right_is_negated:
if expression.op == Operators.AND:
result += " "
elif expression.op == Operators.OR:
result += "| "
else:
raise QueryNotSupportedError("You can only combine two query expressions with"
"AND (&) or OR (|). See docs: TODO")
if right_is_negated:
result += "-"
# We're handling the RediSearch operator in this call ("-"), so resolve the
# inner expression instead of the NegatedExpression.
right = right.expression
result += f"({self.resolve_redisearch_query(right)})"
else:
if isinstance(right, ModelField):
raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO")
else:
result += self.resolve_value(field_name, field_type, expression.op, right)
if encompassing_expression_is_negated:
result = f"-({result})"
return result
def find(self):
args = ["ft.search", self.model.Meta.index_name, self.query]
# TODO: Do we need self.pagination if we're just appending to query anyway?
if self.pagination:
args.extend(self.pagination)
return self.model.db().execute_command(*args)
class PrimaryKeyCreator(Protocol):
def create_pk(self, *args, **kwargs) -> str:
"""Create a new primary key"""
class Uuid4PrimaryKey:
def create_pk(self) -> str:
return str(uuid.uuid4())
class ExpressionProxy: class ExpressionProxy:
def __init__(self, field: ModelField): def __init__(self, field: ModelField):
self.field = field self.field = field
@ -299,6 +151,352 @@ class ExpressionProxy:
return Expression(left=self.field, op=Operators.GE, right=other) return Expression(left=self.field, op=Operators.GE, right=other)
class QueryNotSupportedError(Exception):
"""The attempted query is not supported."""
class RediSearchFieldTypes(Enum):
TEXT = 'TEXT'
TAG = 'TAG'
NUMERIC = 'NUMERIC'
GEO = 'GEO'
# TODO: How to handle Geo fields?
NUMERIC_TYPES = (float, int, decimal.Decimal)
DEFAULT_PAGE_SIZE = 10
class FindQuery(BaseModel):
expressions: Sequence[ExpressionOrNegated]
model: Type['RedisModel']
offset: int = 0
limit: int = DEFAULT_PAGE_SIZE
page_size: int = DEFAULT_PAGE_SIZE
sort_fields: Optional[List[str]] = Field(default_factory=list)
_expression: Expression = PrivateAttr(default=None)
_query: str = PrivateAttr(default=None)
_pagination: List[str] = PrivateAttr(default_factory=list)
_model_cache: Optional[List['RedisModel']] = PrivateAttr(default_factory=list)
class Config:
arbitrary_types_allowed = True
@property
def pagination(self):
if self._pagination:
return self._pagination
self._pagination = self.resolve_redisearch_pagination()
return self._pagination
@property
def expression(self):
if self._expression:
return self._expression
if self.expressions:
self._expression = reduce(operator.and_, self.expressions)
else:
self._expression = Expression(left=None, right=None, op=Operators.ALL)
return self._expression
@property
def query(self):
return self.resolve_redisearch_query(self.expression)
@validator("sort_fields")
def validate_sort_fields(cls, v, values):
model = values['model']
for sort_field in v:
field_name = sort_field.lstrip("-")
if field_name not in model.__fields__:
raise QueryNotSupportedError(f"You tried sort by {field_name}, but that field "
f"does not exist on the model {model}")
field_proxy = getattr(model, field_name)
if not getattr(field_proxy.field.field_info, 'sortable', False):
raise QueryNotSupportedError(f"You tried sort by {field_name}, but {cls} does "
"not define that field as sortable. See docs: XXX")
return v
@staticmethod
def resolve_field_type(field: ModelField) -> RediSearchFieldTypes:
if getattr(field.field_info, 'primary_key', None) is True:
return RediSearchFieldTypes.TAG
elif getattr(field.field_info, 'full_text_search', None) is True:
return RediSearchFieldTypes.TEXT
field_type = field.outer_type_
# TODO: GEO
if any(issubclass(field_type, t) for t in NUMERIC_TYPES):
return RediSearchFieldTypes.NUMERIC
else:
# TAG fields are the default field type.
# TODO: A ListField or ArrayField that supports multiple values
# and contains logic.
return RediSearchFieldTypes.TAG
@staticmethod
def expand_tag_value(value):
err = RedisModelError(f"Using the IN operator requires passing a sequence of "
"possible values. You passed: {value}")
if isinstance(str, value):
raise err
try:
expanded_value = "|".join([escaper.escape(v) for v in value])
except TypeError:
raise err
return expanded_value
@classmethod
def resolve_value(cls, field_name: str, field_type: RediSearchFieldTypes,
op: Operators, value: Any) -> str:
result = ""
if field_type is RediSearchFieldTypes.TEXT:
result = f"@{field_name}:"
if op is Operators.EQ:
result += f'"{value}"'
elif op is Operators.NE:
result = f'-({result}"{value}")'
elif op is Operators.LIKE:
result += value
else:
raise QueryNotSupportedError("Only equals (=), not-equals (!=), and like() "
"comparisons are supported for TEXT fields. See "
"docs: TODO.")
elif field_type is RediSearchFieldTypes.NUMERIC:
if op is Operators.EQ:
result += f"@{field_name}:[{value} {value}]"
elif op is Operators.NE:
# TODO: Is this enough or do we also need a clause for all values
# ([-inf +inf]) from which we then subtract the undesirable value?
result += f"-(@{field_name}:[{value} {value}])"
elif op is Operators.GT:
result += f"@{field_name}:[({value} +inf]"
elif op is Operators.LT:
result += f"@{field_name}:[-inf ({value}]"
elif op is Operators.GE:
result += f"@{field_name}:[{value} +inf]"
elif op is Operators.LE:
result += f"@{field_name}:[-inf {value}]"
elif field_type is RediSearchFieldTypes.TAG:
if op is Operators.EQ:
value = escaper.escape(value)
result += f"@{field_name}:{{{value}}}"
elif op is Operators.NE:
value = escaper.escape(value)
result += f"-(@{field_name}:{{{value}}})"
elif op is Operators.IN:
expanded_value = cls.expand_tag_value(value)
result += f"(@{field_name}:{{{expanded_value}}})"
elif op is Operators.NOT_IN:
expanded_value = cls.expand_tag_value(value)
result += f"-(@{field_name}:{{{expanded_value}}})"
return result
def resolve_redisearch_pagination(self):
"""Resolve pagination options for a query."""
return ["LIMIT", self.offset, self.limit]
def resolve_redisearch_sort_fields(self):
"""Resolve sort options for a query."""
if not self.sort_fields:
return
fields = []
for f in self.sort_fields:
direction = "desc" if f.startswith('-') else 'asc'
fields.extend([f.lstrip('-'), direction])
if self.sort_fields:
return ["SORTBY", *fields]
@classmethod
def resolve_redisearch_query(cls, expression: ExpressionOrNegated):
"""Resolve an expression to a string RediSearch query."""
field_type = None
field_name = None
encompassing_expression_is_negated = False
result = ""
if isinstance(expression, NegatedExpression):
encompassing_expression_is_negated = True
expression = expression.expression
if expression.op is Operators.ALL:
if encompassing_expression_is_negated:
# TODO: Is there a use case for this, perhaps for dynamic
# scoring purposes?
raise QueryNotSupportedError("You cannot negate a query for all results.")
return "*"
if isinstance(expression.left, Expression) or \
isinstance(expression.left, NegatedExpression):
result += f"({cls.resolve_redisearch_query(expression.left)})"
elif isinstance(expression.left, ModelField):
field_type = cls.resolve_field_type(expression.left)
field_name = expression.left.name
else:
import ipdb; ipdb.set_trace()
raise QueryNotSupportedError(f"A query expression should start with either a field "
f"or an expression enclosed in parenthesis. See docs: "
f"TODO")
right = expression.right
if isinstance(right, Expression) or isinstance(right, NegatedExpression):
if expression.op == Operators.AND:
result += " "
elif expression.op == Operators.OR:
result += "| "
else:
raise QueryNotSupportedError("You can only combine two query expressions with"
"AND (&) or OR (|). See docs: TODO")
if isinstance(right, NegatedExpression):
result += "-"
# We're handling the RediSearch operator in this call ("-"), so resolve the
# inner expression instead of the NegatedExpression.
right = right.expression
result += f"({cls.resolve_redisearch_query(right)})"
else:
if isinstance(right, ModelField):
raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO")
else:
# TODO: Optionals causing IDE errors here
result += cls.resolve_value(field_name, field_type, expression.op, right)
if encompassing_expression_is_negated:
result = f"-({result})"
return result
def execute(self, exhaust_results=True):
args = ["ft.search", self.model.Meta.index_name, self.query, *self.pagination]
if self.sort_fields:
args += self.resolve_redisearch_sort_fields()
# Reset the cache if we're executing from offset 0.
if self.offset == 0:
self._model_cache.clear()
# If the offset is greater than 0, we're paginating through a result set,
# so append the new results to results already in the cache.
raw_result = self.model.db().execute_command(*args)
count = raw_result[0]
results = self.model.from_redis(raw_result)
self._model_cache += results
if not exhaust_results:
return self._model_cache
# The query returned all results, so we have no more work to do.
if count <= len(results):
return self._model_cache
# Transparently (to the user) make subsequent requests to paginate
# through the results and finally return them all.
query = self
while True:
# Make a query for each pass of the loop, with a new offset equal to the
# current offset plus `page_size`, until we stop getting results back.
query = FindQuery(expressions=query.expressions,
model=query.model,
offset=query.offset + query.page_size,
page_size=query.page_size,
limit=query.limit)
_results = query.execute(exhaust_results=False)
if not _results:
break
self._model_cache += _results
return self._model_cache
def first(self):
query = FindQuery(expressions=self.expressions, model=self.model,
offset=0, limit=1, sort_fields=self.sort_fields)
return query.execute()[0]
def all(self, batch_size=10):
if batch_size != self.page_size:
# TODO: There's probably a copy-with-change mechanism in Pydantic,
# or can we use one from dataclasses?
query = FindQuery(expressions=self.expressions,
model=self.model,
offset=self.offset,
page_size=batch_size,
limit=batch_size,
sort_fields=self.sort_fields)
return query.execute()
return self.execute()
def sort_by(self, *fields):
if not fields:
return self
return FindQuery(expressions=self.expressions,
model=self.model,
offset=self.offset,
page_size=self.page_size,
limit=self.limit,
sort_fields=list(fields))
def update(self, **kwargs):
"""Update all matching records in this query."""
# TODO
def delete(cls, **field_values):
"""Delete all matching records in this query."""
for field_name, value in field_values:
valid_attr = hasattr(cls.model, field_name)
if not valid_attr:
raise RedisModelError(f"Can't update field {field_name} because "
f"the field does not exist on the model {cls}")
return cls
def __iter__(self):
if self._model_cache:
for m in self._model_cache:
yield m
else:
for m in self.execute():
yield m
def __getitem__(self, item: int):
"""
Given this code:
Model.find()[1000]
We should return only the 1000th result.
1. If the result is loaded in the query cache for this query,
we can return it directly from the cache.
2. If the query cache does not have enough elements to return
that result, then we should clone the current query and
give it a new offset and limit: offset=n, limit=1.
"""
if self._model_cache and len(self._model_cache) >= item:
return self._model_cache[item]
query = FindQuery(expressions=self.expressions,
model=self.model,
offset=item,
sort_fields=self.sort_fields,
limit=1)
return query.execute()[0]
class PrimaryKeyCreator(Protocol):
def create_pk(self, *args, **kwargs) -> str:
"""Create a new primary key"""
class Uuid4PrimaryKey:
def create_pk(self, *args, **kwargs) -> str:
return str(uuid.uuid4())
def __dataclass_transform__( def __dataclass_transform__(
*, *,
eq_default: bool = True, eq_default: bool = True,
@ -395,21 +593,22 @@ def Field(
return field_info return field_info
@dataclass @dataclasses.dataclass
class PrimaryKey: class PrimaryKey:
name: str name: str
field: ModelField field: ModelField
class DefaultMeta: class DefaultMeta:
# TODO: Should this really be optional here?
global_key_prefix: Optional[str] = None global_key_prefix: Optional[str] = None
model_key_prefix: Optional[str] = None model_key_prefix: Optional[str] = None
primary_key_pattern: Optional[str] = None primary_key_pattern: Optional[str] = None
database: Optional[redis.Redis] = None database: Optional[redis.Redis] = None
primary_key: Optional[PrimaryKey] = None primary_key: Optional[PrimaryKey] = None
primary_key_creator_cls: Type[PrimaryKeyCreator] = None primary_key_creator_cls: Optional[Type[PrimaryKeyCreator]] = None
index_name: str = None index_name: Optional[str] = None
abstract: bool = False abstract: Optional[bool] = False
class ModelMeta(ModelMetaclass): class ModelMeta(ModelMetaclass):
@ -473,6 +672,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
pk: Optional[str] = Field(default=None, primary_key=True) pk: Optional[str] = Field(default=None, primary_key=True)
Meta = DefaultMeta Meta = DefaultMeta
# TODO: Missing _meta here is causing IDE warnings.
class Config: class Config:
orm_mode = True orm_mode = True
@ -484,6 +684,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
__pydantic_self__.validate_primary_key() __pydantic_self__.validate_primary_key()
def __lt__(self, other): def __lt__(self, other):
"""Default sort: compare all shared model fields."""
my_keys = set(self.__fields__.keys()) my_keys = set(self.__fields__.keys())
other_keys = set(other.__fields__.keys()) other_keys = set(other.__fields__.keys())
shared_keys = list(my_keys & other_keys) shared_keys = list(my_keys & other_keys)
@ -528,8 +729,13 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
def db(cls): def db(cls):
return cls._meta.database return cls._meta.database
@classmethod
def find(cls, *expressions: Expression):
return FindQuery(expressions=expressions, model=cls)
@classmethod @classmethod
def from_redis(cls, res: Any): def from_redis(cls, res: Any):
# TODO: Parsing logic borrowed from redisearch-py. Evaluate.
import six import six
from six.moves import xrange, zip as izip from six.moves import xrange, zip as izip
@ -537,20 +743,20 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
if isinstance(s, six.string_types): if isinstance(s, six.string_types):
return s return s
elif isinstance(s, six.binary_type): elif isinstance(s, six.binary_type):
return s.decode('utf-8','ignore') return s.decode('utf-8', 'ignore')
else: else:
return s # Not a string we care about return s # Not a string we care about
docs = [] docs = []
step = 2 # Because the result has content step = 2 # Because the result has content
offset = 1 offset = 1 # The first item is the count of total matches.
for i in xrange(1, len(res), step): for i in xrange(1, len(res), step):
fields_offset = offset fields_offset = offset
fields = dict( fields = dict(
dict(izip(map(to_string, res[i + fields_offset][::2]), dict(izip(map(to_string, res[i + fields_offset][::2]),
map(to_string, res[i + fields_offset][1::2]))) map(to_string, res[i + fields_offset][1::2])))
) )
try: try:
@ -562,17 +768,6 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
docs.append(doc) docs.append(doc)
return docs return docs
@classmethod
def find(cls, *expressions: Expression):
query = FindQuery(expressions=expressions, model=cls)
raw_result = query.find()
return cls.from_redis(raw_result)
@classmethod
def find_one(cls, *expressions: Expression):
query = FindQuery(expressions=expressions, model=cls, limit=1, offset=0)
raw_result = query.find()
return cls.from_redis(raw_result)[0]
@classmethod @classmethod
def add(cls, models: Sequence['RedisModel']) -> Sequence['RedisModel']: def add(cls, models: Sequence['RedisModel']) -> Sequence['RedisModel']:
@ -580,6 +775,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
@classmethod @classmethod
def update(cls, **field_values): def update(cls, **field_values):
"""Update this model instance."""
return cls return cls
@classmethod @classmethod

View file

@ -1,55 +0,0 @@
from redis_developer.orm.model import Expression
class QueryIterator:
"""
A lazy iterator that yields results from a RediSearch query.
Examples:
results = Model.filter(email == "a@example.com")
# Consume all results.
for r in results:
print(r)
# Consume an item at an index.
print(results[100])
# Consume a slice.
print(results[0:100])
# Alternative notation to consume all items.
print(results[0:-1])
# Specify the batch size:
results = Model.filter(email == "a@example.com", batch_size=1000)
...
"""
def __init__(self, client, query, batch_size=100):
self.client = client
self.query = query
self.batch_size = batch_size
def __iter__(self):
pass
def __getattr__(self, item):
"""Support getting a single value or a slice."""
# TODO: Query mixin?
def filter(self, *expressions: Expression):
pass
def exclude(self, *expressions: Expression):
pass
def and_(self, *expressions: Expression):
pass
def or_(self, *expressions: Expression):
pass
def not_(self, *expressions: Expression):
pass

View file

@ -19,7 +19,7 @@ def redis():
def key_prefix(): def key_prefix():
# TODO # TODO
yield "redis-developer" yield "redis-developer"
def _delete_test_keys(prefix: str, conn: Redis): def _delete_test_keys(prefix: str, conn: Redis):
for key in conn.scan_iter(f"{prefix}:*"): for key in conn.scan_iter(f"{prefix}:*"):

View file

@ -151,51 +151,73 @@ def test_saves_many():
Member.add(members) Member.add(members)
@pytest.mark.skip("No implemented yet") @pytest.mark.skip("Not ready yet")
def test_updates_a_model(): def test_updates_a_model(members):
member = Member( member1, member2, member3 = members
first_name="Andrew",
last_name="Brookins",
email="a@example.com",
join_date=today
)
# Update a model instance in Redis
member.first_name = "Andrew"
member.last_name = "Brookins"
member.save()
# Or, with an implicit save: # Or, with an implicit save:
member.update(last_name="Smith") member1.update(last_name="Smith")
assert Member.find(Member.pk == member1.pk).first() == member1
# Or, affecting multiple model instances with an implicit save: # Or, affecting multiple model instances with an implicit save:
Member.find(Member.last_name == "Brookins").update(last_name="Smith") Member.find(Member.last_name == "Brookins").update(last_name="Smith")
results = Member.find(Member.last_name == "Smith")
assert sorted(results) == members
def test_paginate_query(members):
member1, member2, member3 = members
actual = Member.find().all(batch_size=1)
assert sorted(actual) == [member1, member2, member3]
def test_access_result_by_index_cached(members):
_, member2, _ = members
query = Member.find().sort_by('age')
# Load the cache, throw away the result.
query.execute()
# Access an item that should be in the cache.
# TODO: Assert that we didn't make a Redis request.
assert query[0] == member2
def test_access_result_by_index_not_cached(members):
member1, member2, member3 = members
query = Member.find().sort_by('age')
# Assert that we don't have any models in the cache yet -- we
# haven't made any requests of Redis.
assert query._model_cache == []
assert query[0] == member2
assert query[1] == member1
assert query[2] == member3
def test_exact_match_queries(members): def test_exact_match_queries(members):
member1, member2, member3 = members member1, member2, member3 = members
actual = Member.find(Member.last_name == "Brookins") actual = Member.find(Member.last_name == "Brookins").all()
assert sorted(actual) == [member1, member2] assert sorted(actual) == [member1, member2]
actual = Member.find( actual = Member.find(
(Member.last_name == "Brookins") & ~(Member.first_name == "Andrew")) (Member.last_name == "Brookins") & ~(Member.first_name == "Andrew")).all()
assert actual == [member2] assert actual == [member2]
actual = Member.find(~(Member.last_name == "Brookins")) actual = Member.find(~(Member.last_name == "Brookins")).all()
assert actual == [member3] assert actual == [member3]
actual = Member.find(Member.last_name != "Brookins") actual = Member.find(Member.last_name != "Brookins").all()
assert actual == [member3] assert actual == [member3]
actual = Member.find( actual = Member.find(
(Member.last_name == "Brookins") & (Member.first_name == "Andrew") (Member.last_name == "Brookins") & (Member.first_name == "Andrew")
| (Member.first_name == "Kim") | (Member.first_name == "Kim")
) ).all()
assert actual == [member2, member1] assert actual == [member2, member1]
actual = Member.find_one(Member.last_name == "Brookins") actual = Member.find(Member.first_name == "Kim", Member.last_name == "Brookins").all()
assert actual == member2 assert actual == [member2]
def test_recursive_query_resolution(members): def test_recursive_query_resolution(members):
@ -203,7 +225,7 @@ def test_recursive_query_resolution(members):
actual = Member.find((Member.last_name == "Brookins") | ( actual = Member.find((Member.last_name == "Brookins") | (
Member.age == 100 Member.age == 100
) & (Member.last_name == "Smith")) ) & (Member.last_name == "Smith")).all()
assert sorted(actual) == [member1, member2, member3] assert sorted(actual) == [member1, member2, member3]
@ -212,42 +234,58 @@ def test_tag_queries_boolean_logic(members):
actual = Member.find( actual = Member.find(
(Member.first_name == "Andrew") & (Member.first_name == "Andrew") &
(Member.last_name == "Brookins") | (Member.last_name == "Smith")) (Member.last_name == "Brookins") | (Member.last_name == "Smith")).all()
assert sorted(actual) == [member1, member3] assert sorted(actual) == [member1, member3]
def test_tag_queries_punctuation():
member = Member(
first_name="Andrew the Michael",
last_name="St. Brookins-on-Pier",
email="a@example.com",
age=38,
join_date=today
)
member.save()
assert Member.find(Member.first_name == "Andrew the Michael").first() == member
assert Member.find(Member.last_name == "St. Brookins-on-Pier").first() == member
assert Member.find(Member.email == "a@example.com").first() == member
def test_tag_queries_negation(members): def test_tag_queries_negation(members):
member1, member2, member3 = members member1, member2, member3 = members
actual = Member.find( actual = Member.find(
~(Member.first_name == "Andrew") & ~(Member.first_name == "Andrew") &
(Member.last_name == "Brookins") | (Member.last_name == "Smith")) (Member.last_name == "Brookins") | (Member.last_name == "Smith")).all()
assert sorted(actual) == [member2, member3] assert sorted(actual) == [member2, member3]
actual = Member.find( actual = Member.find(
(Member.first_name == "Andrew") & ~(Member.last_name == "Brookins")) (Member.first_name == "Andrew") & ~(Member.last_name == "Brookins")).all()
assert sorted(actual) == [member3] assert sorted(actual) == [member3]
def test_numeric_queries(members): def test_numeric_queries(members):
member1, member2, member3 = members member1, member2, member3 = members
actual = Member.find_one(Member.age == 34) actual = Member.find(Member.age == 34).all()
assert actual == member2 assert actual == [member2]
actual = Member.find(Member.age > 34) actual = Member.find(Member.age > 34).all()
assert sorted(actual) == [member1, member3] assert sorted(actual) == [member1, member3]
actual = Member.find(Member.age < 35) actual = Member.find(Member.age < 35).all()
assert actual == [member2] assert actual == [member2]
actual = Member.find(Member.age <= 34) actual = Member.find(Member.age <= 34).all()
assert actual == [member2] assert actual == [member2]
actual = Member.find(Member.age >= 100) actual = Member.find(Member.age >= 100).all()
assert actual == [member3] assert actual == [member3]
actual = Member.find(~(Member.age == 100)) import ipdb; ipdb.set_trace()
actual = Member.find(~(Member.age == 100)).all()
assert sorted(actual) == [member1, member2] assert sorted(actual) == [member1, member2]