WIP on pagination, sorting all()/first() methods

This commit is contained in:
Andrew Brookins 2021-09-25 20:38:02 -07:00
parent 3ba45b7c87
commit 01cab5352b
4 changed files with 473 additions and 294 deletions

View file

@ -2,8 +2,8 @@ import abc
import dataclasses
import decimal
import operator
from copy import copy, deepcopy
from dataclasses import dataclass
import re
from copy import deepcopy
from enum import Enum
from functools import reduce
from typing import (
@ -20,13 +20,15 @@ from typing import (
Sequence,
no_type_check,
Protocol,
List, Type
List,
Type,
Pattern
)
import uuid
import redis
from pydantic import BaseModel, validator
from pydantic.fields import FieldInfo as PydanticFieldInfo
from pydantic.fields import FieldInfo as PydanticFieldInfo, PrivateAttr, Field
from pydantic.fields import ModelField, Undefined, UndefinedType
from pydantic.main import ModelMetaclass
from pydantic.typing import NoArgAnyCallable
@ -34,18 +36,47 @@ from pydantic.utils import Representation
from .encoders import jsonable_encoder
model_registry = {}
_T = TypeVar("_T")
class TokenEscaper:
"""
Escape punctuation within an input string.
"""
# Characters that RediSearch requires us to escape during queries.
# Source: https://oss.redis.com/redisearch/Escaping/#the_rules_of_text_field_tokenization
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\ ]"
def __init__(self, escape_chars_re: Optional[Pattern] = None):
if escape_chars_re:
self.escaped_chars_re = escape_chars_re
else:
self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS)
def escape(self, string):
def escape_symbol(match):
value = match.group(0)
return f"\\{value}"
return self.escaped_chars_re.sub(escape_symbol, string)
escaper = TokenEscaper()
class RedisModelError(Exception):
pass
class NotFoundError(Exception):
pass
"""
A query found no results.
TODO: embed in Model class?
"""
class Operators(Enum):
@ -61,9 +92,10 @@ class Operators(Enum):
IN = 10
NOT_IN = 11
LIKE = 12
ALL = 13
@dataclass
@dataclasses.dataclass
class NegatedExpression:
expression: 'Expression'
@ -77,7 +109,7 @@ class NegatedExpression:
return Expression(left=self, op=Operators.OR, right=other)
@dataclass
@dataclasses.dataclass
class Expression:
op: Operators
left: Any
@ -96,186 +128,6 @@ class Expression:
ExpressionOrNegated = Union[Expression, NegatedExpression]
class QueryNotSupportedError(Exception):
"""The attempted query is not supported."""
class RediSearchFieldTypes(Enum):
TEXT = 'TEXT'
TAG = 'TAG'
NUMERIC = 'NUMERIC'
GEO = 'GEO'
# TODO: How to handle Geo fields?
NUMERIC_TYPES = (float, int, decimal.Decimal)
@dataclass
class FindQuery:
expressions: Sequence[Expression]
expression: Expression = dataclasses.field(init=False)
query: str = dataclasses.field(init=False)
pagination: List[str] = dataclasses.field(init=False)
model: Type['RedisModel']
limit: Optional[int] = None
offset: Optional[int] = None
def __post_init__(self):
self.expression = reduce(operator.and_, self.expressions)
self.query = self.resolve_redisearch_query(self.expression)
self.pagination = self.resolve_redisearch_pagination()
def resolve_field_type(self, field: ModelField) -> RediSearchFieldTypes:
if getattr(field.field_info, 'primary_key', None) is True:
return RediSearchFieldTypes.TAG
elif getattr(field.field_info, 'full_text_search', None) is True:
return RediSearchFieldTypes.TEXT
field_type = field.outer_type_
# TODO: GEO
if any(issubclass(field_type, t) for t in NUMERIC_TYPES):
return RediSearchFieldTypes.NUMERIC
else:
# TAG fields are the default field type.
return RediSearchFieldTypes.TAG
@staticmethod
def expand_tag_value(value):
err = RedisModelError(f"Using the IN operator requires passing an iterable of "
"possible values. You passed: {value}")
if isinstance(str, value):
raise err
try:
expanded_value = "|".join(value)
except TypeError:
raise err
return expanded_value
def resolve_value(self, field_name: str, field_type: RediSearchFieldTypes,
op: Operators, value: Any) -> str:
result = ""
if field_type is RediSearchFieldTypes.TEXT:
result = f"@{field_name}:"
if op is Operators.EQ:
result += f'"{value}"'
elif op is Operators.NE:
result = f'-({result}"{value}")'
elif op is Operators.LIKE:
result += value
else:
# TODO: Handling TAG, TEXT switch-offs, etc.
raise QueryNotSupportedError("Only equals (=) and not-equals (!=) comparisons are "
"currently supported for TEXT fields. See docs: TODO")
elif field_type is RediSearchFieldTypes.NUMERIC:
if op is Operators.EQ:
result += f"@{field_name}:[{value} {value}]"
elif op is Operators.NE:
# TODO: Is this enough or do we also need a clause for all values
# ([-inf +inf]) from which we then subtract the undesirable value?
result += f"-(@{field_name}:[{value} {value}])"
elif op is Operators.GT:
result += f"@{field_name}:[({value} +inf]"
elif op is Operators.LT:
result += f"@{field_name}:[-inf ({value}]"
elif op is Operators.GE:
result += f"@{field_name}:[{value} +inf]"
elif op is Operators.LE:
result += f"@{field_name}:[-inf {value}]"
elif field_type is RediSearchFieldTypes.TAG:
if op is Operators.EQ:
result += f"@{field_name}:{{{value}}}"
elif op is Operators.NE:
result += f"-(@{field_name}:{{{value}}})"
elif op is Operators.IN:
expanded_value = self.expand_tag_value(value)
result += f"(@{field_name}:{{{expanded_value}}})"
elif op is Operators.NOT_IN:
expanded_value = self.expand_tag_value(value)
result += f"-(@{field_name}:{{{expanded_value}}})"
return result
def resolve_redisearch_pagination(self):
"""Resolve pagination options for a query."""
if not self.limit and not self.offset:
return []
offset = self.offset or 0
limit = self.limit or 10
return ["LIMIT", offset, limit]
def resolve_redisearch_query(self, expression: ExpressionOrNegated):
"""Resolve an expression to a string RediSearch query."""
field_type = None
field_name = None
encompassing_expression_is_negated = False
result = ""
if isinstance(expression, NegatedExpression):
encompassing_expression_is_negated = True
expression = expression.expression
if isinstance(expression.left, Expression) or \
isinstance(expression.left, NegatedExpression):
result += f"({self.resolve_redisearch_query(expression.left)})"
elif isinstance(expression.left, ModelField):
field_type = self.resolve_field_type(expression.left)
field_name = expression.left.name
else:
import ipdb; ipdb.set_trace()
raise QueryNotSupportedError(f"A query expression should start with either a field "
f"or an expression enclosed in parenthesis. See docs: "
f"TODO")
right = expression.right
right_is_negated = isinstance(right, NegatedExpression)
if isinstance(right, Expression) or right_is_negated:
if expression.op == Operators.AND:
result += " "
elif expression.op == Operators.OR:
result += "| "
else:
raise QueryNotSupportedError("You can only combine two query expressions with"
"AND (&) or OR (|). See docs: TODO")
if right_is_negated:
result += "-"
# We're handling the RediSearch operator in this call ("-"), so resolve the
# inner expression instead of the NegatedExpression.
right = right.expression
result += f"({self.resolve_redisearch_query(right)})"
else:
if isinstance(right, ModelField):
raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO")
else:
result += self.resolve_value(field_name, field_type, expression.op, right)
if encompassing_expression_is_negated:
result = f"-({result})"
return result
def find(self):
args = ["ft.search", self.model.Meta.index_name, self.query]
# TODO: Do we need self.pagination if we're just appending to query anyway?
if self.pagination:
args.extend(self.pagination)
return self.model.db().execute_command(*args)
class PrimaryKeyCreator(Protocol):
def create_pk(self, *args, **kwargs) -> str:
"""Create a new primary key"""
class Uuid4PrimaryKey:
def create_pk(self) -> str:
return str(uuid.uuid4())
class ExpressionProxy:
def __init__(self, field: ModelField):
self.field = field
@ -299,6 +151,352 @@ class ExpressionProxy:
return Expression(left=self.field, op=Operators.GE, right=other)
class QueryNotSupportedError(Exception):
"""The attempted query is not supported."""
class RediSearchFieldTypes(Enum):
TEXT = 'TEXT'
TAG = 'TAG'
NUMERIC = 'NUMERIC'
GEO = 'GEO'
# TODO: How to handle Geo fields?
NUMERIC_TYPES = (float, int, decimal.Decimal)
DEFAULT_PAGE_SIZE = 10
class FindQuery(BaseModel):
expressions: Sequence[ExpressionOrNegated]
model: Type['RedisModel']
offset: int = 0
limit: int = DEFAULT_PAGE_SIZE
page_size: int = DEFAULT_PAGE_SIZE
sort_fields: Optional[List[str]] = Field(default_factory=list)
_expression: Expression = PrivateAttr(default=None)
_query: str = PrivateAttr(default=None)
_pagination: List[str] = PrivateAttr(default_factory=list)
_model_cache: Optional[List['RedisModel']] = PrivateAttr(default_factory=list)
class Config:
arbitrary_types_allowed = True
@property
def pagination(self):
if self._pagination:
return self._pagination
self._pagination = self.resolve_redisearch_pagination()
return self._pagination
@property
def expression(self):
if self._expression:
return self._expression
if self.expressions:
self._expression = reduce(operator.and_, self.expressions)
else:
self._expression = Expression(left=None, right=None, op=Operators.ALL)
return self._expression
@property
def query(self):
return self.resolve_redisearch_query(self.expression)
@validator("sort_fields")
def validate_sort_fields(cls, v, values):
model = values['model']
for sort_field in v:
field_name = sort_field.lstrip("-")
if field_name not in model.__fields__:
raise QueryNotSupportedError(f"You tried sort by {field_name}, but that field "
f"does not exist on the model {model}")
field_proxy = getattr(model, field_name)
if not getattr(field_proxy.field.field_info, 'sortable', False):
raise QueryNotSupportedError(f"You tried sort by {field_name}, but {cls} does "
"not define that field as sortable. See docs: XXX")
return v
@staticmethod
def resolve_field_type(field: ModelField) -> RediSearchFieldTypes:
if getattr(field.field_info, 'primary_key', None) is True:
return RediSearchFieldTypes.TAG
elif getattr(field.field_info, 'full_text_search', None) is True:
return RediSearchFieldTypes.TEXT
field_type = field.outer_type_
# TODO: GEO
if any(issubclass(field_type, t) for t in NUMERIC_TYPES):
return RediSearchFieldTypes.NUMERIC
else:
# TAG fields are the default field type.
# TODO: A ListField or ArrayField that supports multiple values
# and contains logic.
return RediSearchFieldTypes.TAG
@staticmethod
def expand_tag_value(value):
err = RedisModelError(f"Using the IN operator requires passing a sequence of "
"possible values. You passed: {value}")
if isinstance(str, value):
raise err
try:
expanded_value = "|".join([escaper.escape(v) for v in value])
except TypeError:
raise err
return expanded_value
@classmethod
def resolve_value(cls, field_name: str, field_type: RediSearchFieldTypes,
op: Operators, value: Any) -> str:
result = ""
if field_type is RediSearchFieldTypes.TEXT:
result = f"@{field_name}:"
if op is Operators.EQ:
result += f'"{value}"'
elif op is Operators.NE:
result = f'-({result}"{value}")'
elif op is Operators.LIKE:
result += value
else:
raise QueryNotSupportedError("Only equals (=), not-equals (!=), and like() "
"comparisons are supported for TEXT fields. See "
"docs: TODO.")
elif field_type is RediSearchFieldTypes.NUMERIC:
if op is Operators.EQ:
result += f"@{field_name}:[{value} {value}]"
elif op is Operators.NE:
# TODO: Is this enough or do we also need a clause for all values
# ([-inf +inf]) from which we then subtract the undesirable value?
result += f"-(@{field_name}:[{value} {value}])"
elif op is Operators.GT:
result += f"@{field_name}:[({value} +inf]"
elif op is Operators.LT:
result += f"@{field_name}:[-inf ({value}]"
elif op is Operators.GE:
result += f"@{field_name}:[{value} +inf]"
elif op is Operators.LE:
result += f"@{field_name}:[-inf {value}]"
elif field_type is RediSearchFieldTypes.TAG:
if op is Operators.EQ:
value = escaper.escape(value)
result += f"@{field_name}:{{{value}}}"
elif op is Operators.NE:
value = escaper.escape(value)
result += f"-(@{field_name}:{{{value}}})"
elif op is Operators.IN:
expanded_value = cls.expand_tag_value(value)
result += f"(@{field_name}:{{{expanded_value}}})"
elif op is Operators.NOT_IN:
expanded_value = cls.expand_tag_value(value)
result += f"-(@{field_name}:{{{expanded_value}}})"
return result
def resolve_redisearch_pagination(self):
"""Resolve pagination options for a query."""
return ["LIMIT", self.offset, self.limit]
def resolve_redisearch_sort_fields(self):
"""Resolve sort options for a query."""
if not self.sort_fields:
return
fields = []
for f in self.sort_fields:
direction = "desc" if f.startswith('-') else 'asc'
fields.extend([f.lstrip('-'), direction])
if self.sort_fields:
return ["SORTBY", *fields]
@classmethod
def resolve_redisearch_query(cls, expression: ExpressionOrNegated):
"""Resolve an expression to a string RediSearch query."""
field_type = None
field_name = None
encompassing_expression_is_negated = False
result = ""
if isinstance(expression, NegatedExpression):
encompassing_expression_is_negated = True
expression = expression.expression
if expression.op is Operators.ALL:
if encompassing_expression_is_negated:
# TODO: Is there a use case for this, perhaps for dynamic
# scoring purposes?
raise QueryNotSupportedError("You cannot negate a query for all results.")
return "*"
if isinstance(expression.left, Expression) or \
isinstance(expression.left, NegatedExpression):
result += f"({cls.resolve_redisearch_query(expression.left)})"
elif isinstance(expression.left, ModelField):
field_type = cls.resolve_field_type(expression.left)
field_name = expression.left.name
else:
import ipdb; ipdb.set_trace()
raise QueryNotSupportedError(f"A query expression should start with either a field "
f"or an expression enclosed in parenthesis. See docs: "
f"TODO")
right = expression.right
if isinstance(right, Expression) or isinstance(right, NegatedExpression):
if expression.op == Operators.AND:
result += " "
elif expression.op == Operators.OR:
result += "| "
else:
raise QueryNotSupportedError("You can only combine two query expressions with"
"AND (&) or OR (|). See docs: TODO")
if isinstance(right, NegatedExpression):
result += "-"
# We're handling the RediSearch operator in this call ("-"), so resolve the
# inner expression instead of the NegatedExpression.
right = right.expression
result += f"({cls.resolve_redisearch_query(right)})"
else:
if isinstance(right, ModelField):
raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO")
else:
# TODO: Optionals causing IDE errors here
result += cls.resolve_value(field_name, field_type, expression.op, right)
if encompassing_expression_is_negated:
result = f"-({result})"
return result
def execute(self, exhaust_results=True):
args = ["ft.search", self.model.Meta.index_name, self.query, *self.pagination]
if self.sort_fields:
args += self.resolve_redisearch_sort_fields()
# Reset the cache if we're executing from offset 0.
if self.offset == 0:
self._model_cache.clear()
# If the offset is greater than 0, we're paginating through a result set,
# so append the new results to results already in the cache.
raw_result = self.model.db().execute_command(*args)
count = raw_result[0]
results = self.model.from_redis(raw_result)
self._model_cache += results
if not exhaust_results:
return self._model_cache
# The query returned all results, so we have no more work to do.
if count <= len(results):
return self._model_cache
# Transparently (to the user) make subsequent requests to paginate
# through the results and finally return them all.
query = self
while True:
# Make a query for each pass of the loop, with a new offset equal to the
# current offset plus `page_size`, until we stop getting results back.
query = FindQuery(expressions=query.expressions,
model=query.model,
offset=query.offset + query.page_size,
page_size=query.page_size,
limit=query.limit)
_results = query.execute(exhaust_results=False)
if not _results:
break
self._model_cache += _results
return self._model_cache
def first(self):
query = FindQuery(expressions=self.expressions, model=self.model,
offset=0, limit=1, sort_fields=self.sort_fields)
return query.execute()[0]
def all(self, batch_size=10):
if batch_size != self.page_size:
# TODO: There's probably a copy-with-change mechanism in Pydantic,
# or can we use one from dataclasses?
query = FindQuery(expressions=self.expressions,
model=self.model,
offset=self.offset,
page_size=batch_size,
limit=batch_size,
sort_fields=self.sort_fields)
return query.execute()
return self.execute()
def sort_by(self, *fields):
if not fields:
return self
return FindQuery(expressions=self.expressions,
model=self.model,
offset=self.offset,
page_size=self.page_size,
limit=self.limit,
sort_fields=list(fields))
def update(self, **kwargs):
"""Update all matching records in this query."""
# TODO
def delete(cls, **field_values):
"""Delete all matching records in this query."""
for field_name, value in field_values:
valid_attr = hasattr(cls.model, field_name)
if not valid_attr:
raise RedisModelError(f"Can't update field {field_name} because "
f"the field does not exist on the model {cls}")
return cls
def __iter__(self):
if self._model_cache:
for m in self._model_cache:
yield m
else:
for m in self.execute():
yield m
def __getitem__(self, item: int):
"""
Given this code:
Model.find()[1000]
We should return only the 1000th result.
1. If the result is loaded in the query cache for this query,
we can return it directly from the cache.
2. If the query cache does not have enough elements to return
that result, then we should clone the current query and
give it a new offset and limit: offset=n, limit=1.
"""
if self._model_cache and len(self._model_cache) >= item:
return self._model_cache[item]
query = FindQuery(expressions=self.expressions,
model=self.model,
offset=item,
sort_fields=self.sort_fields,
limit=1)
return query.execute()[0]
class PrimaryKeyCreator(Protocol):
def create_pk(self, *args, **kwargs) -> str:
"""Create a new primary key"""
class Uuid4PrimaryKey:
def create_pk(self, *args, **kwargs) -> str:
return str(uuid.uuid4())
def __dataclass_transform__(
*,
eq_default: bool = True,
@ -395,21 +593,22 @@ def Field(
return field_info
@dataclass
@dataclasses.dataclass
class PrimaryKey:
name: str
field: ModelField
class DefaultMeta:
# TODO: Should this really be optional here?
global_key_prefix: Optional[str] = None
model_key_prefix: Optional[str] = None
primary_key_pattern: Optional[str] = None
database: Optional[redis.Redis] = None
primary_key: Optional[PrimaryKey] = None
primary_key_creator_cls: Type[PrimaryKeyCreator] = None
index_name: str = None
abstract: bool = False
primary_key_creator_cls: Optional[Type[PrimaryKeyCreator]] = None
index_name: Optional[str] = None
abstract: Optional[bool] = False
class ModelMeta(ModelMetaclass):
@ -473,6 +672,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
pk: Optional[str] = Field(default=None, primary_key=True)
Meta = DefaultMeta
# TODO: Missing _meta here is causing IDE warnings.
class Config:
orm_mode = True
@ -484,6 +684,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
__pydantic_self__.validate_primary_key()
def __lt__(self, other):
"""Default sort: compare all shared model fields."""
my_keys = set(self.__fields__.keys())
other_keys = set(other.__fields__.keys())
shared_keys = list(my_keys & other_keys)
@ -528,8 +729,13 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
def db(cls):
return cls._meta.database
@classmethod
def find(cls, *expressions: Expression):
return FindQuery(expressions=expressions, model=cls)
@classmethod
def from_redis(cls, res: Any):
# TODO: Parsing logic borrowed from redisearch-py. Evaluate.
import six
from six.moves import xrange, zip as izip
@ -543,7 +749,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
docs = []
step = 2 # Because the result has content
offset = 1
offset = 1 # The first item is the count of total matches.
for i in xrange(1, len(res), step):
fields_offset = offset
@ -562,17 +768,6 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
docs.append(doc)
return docs
@classmethod
def find(cls, *expressions: Expression):
query = FindQuery(expressions=expressions, model=cls)
raw_result = query.find()
return cls.from_redis(raw_result)
@classmethod
def find_one(cls, *expressions: Expression):
query = FindQuery(expressions=expressions, model=cls, limit=1, offset=0)
raw_result = query.find()
return cls.from_redis(raw_result)[0]
@classmethod
def add(cls, models: Sequence['RedisModel']) -> Sequence['RedisModel']:
@ -580,6 +775,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
@classmethod
def update(cls, **field_values):
"""Update this model instance."""
return cls
@classmethod

View file

@ -1,55 +0,0 @@
from redis_developer.orm.model import Expression
class QueryIterator:
"""
A lazy iterator that yields results from a RediSearch query.
Examples:
results = Model.filter(email == "a@example.com")
# Consume all results.
for r in results:
print(r)
# Consume an item at an index.
print(results[100])
# Consume a slice.
print(results[0:100])
# Alternative notation to consume all items.
print(results[0:-1])
# Specify the batch size:
results = Model.filter(email == "a@example.com", batch_size=1000)
...
"""
def __init__(self, client, query, batch_size=100):
self.client = client
self.query = query
self.batch_size = batch_size
def __iter__(self):
pass
def __getattr__(self, item):
"""Support getting a single value or a slice."""
# TODO: Query mixin?
def filter(self, *expressions: Expression):
pass
def exclude(self, *expressions: Expression):
pass
def and_(self, *expressions: Expression):
pass
def or_(self, *expressions: Expression):
pass
def not_(self, *expressions: Expression):
pass

View file

@ -151,51 +151,73 @@ def test_saves_many():
Member.add(members)
@pytest.mark.skip("No implemented yet")
def test_updates_a_model():
member = Member(
first_name="Andrew",
last_name="Brookins",
email="a@example.com",
join_date=today
)
# Update a model instance in Redis
member.first_name = "Andrew"
member.last_name = "Brookins"
member.save()
@pytest.mark.skip("Not ready yet")
def test_updates_a_model(members):
member1, member2, member3 = members
# Or, with an implicit save:
member.update(last_name="Smith")
member1.update(last_name="Smith")
assert Member.find(Member.pk == member1.pk).first() == member1
# Or, affecting multiple model instances with an implicit save:
Member.find(Member.last_name == "Brookins").update(last_name="Smith")
results = Member.find(Member.last_name == "Smith")
assert sorted(results) == members
def test_paginate_query(members):
member1, member2, member3 = members
actual = Member.find().all(batch_size=1)
assert sorted(actual) == [member1, member2, member3]
def test_access_result_by_index_cached(members):
_, member2, _ = members
query = Member.find().sort_by('age')
# Load the cache, throw away the result.
query.execute()
# Access an item that should be in the cache.
# TODO: Assert that we didn't make a Redis request.
assert query[0] == member2
def test_access_result_by_index_not_cached(members):
member1, member2, member3 = members
query = Member.find().sort_by('age')
# Assert that we don't have any models in the cache yet -- we
# haven't made any requests of Redis.
assert query._model_cache == []
assert query[0] == member2
assert query[1] == member1
assert query[2] == member3
def test_exact_match_queries(members):
member1, member2, member3 = members
actual = Member.find(Member.last_name == "Brookins")
actual = Member.find(Member.last_name == "Brookins").all()
assert sorted(actual) == [member1, member2]
actual = Member.find(
(Member.last_name == "Brookins") & ~(Member.first_name == "Andrew"))
(Member.last_name == "Brookins") & ~(Member.first_name == "Andrew")).all()
assert actual == [member2]
actual = Member.find(~(Member.last_name == "Brookins"))
actual = Member.find(~(Member.last_name == "Brookins")).all()
assert actual == [member3]
actual = Member.find(Member.last_name != "Brookins")
actual = Member.find(Member.last_name != "Brookins").all()
assert actual == [member3]
actual = Member.find(
(Member.last_name == "Brookins") & (Member.first_name == "Andrew")
| (Member.first_name == "Kim")
)
).all()
assert actual == [member2, member1]
actual = Member.find_one(Member.last_name == "Brookins")
assert actual == member2
actual = Member.find(Member.first_name == "Kim", Member.last_name == "Brookins").all()
assert actual == [member2]
def test_recursive_query_resolution(members):
@ -203,7 +225,7 @@ def test_recursive_query_resolution(members):
actual = Member.find((Member.last_name == "Brookins") | (
Member.age == 100
) & (Member.last_name == "Smith"))
) & (Member.last_name == "Smith")).all()
assert sorted(actual) == [member1, member2, member3]
@ -212,42 +234,58 @@ def test_tag_queries_boolean_logic(members):
actual = Member.find(
(Member.first_name == "Andrew") &
(Member.last_name == "Brookins") | (Member.last_name == "Smith"))
(Member.last_name == "Brookins") | (Member.last_name == "Smith")).all()
assert sorted(actual) == [member1, member3]
def test_tag_queries_punctuation():
member = Member(
first_name="Andrew the Michael",
last_name="St. Brookins-on-Pier",
email="a@example.com",
age=38,
join_date=today
)
member.save()
assert Member.find(Member.first_name == "Andrew the Michael").first() == member
assert Member.find(Member.last_name == "St. Brookins-on-Pier").first() == member
assert Member.find(Member.email == "a@example.com").first() == member
def test_tag_queries_negation(members):
member1, member2, member3 = members
actual = Member.find(
~(Member.first_name == "Andrew") &
(Member.last_name == "Brookins") | (Member.last_name == "Smith"))
(Member.last_name == "Brookins") | (Member.last_name == "Smith")).all()
assert sorted(actual) == [member2, member3]
actual = Member.find(
(Member.first_name == "Andrew") & ~(Member.last_name == "Brookins"))
(Member.first_name == "Andrew") & ~(Member.last_name == "Brookins")).all()
assert sorted(actual) == [member3]
def test_numeric_queries(members):
member1, member2, member3 = members
actual = Member.find_one(Member.age == 34)
assert actual == member2
actual = Member.find(Member.age == 34).all()
assert actual == [member2]
actual = Member.find(Member.age > 34)
actual = Member.find(Member.age > 34).all()
assert sorted(actual) == [member1, member3]
actual = Member.find(Member.age < 35)
actual = Member.find(Member.age < 35).all()
assert actual == [member2]
actual = Member.find(Member.age <= 34)
actual = Member.find(Member.age <= 34).all()
assert actual == [member2]
actual = Member.find(Member.age >= 100)
actual = Member.find(Member.age >= 100).all()
assert actual == [member3]
actual = Member.find(~(Member.age == 100))
import ipdb; ipdb.set_trace()
actual = Member.find(~(Member.age == 100)).all()
assert sorted(actual) == [member1, member2]