Support negated expressions

This commit is contained in:
Andrew Brookins 2021-09-16 12:03:03 -07:00
parent 0990c2e1b4
commit f8c55236c1
2 changed files with 103 additions and 35 deletions

View file

@ -65,16 +65,29 @@ class Operators(Enum):
LIKE = 14
@dataclass
class NegatedExpression:
expression: 'Expression'
def __invert__(self):
return self.expression
@dataclass
class Expression:
op: Operators
left: Any
right: Any
def __invert__(self):
return NegatedExpression(self)
def __and__(self, other):
return Expression(left=self, op=Operators.AND, right=other)
ExpressionOrNegated = Union[Expression, NegatedExpression]
class QueryNotSupportedError(Exception):
"""The attempted query is not supported."""
@ -121,11 +134,14 @@ class FindQuery:
result = f"@{field_name}:"
if op is Operators.EQ:
result += f'"{value}"'
elif op is Operators.NE:
result = f'-({result}"{value}")'
elif op is Operators.LIKE:
result += value
else:
raise QueryNotSupportedError("Only equals (=) comparisons are currently supported "
"for TEXT fields. See docs: TODO")
# TODO: Handling TAG, TEXT switch-offs, etc.
raise QueryNotSupportedError("Only equals (=) and not-equals (!=) comparisons are "
"currently supported for TEXT fields. See docs: TODO")
elif field_type is RediSearchFieldTypes.NUMERIC:
if op is Operators.EQ:
result += f"@{field_name}:[{value} {value}]"
@ -143,11 +159,16 @@ class FindQuery:
return result
def resolve_redisearch_query(self, expression: Expression):
def resolve_redisearch_query(self, expression: ExpressionOrNegated):
"""Resolve an expression to a string RediSearch query."""
field_type = None
field_name = None
is_negated = False
result = ""
if isinstance(expression, NegatedExpression):
is_negated = True
expression = expression.expression
if isinstance(expression.left, Expression):
result += f"({self.resolve_redisearch_query(expression.left)})"
elif isinstance(expression.left, ModelField):
@ -158,22 +179,33 @@ class FindQuery:
f"or an expression enclosed in parenthesis. See docs: "
f"TODO")
if isinstance(expression.right, Expression):
right = expression.right
right_is_negated = isinstance(right, NegatedExpression)
if isinstance(right, Expression) or right_is_negated:
if expression.op == Operators.AND:
result += " ("
result += " "
elif expression.op == Operators.OR:
result += "| ("
elif expression.op == Operators.NOT:
result += " ~("
result += "| "
else:
raise QueryNotSupportedError("You can only combine two query expressions with"
"AND (&), OR (|), or NOT (~). See docs: TODO")
result += f"{self.resolve_redisearch_query(expression.right)})" # NOTE: We add the closing paren
"AND (&) or OR (|). See docs: TODO")
if right_is_negated:
result += "-"
# We're handling the RediSearch operator in this call ("-"), so resolve the
# inner expression instead of the NegatedExpression.
right = right.expression
result += f"({self.resolve_redisearch_query(right)})"
else:
if isinstance(expression.right, ModelField):
if isinstance(right, ModelField):
raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO")
else:
result += f"({self.resolve_value(field_name, field_type, expression.op, expression.right)})"
result += self.resolve_value(field_name, field_type, expression.op, right)
if is_negated:
result = f"-({result})"
return result
@ -214,6 +246,10 @@ class ExpressionProxy:
def __ge__(self, other: Any) -> Expression:
return Expression(left=self.field, op=Operators.GE, right=other)
def __invert__(self):
import ipdb; ipdb.set_trace()
pass
def __dataclass_transform__(
*,
@ -368,7 +404,6 @@ class ModelMeta(ModelMetaclass):
if field.field_info.primary_key:
new_class._meta.primary_key = PrimaryKey(name=name, field=field)
# TODO: Raise exception here, global key prefix required?
if not getattr(new_class._meta, 'global_key_prefix', None):
new_class._meta.global_key_prefix = getattr(base_meta, "global_key_prefix", "")
if not getattr(new_class._meta, 'model_key_prefix', None):
@ -391,11 +426,6 @@ class ModelMeta(ModelMetaclass):
class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
"""
TODO: Convert expressions to Redis commands, execute
TODO: Key prefix vs. "key pattern" (that's actually the primary key pattern)
TODO: Generate RediSearch schema
"""
pk: Optional[str] = Field(default=None, primary_key=True)
Meta = DefaultMeta

View file

@ -126,15 +126,15 @@ def test_updates_a_model():
)
# Update a model instance in Redis
member.first_name = "Brian"
member.last_name = "Sam-Bodden"
member.first_name = "Andrew"
member.last_name = "Brookins"
member.save()
# Or, with an implicit save:
member.update(first_name="Brian", last_name="Sam-Bodden")
member.update(last_name="Smith")
# Or, affecting multiple model instances with an implicit save:
Member.filter(Member.last_name == "Brookins").update(last_name="Sam-Bodden")
Member.find(Member.last_name == "Brookins").update(last_name="Smith")
def test_exact_match_queries():
@ -151,19 +151,34 @@ def test_exact_match_queries():
email="k@example.com",
join_date=today
)
member3 = Member(
first_name="Andrew",
last_name="Smith",
email="as@example.com",
join_date=today
)
member1.save()
member2.save()
member3.save()
import ipdb; ipdb.set_trace()
# # TODO: How to help IDEs know that last_name is not a str, but a wrapped expression?
actual = Member.find(Member.last_name == "Brookins")
assert actual == [member2, member1]
actual = Member.find(
(Member.last_name == "Brookins") & ~(Member.first_name == "Andrew"))
# actual = Member.find(
# (Member.last_name == "Brookins") & (~Member.first_name == "Andrew"))
# assert actual == [member2]
assert actual == [member2]
actual = Member.find(~(Member.last_name == "Brookins"))
assert actual == [member3]
actual = Member.find(Member.last_name != "Brookins")
assert actual == [member3]
# actual = Member.find(~Member.last_name == "Brookins")
# assert actual == []
# actual = Member.find(
# (Member.last_name == "Brookins") & (Member.first_name == "Andrew")
@ -183,3 +198,26 @@ def test_schema():
assert Address.schema() == "SCHEMA pk TAG SORTABLE a_string TEXT an_integer NUMERIC " \
"a_float NUMERIC"
# ---
from typing import Optional
from sqlmodel import Field, Session, SQLModel, create_engine, select
class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str
secret_name: str
age: Optional[int] = None
engine = create_engine("sqlite:///database.db")
with Session(engine) as session:
import ipdb; ipdb.set_trace()
statement = select(Hero).where(Hero.name == "Spider-Boy")
hero = session.exec(statement).first()
print(hero)