diff --git a/redis_developer/orm/model.py b/redis_developer/orm/model.py index cb89786..1510df2 100644 --- a/redis_developer/orm/model.py +++ b/redis_developer/orm/model.py @@ -65,16 +65,29 @@ class Operators(Enum): LIKE = 14 +@dataclass +class NegatedExpression: + expression: 'Expression' + + def __invert__(self): + return self.expression + + @dataclass class Expression: op: Operators left: Any right: Any + def __invert__(self): + return NegatedExpression(self) + def __and__(self, other): return Expression(left=self, op=Operators.AND, right=other) +ExpressionOrNegated = Union[Expression, NegatedExpression] + class QueryNotSupportedError(Exception): """The attempted query is not supported.""" @@ -120,12 +133,15 @@ class FindQuery: if field_type is RediSearchFieldTypes.TEXT: result = f"@{field_name}:" if op is Operators.EQ: - result += f'"{value}"' + result += f'"{value}"' + elif op is Operators.NE: + result = f'-({result}"{value}")' elif op is Operators.LIKE: result += value else: - raise QueryNotSupportedError("Only equals (=) comparisons are currently supported " - "for TEXT fields. See docs: TODO") + # TODO: Handling TAG, TEXT switch-offs, etc. + raise QueryNotSupportedError("Only equals (=) and not-equals (!=) comparisons are " + "currently supported for TEXT fields. See docs: TODO") elif field_type is RediSearchFieldTypes.NUMERIC: if op is Operators.EQ: result += f"@{field_name}:[{value} {value}]" @@ -140,14 +156,19 @@ class FindQuery: result += f"@{field_name}:[{value} +inf]" elif op is Operators.LTE: result += f"@{field_name}:[-inf {value}]" - + return result - def resolve_redisearch_query(self, expression: Expression): + def resolve_redisearch_query(self, expression: ExpressionOrNegated): """Resolve an expression to a string RediSearch query.""" field_type = None field_name = None + is_negated = False result = "" + + if isinstance(expression, NegatedExpression): + is_negated = True + expression = expression.expression if isinstance(expression.left, Expression): result += f"({self.resolve_redisearch_query(expression.left)})" elif isinstance(expression.left, ModelField): @@ -158,22 +179,33 @@ class FindQuery: f"or an expression enclosed in parenthesis. See docs: " f"TODO") - if isinstance(expression.right, Expression): + right = expression.right + right_is_negated = isinstance(right, NegatedExpression) + + if isinstance(right, Expression) or right_is_negated: if expression.op == Operators.AND: - result += " (" + result += " " elif expression.op == Operators.OR: - result += "| (" - elif expression.op == Operators.NOT: - result += " ~(" + result += "| " else: raise QueryNotSupportedError("You can only combine two query expressions with" - "AND (&), OR (|), or NOT (~). See docs: TODO") - result += f"{self.resolve_redisearch_query(expression.right)})" # NOTE: We add the closing paren + "AND (&) or OR (|). See docs: TODO") + + if right_is_negated: + result += "-" + # We're handling the RediSearch operator in this call ("-"), so resolve the + # inner expression instead of the NegatedExpression. + right = right.expression + + result += f"({self.resolve_redisearch_query(right)})" else: - if isinstance(expression.right, ModelField): + if isinstance(right, ModelField): raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO") else: - result += f"({self.resolve_value(field_name, field_type, expression.op, expression.right)})" + result += self.resolve_value(field_name, field_type, expression.op, right) + + if is_negated: + result = f"-({result})" return result @@ -214,6 +246,10 @@ class ExpressionProxy: def __ge__(self, other: Any) -> Expression: return Expression(left=self.field, op=Operators.GE, right=other) + def __invert__(self): + import ipdb; ipdb.set_trace() + pass + def __dataclass_transform__( *, @@ -368,7 +404,6 @@ class ModelMeta(ModelMetaclass): if field.field_info.primary_key: new_class._meta.primary_key = PrimaryKey(name=name, field=field) - # TODO: Raise exception here, global key prefix required? if not getattr(new_class._meta, 'global_key_prefix', None): new_class._meta.global_key_prefix = getattr(base_meta, "global_key_prefix", "") if not getattr(new_class._meta, 'model_key_prefix', None): @@ -391,11 +426,6 @@ class ModelMeta(ModelMetaclass): class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): - """ - TODO: Convert expressions to Redis commands, execute - TODO: Key prefix vs. "key pattern" (that's actually the primary key pattern) - TODO: Generate RediSearch schema - """ pk: Optional[str] = Field(default=None, primary_key=True) Meta = DefaultMeta @@ -446,12 +476,12 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): @classmethod def db(cls): return cls._meta.database - + @classmethod def from_redis(cls, res: Any): import six from six.moves import xrange, zip as izip - + def to_string(s): if isinstance(s, six.string_types): return s @@ -459,7 +489,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): return s.decode('utf-8','ignore') else: return s # Not a string we care about - + docs = [] step = 2 # Because the result has content offset = 1 @@ -562,7 +592,7 @@ class HashModel(RedisModel, abc.ABC): return f"{name} NUMERIC" else: return f"{name} TEXT" - + @classmethod def schema(cls): hash_prefix = cls.make_key(cls._meta.primary_key_pattern.format(pk="")) diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index fbb7ad2..00f03ab 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -126,15 +126,15 @@ def test_updates_a_model(): ) # Update a model instance in Redis - member.first_name = "Brian" - member.last_name = "Sam-Bodden" + member.first_name = "Andrew" + member.last_name = "Brookins" member.save() # Or, with an implicit save: - member.update(first_name="Brian", last_name="Sam-Bodden") + member.update(last_name="Smith") # Or, affecting multiple model instances with an implicit save: - Member.filter(Member.last_name == "Brookins").update(last_name="Sam-Bodden") + Member.find(Member.last_name == "Brookins").update(last_name="Smith") def test_exact_match_queries(): @@ -151,21 +151,36 @@ def test_exact_match_queries(): email="k@example.com", join_date=today ) + + member3 = Member( + first_name="Andrew", + last_name="Smith", + email="as@example.com", + join_date=today + ) member1.save() member2.save() + member3.save() + import ipdb; ipdb.set_trace() + + # # TODO: How to help IDEs know that last_name is not a str, but a wrapped expression? actual = Member.find(Member.last_name == "Brookins") assert actual == [member2, member1] - - # actual = Member.find( - # (Member.last_name == "Brookins") & (~Member.first_name == "Andrew")) - # assert actual == [member2] + actual = Member.find( + (Member.last_name == "Brookins") & ~(Member.first_name == "Andrew")) - # actual = Member.find(~Member.last_name == "Brookins") - # assert actual == [] + assert actual == [member2] - # actual = Member.find( + actual = Member.find(~(Member.last_name == "Brookins")) + assert actual == [member3] + + actual = Member.find(Member.last_name != "Brookins") + assert actual == [member3] + + +# actual = Member.find( # (Member.last_name == "Brookins") & (Member.first_name == "Andrew") # | (Member.first_name == "Kim") # ) @@ -183,3 +198,26 @@ def test_schema(): assert Address.schema() == "SCHEMA pk TAG SORTABLE a_string TEXT an_integer NUMERIC " \ "a_float NUMERIC" + + + +# --- + +from typing import Optional + +from sqlmodel import Field, Session, SQLModel, create_engine, select + +class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + secret_name: str + age: Optional[int] = None + + +engine = create_engine("sqlite:///database.db") + +with Session(engine) as session: + import ipdb; ipdb.set_trace() + statement = select(Hero).where(Hero.name == "Spider-Boy") + hero = session.exec(statement).first() + print(hero)