Support negated expressions

This commit is contained in:
Andrew Brookins 2021-09-16 12:03:03 -07:00
parent 0990c2e1b4
commit f8c55236c1
2 changed files with 103 additions and 35 deletions

View file

@ -65,16 +65,29 @@ class Operators(Enum):
LIKE = 14 LIKE = 14
@dataclass
class NegatedExpression:
expression: 'Expression'
def __invert__(self):
return self.expression
@dataclass @dataclass
class Expression: class Expression:
op: Operators op: Operators
left: Any left: Any
right: Any right: Any
def __invert__(self):
return NegatedExpression(self)
def __and__(self, other): def __and__(self, other):
return Expression(left=self, op=Operators.AND, right=other) return Expression(left=self, op=Operators.AND, right=other)
ExpressionOrNegated = Union[Expression, NegatedExpression]
class QueryNotSupportedError(Exception): class QueryNotSupportedError(Exception):
"""The attempted query is not supported.""" """The attempted query is not supported."""
@ -121,11 +134,14 @@ class FindQuery:
result = f"@{field_name}:" result = f"@{field_name}:"
if op is Operators.EQ: if op is Operators.EQ:
result += f'"{value}"' result += f'"{value}"'
elif op is Operators.NE:
result = f'-({result}"{value}")'
elif op is Operators.LIKE: elif op is Operators.LIKE:
result += value result += value
else: else:
raise QueryNotSupportedError("Only equals (=) comparisons are currently supported " # TODO: Handling TAG, TEXT switch-offs, etc.
"for TEXT fields. See docs: TODO") raise QueryNotSupportedError("Only equals (=) and not-equals (!=) comparisons are "
"currently supported for TEXT fields. See docs: TODO")
elif field_type is RediSearchFieldTypes.NUMERIC: elif field_type is RediSearchFieldTypes.NUMERIC:
if op is Operators.EQ: if op is Operators.EQ:
result += f"@{field_name}:[{value} {value}]" result += f"@{field_name}:[{value} {value}]"
@ -143,11 +159,16 @@ class FindQuery:
return result return result
def resolve_redisearch_query(self, expression: Expression): def resolve_redisearch_query(self, expression: ExpressionOrNegated):
"""Resolve an expression to a string RediSearch query.""" """Resolve an expression to a string RediSearch query."""
field_type = None field_type = None
field_name = None field_name = None
is_negated = False
result = "" result = ""
if isinstance(expression, NegatedExpression):
is_negated = True
expression = expression.expression
if isinstance(expression.left, Expression): if isinstance(expression.left, Expression):
result += f"({self.resolve_redisearch_query(expression.left)})" result += f"({self.resolve_redisearch_query(expression.left)})"
elif isinstance(expression.left, ModelField): elif isinstance(expression.left, ModelField):
@ -158,22 +179,33 @@ class FindQuery:
f"or an expression enclosed in parenthesis. See docs: " f"or an expression enclosed in parenthesis. See docs: "
f"TODO") f"TODO")
if isinstance(expression.right, Expression): right = expression.right
right_is_negated = isinstance(right, NegatedExpression)
if isinstance(right, Expression) or right_is_negated:
if expression.op == Operators.AND: if expression.op == Operators.AND:
result += " (" result += " "
elif expression.op == Operators.OR: elif expression.op == Operators.OR:
result += "| (" result += "| "
elif expression.op == Operators.NOT:
result += " ~("
else: else:
raise QueryNotSupportedError("You can only combine two query expressions with" raise QueryNotSupportedError("You can only combine two query expressions with"
"AND (&), OR (|), or NOT (~). See docs: TODO") "AND (&) or OR (|). See docs: TODO")
result += f"{self.resolve_redisearch_query(expression.right)})" # NOTE: We add the closing paren
if right_is_negated:
result += "-"
# We're handling the RediSearch operator in this call ("-"), so resolve the
# inner expression instead of the NegatedExpression.
right = right.expression
result += f"({self.resolve_redisearch_query(right)})"
else: else:
if isinstance(expression.right, ModelField): if isinstance(right, ModelField):
raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO") raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO")
else: else:
result += f"({self.resolve_value(field_name, field_type, expression.op, expression.right)})" result += self.resolve_value(field_name, field_type, expression.op, right)
if is_negated:
result = f"-({result})"
return result return result
@ -214,6 +246,10 @@ class ExpressionProxy:
def __ge__(self, other: Any) -> Expression: def __ge__(self, other: Any) -> Expression:
return Expression(left=self.field, op=Operators.GE, right=other) return Expression(left=self.field, op=Operators.GE, right=other)
def __invert__(self):
import ipdb; ipdb.set_trace()
pass
def __dataclass_transform__( def __dataclass_transform__(
*, *,
@ -368,7 +404,6 @@ class ModelMeta(ModelMetaclass):
if field.field_info.primary_key: if field.field_info.primary_key:
new_class._meta.primary_key = PrimaryKey(name=name, field=field) new_class._meta.primary_key = PrimaryKey(name=name, field=field)
# TODO: Raise exception here, global key prefix required?
if not getattr(new_class._meta, 'global_key_prefix', None): if not getattr(new_class._meta, 'global_key_prefix', None):
new_class._meta.global_key_prefix = getattr(base_meta, "global_key_prefix", "") new_class._meta.global_key_prefix = getattr(base_meta, "global_key_prefix", "")
if not getattr(new_class._meta, 'model_key_prefix', None): if not getattr(new_class._meta, 'model_key_prefix', None):
@ -391,11 +426,6 @@ class ModelMeta(ModelMetaclass):
class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
"""
TODO: Convert expressions to Redis commands, execute
TODO: Key prefix vs. "key pattern" (that's actually the primary key pattern)
TODO: Generate RediSearch schema
"""
pk: Optional[str] = Field(default=None, primary_key=True) pk: Optional[str] = Field(default=None, primary_key=True)
Meta = DefaultMeta Meta = DefaultMeta

View file

@ -126,15 +126,15 @@ def test_updates_a_model():
) )
# Update a model instance in Redis # Update a model instance in Redis
member.first_name = "Brian" member.first_name = "Andrew"
member.last_name = "Sam-Bodden" member.last_name = "Brookins"
member.save() member.save()
# Or, with an implicit save: # Or, with an implicit save:
member.update(first_name="Brian", last_name="Sam-Bodden") member.update(last_name="Smith")
# Or, affecting multiple model instances with an implicit save: # Or, affecting multiple model instances with an implicit save:
Member.filter(Member.last_name == "Brookins").update(last_name="Sam-Bodden") Member.find(Member.last_name == "Brookins").update(last_name="Smith")
def test_exact_match_queries(): def test_exact_match_queries():
@ -151,19 +151,34 @@ def test_exact_match_queries():
email="k@example.com", email="k@example.com",
join_date=today join_date=today
) )
member3 = Member(
first_name="Andrew",
last_name="Smith",
email="as@example.com",
join_date=today
)
member1.save() member1.save()
member2.save() member2.save()
member3.save()
import ipdb; ipdb.set_trace()
# # TODO: How to help IDEs know that last_name is not a str, but a wrapped expression?
actual = Member.find(Member.last_name == "Brookins") actual = Member.find(Member.last_name == "Brookins")
assert actual == [member2, member1] assert actual == [member2, member1]
actual = Member.find(
(Member.last_name == "Brookins") & ~(Member.first_name == "Andrew"))
# actual = Member.find( assert actual == [member2]
# (Member.last_name == "Brookins") & (~Member.first_name == "Andrew"))
# assert actual == [member2] actual = Member.find(~(Member.last_name == "Brookins"))
assert actual == [member3]
actual = Member.find(Member.last_name != "Brookins")
assert actual == [member3]
# actual = Member.find(~Member.last_name == "Brookins")
# assert actual == []
# actual = Member.find( # actual = Member.find(
# (Member.last_name == "Brookins") & (Member.first_name == "Andrew") # (Member.last_name == "Brookins") & (Member.first_name == "Andrew")
@ -183,3 +198,26 @@ def test_schema():
assert Address.schema() == "SCHEMA pk TAG SORTABLE a_string TEXT an_integer NUMERIC " \ assert Address.schema() == "SCHEMA pk TAG SORTABLE a_string TEXT an_integer NUMERIC " \
"a_float NUMERIC" "a_float NUMERIC"
# ---
from typing import Optional
from sqlmodel import Field, Session, SQLModel, create_engine, select
class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str
secret_name: str
age: Optional[int] = None
engine = create_engine("sqlite:///database.db")
with Session(engine) as session:
import ipdb; ipdb.set_trace()
statement = select(Hero).where(Hero.name == "Spider-Boy")
hero = session.exec(statement).first()
print(hero)