Support negated expressions
This commit is contained in:
parent
0990c2e1b4
commit
f8c55236c1
2 changed files with 103 additions and 35 deletions
|
@ -65,16 +65,29 @@ class Operators(Enum):
|
|||
LIKE = 14
|
||||
|
||||
|
||||
@dataclass
|
||||
class NegatedExpression:
|
||||
expression: 'Expression'
|
||||
|
||||
def __invert__(self):
|
||||
return self.expression
|
||||
|
||||
|
||||
@dataclass
|
||||
class Expression:
|
||||
op: Operators
|
||||
left: Any
|
||||
right: Any
|
||||
|
||||
def __invert__(self):
|
||||
return NegatedExpression(self)
|
||||
|
||||
def __and__(self, other):
|
||||
return Expression(left=self, op=Operators.AND, right=other)
|
||||
|
||||
|
||||
ExpressionOrNegated = Union[Expression, NegatedExpression]
|
||||
|
||||
|
||||
class QueryNotSupportedError(Exception):
|
||||
"""The attempted query is not supported."""
|
||||
|
@ -120,12 +133,15 @@ class FindQuery:
|
|||
if field_type is RediSearchFieldTypes.TEXT:
|
||||
result = f"@{field_name}:"
|
||||
if op is Operators.EQ:
|
||||
result += f'"{value}"'
|
||||
result += f'"{value}"'
|
||||
elif op is Operators.NE:
|
||||
result = f'-({result}"{value}")'
|
||||
elif op is Operators.LIKE:
|
||||
result += value
|
||||
else:
|
||||
raise QueryNotSupportedError("Only equals (=) comparisons are currently supported "
|
||||
"for TEXT fields. See docs: TODO")
|
||||
# TODO: Handling TAG, TEXT switch-offs, etc.
|
||||
raise QueryNotSupportedError("Only equals (=) and not-equals (!=) comparisons are "
|
||||
"currently supported for TEXT fields. See docs: TODO")
|
||||
elif field_type is RediSearchFieldTypes.NUMERIC:
|
||||
if op is Operators.EQ:
|
||||
result += f"@{field_name}:[{value} {value}]"
|
||||
|
@ -140,14 +156,19 @@ class FindQuery:
|
|||
result += f"@{field_name}:[{value} +inf]"
|
||||
elif op is Operators.LTE:
|
||||
result += f"@{field_name}:[-inf {value}]"
|
||||
|
||||
|
||||
return result
|
||||
|
||||
def resolve_redisearch_query(self, expression: Expression):
|
||||
def resolve_redisearch_query(self, expression: ExpressionOrNegated):
|
||||
"""Resolve an expression to a string RediSearch query."""
|
||||
field_type = None
|
||||
field_name = None
|
||||
is_negated = False
|
||||
result = ""
|
||||
|
||||
if isinstance(expression, NegatedExpression):
|
||||
is_negated = True
|
||||
expression = expression.expression
|
||||
if isinstance(expression.left, Expression):
|
||||
result += f"({self.resolve_redisearch_query(expression.left)})"
|
||||
elif isinstance(expression.left, ModelField):
|
||||
|
@ -158,22 +179,33 @@ class FindQuery:
|
|||
f"or an expression enclosed in parenthesis. See docs: "
|
||||
f"TODO")
|
||||
|
||||
if isinstance(expression.right, Expression):
|
||||
right = expression.right
|
||||
right_is_negated = isinstance(right, NegatedExpression)
|
||||
|
||||
if isinstance(right, Expression) or right_is_negated:
|
||||
if expression.op == Operators.AND:
|
||||
result += " ("
|
||||
result += " "
|
||||
elif expression.op == Operators.OR:
|
||||
result += "| ("
|
||||
elif expression.op == Operators.NOT:
|
||||
result += " ~("
|
||||
result += "| "
|
||||
else:
|
||||
raise QueryNotSupportedError("You can only combine two query expressions with"
|
||||
"AND (&), OR (|), or NOT (~). See docs: TODO")
|
||||
result += f"{self.resolve_redisearch_query(expression.right)})" # NOTE: We add the closing paren
|
||||
"AND (&) or OR (|). See docs: TODO")
|
||||
|
||||
if right_is_negated:
|
||||
result += "-"
|
||||
# We're handling the RediSearch operator in this call ("-"), so resolve the
|
||||
# inner expression instead of the NegatedExpression.
|
||||
right = right.expression
|
||||
|
||||
result += f"({self.resolve_redisearch_query(right)})"
|
||||
else:
|
||||
if isinstance(expression.right, ModelField):
|
||||
if isinstance(right, ModelField):
|
||||
raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO")
|
||||
else:
|
||||
result += f"({self.resolve_value(field_name, field_type, expression.op, expression.right)})"
|
||||
result += self.resolve_value(field_name, field_type, expression.op, right)
|
||||
|
||||
if is_negated:
|
||||
result = f"-({result})"
|
||||
|
||||
return result
|
||||
|
||||
|
@ -214,6 +246,10 @@ class ExpressionProxy:
|
|||
def __ge__(self, other: Any) -> Expression:
|
||||
return Expression(left=self.field, op=Operators.GE, right=other)
|
||||
|
||||
def __invert__(self):
|
||||
import ipdb; ipdb.set_trace()
|
||||
pass
|
||||
|
||||
|
||||
def __dataclass_transform__(
|
||||
*,
|
||||
|
@ -368,7 +404,6 @@ class ModelMeta(ModelMetaclass):
|
|||
if field.field_info.primary_key:
|
||||
new_class._meta.primary_key = PrimaryKey(name=name, field=field)
|
||||
|
||||
# TODO: Raise exception here, global key prefix required?
|
||||
if not getattr(new_class._meta, 'global_key_prefix', None):
|
||||
new_class._meta.global_key_prefix = getattr(base_meta, "global_key_prefix", "")
|
||||
if not getattr(new_class._meta, 'model_key_prefix', None):
|
||||
|
@ -391,11 +426,6 @@ class ModelMeta(ModelMetaclass):
|
|||
|
||||
|
||||
class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
||||
"""
|
||||
TODO: Convert expressions to Redis commands, execute
|
||||
TODO: Key prefix vs. "key pattern" (that's actually the primary key pattern)
|
||||
TODO: Generate RediSearch schema
|
||||
"""
|
||||
pk: Optional[str] = Field(default=None, primary_key=True)
|
||||
|
||||
Meta = DefaultMeta
|
||||
|
@ -446,12 +476,12 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
|||
@classmethod
|
||||
def db(cls):
|
||||
return cls._meta.database
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_redis(cls, res: Any):
|
||||
import six
|
||||
from six.moves import xrange, zip as izip
|
||||
|
||||
|
||||
def to_string(s):
|
||||
if isinstance(s, six.string_types):
|
||||
return s
|
||||
|
@ -459,7 +489,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
|||
return s.decode('utf-8','ignore')
|
||||
else:
|
||||
return s # Not a string we care about
|
||||
|
||||
|
||||
docs = []
|
||||
step = 2 # Because the result has content
|
||||
offset = 1
|
||||
|
@ -562,7 +592,7 @@ class HashModel(RedisModel, abc.ABC):
|
|||
return f"{name} NUMERIC"
|
||||
else:
|
||||
return f"{name} TEXT"
|
||||
|
||||
|
||||
@classmethod
|
||||
def schema(cls):
|
||||
hash_prefix = cls.make_key(cls._meta.primary_key_pattern.format(pk=""))
|
||||
|
|
|
@ -126,15 +126,15 @@ def test_updates_a_model():
|
|||
)
|
||||
|
||||
# Update a model instance in Redis
|
||||
member.first_name = "Brian"
|
||||
member.last_name = "Sam-Bodden"
|
||||
member.first_name = "Andrew"
|
||||
member.last_name = "Brookins"
|
||||
member.save()
|
||||
|
||||
# Or, with an implicit save:
|
||||
member.update(first_name="Brian", last_name="Sam-Bodden")
|
||||
member.update(last_name="Smith")
|
||||
|
||||
# Or, affecting multiple model instances with an implicit save:
|
||||
Member.filter(Member.last_name == "Brookins").update(last_name="Sam-Bodden")
|
||||
Member.find(Member.last_name == "Brookins").update(last_name="Smith")
|
||||
|
||||
|
||||
def test_exact_match_queries():
|
||||
|
@ -151,21 +151,36 @@ def test_exact_match_queries():
|
|||
email="k@example.com",
|
||||
join_date=today
|
||||
)
|
||||
|
||||
member3 = Member(
|
||||
first_name="Andrew",
|
||||
last_name="Smith",
|
||||
email="as@example.com",
|
||||
join_date=today
|
||||
)
|
||||
member1.save()
|
||||
member2.save()
|
||||
member3.save()
|
||||
|
||||
import ipdb; ipdb.set_trace()
|
||||
|
||||
# # TODO: How to help IDEs know that last_name is not a str, but a wrapped expression?
|
||||
actual = Member.find(Member.last_name == "Brookins")
|
||||
assert actual == [member2, member1]
|
||||
|
||||
|
||||
# actual = Member.find(
|
||||
# (Member.last_name == "Brookins") & (~Member.first_name == "Andrew"))
|
||||
# assert actual == [member2]
|
||||
actual = Member.find(
|
||||
(Member.last_name == "Brookins") & ~(Member.first_name == "Andrew"))
|
||||
|
||||
# actual = Member.find(~Member.last_name == "Brookins")
|
||||
# assert actual == []
|
||||
assert actual == [member2]
|
||||
|
||||
# actual = Member.find(
|
||||
actual = Member.find(~(Member.last_name == "Brookins"))
|
||||
assert actual == [member3]
|
||||
|
||||
actual = Member.find(Member.last_name != "Brookins")
|
||||
assert actual == [member3]
|
||||
|
||||
|
||||
# actual = Member.find(
|
||||
# (Member.last_name == "Brookins") & (Member.first_name == "Andrew")
|
||||
# | (Member.first_name == "Kim")
|
||||
# )
|
||||
|
@ -183,3 +198,26 @@ def test_schema():
|
|||
|
||||
assert Address.schema() == "SCHEMA pk TAG SORTABLE a_string TEXT an_integer NUMERIC " \
|
||||
"a_float NUMERIC"
|
||||
|
||||
|
||||
|
||||
# ---
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlmodel import Field, Session, SQLModel, create_engine, select
|
||||
|
||||
class Hero(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
name: str
|
||||
secret_name: str
|
||||
age: Optional[int] = None
|
||||
|
||||
|
||||
engine = create_engine("sqlite:///database.db")
|
||||
|
||||
with Session(engine) as session:
|
||||
import ipdb; ipdb.set_trace()
|
||||
statement = select(Hero).where(Hero.name == "Spider-Boy")
|
||||
hero = session.exec(statement).first()
|
||||
print(hero)
|
||||
|
|
Loading…
Reference in a new issue