Support negated expressions
This commit is contained in:
parent
0990c2e1b4
commit
f8c55236c1
2 changed files with 103 additions and 35 deletions
|
@ -65,16 +65,29 @@ class Operators(Enum):
|
||||||
LIKE = 14
|
LIKE = 14
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NegatedExpression:
|
||||||
|
expression: 'Expression'
|
||||||
|
|
||||||
|
def __invert__(self):
|
||||||
|
return self.expression
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Expression:
|
class Expression:
|
||||||
op: Operators
|
op: Operators
|
||||||
left: Any
|
left: Any
|
||||||
right: Any
|
right: Any
|
||||||
|
|
||||||
|
def __invert__(self):
|
||||||
|
return NegatedExpression(self)
|
||||||
|
|
||||||
def __and__(self, other):
|
def __and__(self, other):
|
||||||
return Expression(left=self, op=Operators.AND, right=other)
|
return Expression(left=self, op=Operators.AND, right=other)
|
||||||
|
|
||||||
|
|
||||||
|
ExpressionOrNegated = Union[Expression, NegatedExpression]
|
||||||
|
|
||||||
|
|
||||||
class QueryNotSupportedError(Exception):
|
class QueryNotSupportedError(Exception):
|
||||||
"""The attempted query is not supported."""
|
"""The attempted query is not supported."""
|
||||||
|
@ -120,12 +133,15 @@ class FindQuery:
|
||||||
if field_type is RediSearchFieldTypes.TEXT:
|
if field_type is RediSearchFieldTypes.TEXT:
|
||||||
result = f"@{field_name}:"
|
result = f"@{field_name}:"
|
||||||
if op is Operators.EQ:
|
if op is Operators.EQ:
|
||||||
result += f'"{value}"'
|
result += f'"{value}"'
|
||||||
|
elif op is Operators.NE:
|
||||||
|
result = f'-({result}"{value}")'
|
||||||
elif op is Operators.LIKE:
|
elif op is Operators.LIKE:
|
||||||
result += value
|
result += value
|
||||||
else:
|
else:
|
||||||
raise QueryNotSupportedError("Only equals (=) comparisons are currently supported "
|
# TODO: Handling TAG, TEXT switch-offs, etc.
|
||||||
"for TEXT fields. See docs: TODO")
|
raise QueryNotSupportedError("Only equals (=) and not-equals (!=) comparisons are "
|
||||||
|
"currently supported for TEXT fields. See docs: TODO")
|
||||||
elif field_type is RediSearchFieldTypes.NUMERIC:
|
elif field_type is RediSearchFieldTypes.NUMERIC:
|
||||||
if op is Operators.EQ:
|
if op is Operators.EQ:
|
||||||
result += f"@{field_name}:[{value} {value}]"
|
result += f"@{field_name}:[{value} {value}]"
|
||||||
|
@ -140,14 +156,19 @@ class FindQuery:
|
||||||
result += f"@{field_name}:[{value} +inf]"
|
result += f"@{field_name}:[{value} +inf]"
|
||||||
elif op is Operators.LTE:
|
elif op is Operators.LTE:
|
||||||
result += f"@{field_name}:[-inf {value}]"
|
result += f"@{field_name}:[-inf {value}]"
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def resolve_redisearch_query(self, expression: Expression):
|
def resolve_redisearch_query(self, expression: ExpressionOrNegated):
|
||||||
"""Resolve an expression to a string RediSearch query."""
|
"""Resolve an expression to a string RediSearch query."""
|
||||||
field_type = None
|
field_type = None
|
||||||
field_name = None
|
field_name = None
|
||||||
|
is_negated = False
|
||||||
result = ""
|
result = ""
|
||||||
|
|
||||||
|
if isinstance(expression, NegatedExpression):
|
||||||
|
is_negated = True
|
||||||
|
expression = expression.expression
|
||||||
if isinstance(expression.left, Expression):
|
if isinstance(expression.left, Expression):
|
||||||
result += f"({self.resolve_redisearch_query(expression.left)})"
|
result += f"({self.resolve_redisearch_query(expression.left)})"
|
||||||
elif isinstance(expression.left, ModelField):
|
elif isinstance(expression.left, ModelField):
|
||||||
|
@ -158,22 +179,33 @@ class FindQuery:
|
||||||
f"or an expression enclosed in parenthesis. See docs: "
|
f"or an expression enclosed in parenthesis. See docs: "
|
||||||
f"TODO")
|
f"TODO")
|
||||||
|
|
||||||
if isinstance(expression.right, Expression):
|
right = expression.right
|
||||||
|
right_is_negated = isinstance(right, NegatedExpression)
|
||||||
|
|
||||||
|
if isinstance(right, Expression) or right_is_negated:
|
||||||
if expression.op == Operators.AND:
|
if expression.op == Operators.AND:
|
||||||
result += " ("
|
result += " "
|
||||||
elif expression.op == Operators.OR:
|
elif expression.op == Operators.OR:
|
||||||
result += "| ("
|
result += "| "
|
||||||
elif expression.op == Operators.NOT:
|
|
||||||
result += " ~("
|
|
||||||
else:
|
else:
|
||||||
raise QueryNotSupportedError("You can only combine two query expressions with"
|
raise QueryNotSupportedError("You can only combine two query expressions with"
|
||||||
"AND (&), OR (|), or NOT (~). See docs: TODO")
|
"AND (&) or OR (|). See docs: TODO")
|
||||||
result += f"{self.resolve_redisearch_query(expression.right)})" # NOTE: We add the closing paren
|
|
||||||
|
if right_is_negated:
|
||||||
|
result += "-"
|
||||||
|
# We're handling the RediSearch operator in this call ("-"), so resolve the
|
||||||
|
# inner expression instead of the NegatedExpression.
|
||||||
|
right = right.expression
|
||||||
|
|
||||||
|
result += f"({self.resolve_redisearch_query(right)})"
|
||||||
else:
|
else:
|
||||||
if isinstance(expression.right, ModelField):
|
if isinstance(right, ModelField):
|
||||||
raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO")
|
raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO")
|
||||||
else:
|
else:
|
||||||
result += f"({self.resolve_value(field_name, field_type, expression.op, expression.right)})"
|
result += self.resolve_value(field_name, field_type, expression.op, right)
|
||||||
|
|
||||||
|
if is_negated:
|
||||||
|
result = f"-({result})"
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -214,6 +246,10 @@ class ExpressionProxy:
|
||||||
def __ge__(self, other: Any) -> Expression:
|
def __ge__(self, other: Any) -> Expression:
|
||||||
return Expression(left=self.field, op=Operators.GE, right=other)
|
return Expression(left=self.field, op=Operators.GE, right=other)
|
||||||
|
|
||||||
|
def __invert__(self):
|
||||||
|
import ipdb; ipdb.set_trace()
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def __dataclass_transform__(
|
def __dataclass_transform__(
|
||||||
*,
|
*,
|
||||||
|
@ -368,7 +404,6 @@ class ModelMeta(ModelMetaclass):
|
||||||
if field.field_info.primary_key:
|
if field.field_info.primary_key:
|
||||||
new_class._meta.primary_key = PrimaryKey(name=name, field=field)
|
new_class._meta.primary_key = PrimaryKey(name=name, field=field)
|
||||||
|
|
||||||
# TODO: Raise exception here, global key prefix required?
|
|
||||||
if not getattr(new_class._meta, 'global_key_prefix', None):
|
if not getattr(new_class._meta, 'global_key_prefix', None):
|
||||||
new_class._meta.global_key_prefix = getattr(base_meta, "global_key_prefix", "")
|
new_class._meta.global_key_prefix = getattr(base_meta, "global_key_prefix", "")
|
||||||
if not getattr(new_class._meta, 'model_key_prefix', None):
|
if not getattr(new_class._meta, 'model_key_prefix', None):
|
||||||
|
@ -391,11 +426,6 @@ class ModelMeta(ModelMetaclass):
|
||||||
|
|
||||||
|
|
||||||
class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
||||||
"""
|
|
||||||
TODO: Convert expressions to Redis commands, execute
|
|
||||||
TODO: Key prefix vs. "key pattern" (that's actually the primary key pattern)
|
|
||||||
TODO: Generate RediSearch schema
|
|
||||||
"""
|
|
||||||
pk: Optional[str] = Field(default=None, primary_key=True)
|
pk: Optional[str] = Field(default=None, primary_key=True)
|
||||||
|
|
||||||
Meta = DefaultMeta
|
Meta = DefaultMeta
|
||||||
|
@ -446,12 +476,12 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
||||||
@classmethod
|
@classmethod
|
||||||
def db(cls):
|
def db(cls):
|
||||||
return cls._meta.database
|
return cls._meta.database
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_redis(cls, res: Any):
|
def from_redis(cls, res: Any):
|
||||||
import six
|
import six
|
||||||
from six.moves import xrange, zip as izip
|
from six.moves import xrange, zip as izip
|
||||||
|
|
||||||
def to_string(s):
|
def to_string(s):
|
||||||
if isinstance(s, six.string_types):
|
if isinstance(s, six.string_types):
|
||||||
return s
|
return s
|
||||||
|
@ -459,7 +489,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
||||||
return s.decode('utf-8','ignore')
|
return s.decode('utf-8','ignore')
|
||||||
else:
|
else:
|
||||||
return s # Not a string we care about
|
return s # Not a string we care about
|
||||||
|
|
||||||
docs = []
|
docs = []
|
||||||
step = 2 # Because the result has content
|
step = 2 # Because the result has content
|
||||||
offset = 1
|
offset = 1
|
||||||
|
@ -562,7 +592,7 @@ class HashModel(RedisModel, abc.ABC):
|
||||||
return f"{name} NUMERIC"
|
return f"{name} NUMERIC"
|
||||||
else:
|
else:
|
||||||
return f"{name} TEXT"
|
return f"{name} TEXT"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def schema(cls):
|
def schema(cls):
|
||||||
hash_prefix = cls.make_key(cls._meta.primary_key_pattern.format(pk=""))
|
hash_prefix = cls.make_key(cls._meta.primary_key_pattern.format(pk=""))
|
||||||
|
|
|
@ -126,15 +126,15 @@ def test_updates_a_model():
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update a model instance in Redis
|
# Update a model instance in Redis
|
||||||
member.first_name = "Brian"
|
member.first_name = "Andrew"
|
||||||
member.last_name = "Sam-Bodden"
|
member.last_name = "Brookins"
|
||||||
member.save()
|
member.save()
|
||||||
|
|
||||||
# Or, with an implicit save:
|
# Or, with an implicit save:
|
||||||
member.update(first_name="Brian", last_name="Sam-Bodden")
|
member.update(last_name="Smith")
|
||||||
|
|
||||||
# Or, affecting multiple model instances with an implicit save:
|
# Or, affecting multiple model instances with an implicit save:
|
||||||
Member.filter(Member.last_name == "Brookins").update(last_name="Sam-Bodden")
|
Member.find(Member.last_name == "Brookins").update(last_name="Smith")
|
||||||
|
|
||||||
|
|
||||||
def test_exact_match_queries():
|
def test_exact_match_queries():
|
||||||
|
@ -151,21 +151,36 @@ def test_exact_match_queries():
|
||||||
email="k@example.com",
|
email="k@example.com",
|
||||||
join_date=today
|
join_date=today
|
||||||
)
|
)
|
||||||
|
|
||||||
|
member3 = Member(
|
||||||
|
first_name="Andrew",
|
||||||
|
last_name="Smith",
|
||||||
|
email="as@example.com",
|
||||||
|
join_date=today
|
||||||
|
)
|
||||||
member1.save()
|
member1.save()
|
||||||
member2.save()
|
member2.save()
|
||||||
|
member3.save()
|
||||||
|
|
||||||
|
import ipdb; ipdb.set_trace()
|
||||||
|
|
||||||
|
# # TODO: How to help IDEs know that last_name is not a str, but a wrapped expression?
|
||||||
actual = Member.find(Member.last_name == "Brookins")
|
actual = Member.find(Member.last_name == "Brookins")
|
||||||
assert actual == [member2, member1]
|
assert actual == [member2, member1]
|
||||||
|
|
||||||
|
|
||||||
# actual = Member.find(
|
actual = Member.find(
|
||||||
# (Member.last_name == "Brookins") & (~Member.first_name == "Andrew"))
|
(Member.last_name == "Brookins") & ~(Member.first_name == "Andrew"))
|
||||||
# assert actual == [member2]
|
|
||||||
|
|
||||||
# actual = Member.find(~Member.last_name == "Brookins")
|
assert actual == [member2]
|
||||||
# assert actual == []
|
|
||||||
|
|
||||||
# actual = Member.find(
|
actual = Member.find(~(Member.last_name == "Brookins"))
|
||||||
|
assert actual == [member3]
|
||||||
|
|
||||||
|
actual = Member.find(Member.last_name != "Brookins")
|
||||||
|
assert actual == [member3]
|
||||||
|
|
||||||
|
|
||||||
|
# actual = Member.find(
|
||||||
# (Member.last_name == "Brookins") & (Member.first_name == "Andrew")
|
# (Member.last_name == "Brookins") & (Member.first_name == "Andrew")
|
||||||
# | (Member.first_name == "Kim")
|
# | (Member.first_name == "Kim")
|
||||||
# )
|
# )
|
||||||
|
@ -183,3 +198,26 @@ def test_schema():
|
||||||
|
|
||||||
assert Address.schema() == "SCHEMA pk TAG SORTABLE a_string TEXT an_integer NUMERIC " \
|
assert Address.schema() == "SCHEMA pk TAG SORTABLE a_string TEXT an_integer NUMERIC " \
|
||||||
"a_float NUMERIC"
|
"a_float NUMERIC"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# ---
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlmodel import Field, Session, SQLModel, create_engine, select
|
||||||
|
|
||||||
|
class Hero(SQLModel, table=True):
|
||||||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
|
name: str
|
||||||
|
secret_name: str
|
||||||
|
age: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
engine = create_engine("sqlite:///database.db")
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
import ipdb; ipdb.set_trace()
|
||||||
|
statement = select(Hero).where(Hero.name == "Spider-Boy")
|
||||||
|
hero = session.exec(statement).first()
|
||||||
|
print(hero)
|
||||||
|
|
Loading…
Reference in a new issue