From 0990c2e1b47204b5a4b9171ded0b9d6e787e3487 Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Wed, 15 Sep 2021 17:41:45 -0700 Subject: [PATCH] Add basic migrations, query expression resolver --- poetry.lock | 43 ++- pyproject.toml | 5 + redis_developer/{orm => }/connections.py | 2 +- redis_developer/orm/cli/__init__.py | 0 redis_developer/orm/cli/migrate.py | 16 + redis_developer/orm/migrations/__init__.py | 0 redis_developer/orm/migrations/migrator.py | 116 +++++++ redis_developer/orm/model.py | 352 +++++++++++++++++---- redis_developer/orm/models.py | 30 ++ redis_developer/orm/query_iterator.py | 55 ++++ redis_developer/orm/query_resolver.py | 106 +++++++ tests/__init__.py | 0 tests/conftest.py | 31 ++ tests/test_hash_model.py | 79 +++-- tests/test_json_model.py | 5 +- 15 files changed, 752 insertions(+), 88 deletions(-) rename redis_developer/{orm => }/connections.py (55%) create mode 100644 redis_developer/orm/cli/__init__.py create mode 100644 redis_developer/orm/cli/migrate.py create mode 100644 redis_developer/orm/migrations/__init__.py create mode 100644 redis_developer/orm/migrations/migrator.py create mode 100644 redis_developer/orm/models.py create mode 100644 redis_developer/orm/query_iterator.py create mode 100644 redis_developer/orm/query_resolver.py create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py diff --git a/poetry.lock b/poetry.lock index 244a320..e6b534d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -59,17 +59,28 @@ category = "dev" optional = false python-versions = "*" +[[package]] +name = "click" +version = "8.0.1" +description = "Composable command line interface toolkit" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + [[package]] name = "colorama" version = "0.4.4" description = "Cross-platform colored terminal text." -category = "dev" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" [[package]] name = "decorator" -version = "5.0.9" +version = "5.1.0" description = "Decorators for Humans" category = "dev" optional = false @@ -145,7 +156,7 @@ testing = ["Django (<3.1)", "colorama", "docopt", "pytest (<6.0.0)"] [[package]] name = "matplotlib-inline" -version = "0.1.2" +version = "0.1.3" description = "Inline Matplotlib backend for Jupyter" category = "dev" optional = false @@ -298,6 +309,14 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" [package.extras] hiredis = ["hiredis (>=0.1.3)"] +[[package]] +name = "six" +version = "1.16.0" +description = "Python 2 and 3 compatibility utilities" +category = "main" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" + [[package]] name = "toml" version = "0.10.2" @@ -336,7 +355,7 @@ python-versions = "*" [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "e5ac777000236190a585bef489a7fbe744a0d0dc328c001cb198a34746bd24a3" +content-hash = "b3f0c7c5701bb2c317df7f2f42218ef6c38d9b035ed49c9a9469df6a6727973c" [metadata.files] aioredis = [ @@ -363,13 +382,17 @@ backcall = [ {file = "backcall-0.2.0-py2.py3-none-any.whl", hash = "sha256:fbbce6a29f263178a1f7915c1940bde0ec2b2a967566fe1c65c1dfb7422bd255"}, {file = "backcall-0.2.0.tar.gz", hash = "sha256:5cbdbf27be5e7cfadb448baf0aa95508f91f2bbc6c6437cd9cd06e2a4c215e1e"}, ] +click = [ + {file = "click-8.0.1-py3-none-any.whl", hash = "sha256:fba402a4a47334742d782209a7c79bc448911afe1149d07bdabdf480b3e2f4b6"}, + {file = "click-8.0.1.tar.gz", hash = "sha256:8c04c11192119b1ef78ea049e0a6f0463e4c48ef00a30160c704337586f3ad7a"}, +] colorama = [ {file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"}, {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"}, ] decorator = [ - {file = "decorator-5.0.9-py3-none-any.whl", hash = "sha256:6e5c199c16f7a9f0e3a61a4a54b3d27e7dad0dbdde92b944426cb20914376323"}, - {file = "decorator-5.0.9.tar.gz", hash = "sha256:72ecfba4320a893c53f9706bebb2d55c270c1e51a28789361aa93e4a21319ed5"}, + {file = "decorator-5.1.0-py3-none-any.whl", hash = "sha256:7b12e7c3c6ab203a29e157335e9122cb03de9ab7264b137594103fd4a683b374"}, + {file = "decorator-5.1.0.tar.gz", hash = "sha256:e59913af105b9860aa2c8d3272d9de5a56a4e608db9a2f167a8480b323d529a7"}, ] iniconfig = [ {file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"}, @@ -387,8 +410,8 @@ jedi = [ {file = "jedi-0.18.0.tar.gz", hash = "sha256:92550a404bad8afed881a137ec9a461fed49eca661414be45059329614ed0707"}, ] matplotlib-inline = [ - {file = "matplotlib-inline-0.1.2.tar.gz", hash = "sha256:f41d5ff73c9f5385775d5c0bc13b424535c8402fe70ea8210f93e11f3683993e"}, - {file = "matplotlib_inline-0.1.2-py3-none-any.whl", hash = "sha256:5cf1176f554abb4fa98cb362aa2b55c500147e4bdbb07e3fda359143e1da0811"}, + {file = "matplotlib-inline-0.1.3.tar.gz", hash = "sha256:a04bfba22e0d1395479f866853ec1ee28eea1485c1d69a6faf00dc3e24ff34ee"}, + {file = "matplotlib_inline-0.1.3-py3-none-any.whl", hash = "sha256:aed605ba3b72462d64d475a21a9296f400a19c4f74a31b59103d2a99ffd5aa5c"}, ] packaging = [ {file = "packaging-21.0-py3-none-any.whl", hash = "sha256:c86254f9220d55e31cc94d69bade760f0847da8000def4dfe1c6b872fd14ff14"}, @@ -462,6 +485,10 @@ redis = [ {file = "redis-3.5.3-py2.py3-none-any.whl", hash = "sha256:432b788c4530cfe16d8d943a09d40ca6c16149727e4afe8c2c9d5580c59d9f24"}, {file = "redis-3.5.3.tar.gz", hash = "sha256:0e7e0cfca8660dea8b7d5cd8c4f6c5e29e11f31158c0b0ae91a397f00e5a05a2"}, ] +six = [ + {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, + {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, +] toml = [ {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"}, {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, diff --git a/pyproject.toml b/pyproject.toml index 2fab746..206bc5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,11 +10,16 @@ python = "^3.8" redis = "^3.5.3" aioredis = "^2.0.0" pydantic = "^1.8.2" +click = "^8.0.1" +six = "^1.16.0" [tool.poetry.dev-dependencies] pytest = "^6.2.4" ipdb = "^0.13.9" +[tool.poetry.scripts] +migrate = "redis_developer.orm.cli.migrate:migrate" + [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" diff --git a/redis_developer/orm/connections.py b/redis_developer/connections.py similarity index 55% rename from redis_developer/orm/connections.py rename to redis_developer/connections.py index 0124d33..f8ba77f 100644 --- a/redis_developer/orm/connections.py +++ b/redis_developer/connections.py @@ -2,4 +2,4 @@ import redis def get_redis_connection() -> redis.Redis: - return redis.Redis() + return redis.Redis(decode_responses=True) diff --git a/redis_developer/orm/cli/__init__.py b/redis_developer/orm/cli/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/redis_developer/orm/cli/migrate.py b/redis_developer/orm/cli/migrate.py new file mode 100644 index 0000000..324f9c5 --- /dev/null +++ b/redis_developer/orm/cli/migrate.py @@ -0,0 +1,16 @@ +import click +from redis_developer.orm.migrations.migrator import Migrator + + +@click.command() +@click.option("--module", default="redis_developer") +def migrate(module): + migrator = Migrator(module) + + if migrator.migrations: + print("Pending migrations:") + for migration in migrator.migrations: + print(migration) + + if input(f"Run migrations? (y/n) ") == "y": + migrator.run() diff --git a/redis_developer/orm/migrations/__init__.py b/redis_developer/orm/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/redis_developer/orm/migrations/migrator.py b/redis_developer/orm/migrations/migrator.py new file mode 100644 index 0000000..173bd33 --- /dev/null +++ b/redis_developer/orm/migrations/migrator.py @@ -0,0 +1,116 @@ +import hashlib +import logging +from dataclasses import dataclass +from enum import Enum +from typing import Optional + +from redis import ResponseError + +from redis_developer.connections import get_redis_connection +from redis_developer.orm.model import model_registry + +redis = get_redis_connection() +log = logging.getLogger(__name__) + + +import importlib +import pkgutil + + +def import_submodules(root_module_name: str): + """Import all submodules of a module, recursively.""" + # TODO: Call this without specifying a module name, to import everything? + root_module = importlib.import_module(root_module_name) + for loader, module_name, is_pkg in pkgutil.walk_packages( + root_module.__path__, root_module.__name__ + '.'): + importlib.import_module(module_name) + + +def schema_hash_key(index_name): + return f"{index_name}:hash" + + +def create_index(index_name, schema, current_hash): + redis.execute_command(f"ft.create {index_name} " + f"{schema}") + redis.set(schema_hash_key(index_name), current_hash) + + +class MigrationAction(Enum): + CREATE = 2 + DROP = 1 + + +@dataclass +class IndexMigration: + model_name: str + index_name: str + schema: str + hash: str + action: MigrationAction + previous_hash: Optional[str] = None + + def run(self): + if self.action is MigrationAction.CREATE: + self.create() + elif self.action is MigrationAction.DROP: + self.drop() + + def create(self): + return create_index(self.index_name, self.schema, self.hash) + + def drop(self): + redis.execute_command(f"FT.DROPINDEX {self.index_name}") + + +class Migrator: + def __init__(self, module=None): + # Try to load any modules found under the given path or module name. + if module: + import_submodules(module) + + self.migrations = [] + + for name, cls in model_registry.items(): + hash_key = schema_hash_key(cls.Meta.index_name) + try: + schema = cls.schema() + except NotImplementedError: + log.info("Skipping migrations for %s", name) + continue + current_hash = hashlib.sha1(schema.encode("utf-8")).hexdigest() + + try: + redis.execute_command("ft.info", cls.Meta.index_name) + except ResponseError: + self.migrations.append( + IndexMigration(name, cls.Meta.index_name, schema, current_hash, + MigrationAction.CREATE)) + + stored_hash = redis.get(hash_key) + schema_out_of_date = current_hash != stored_hash + + if schema_out_of_date: + # TODO: Switch out schema with an alias to avoid downtime -- separate migration? + self.migrations.append( + IndexMigration(name, cls.Meta.index_name, schema, current_hash, + MigrationAction.DROP, stored_hash)) + self.migrations.append( + IndexMigration(name, cls.Meta.index_name, schema, current_hash, + MigrationAction.CREATE, stored_hash)) + + @property + def valid_migrations(self): + return self.missing_indexes.keys() + self.out_of_date_indexes.keys() + + def validate_migration(self, model_class_name): + if model_class_name not in self.valid_migrations: + migrations = ", ".join(self.valid_migrations) + raise RuntimeError(f"No migration found for {model_class_name}." + f"Valid migrations are: {migrations}") + + def run(self): + # TODO: Migration history + # TODO: Dry run with output + for migration in self.migrations: + migration.run() diff --git a/redis_developer/orm/model.py b/redis_developer/orm/model.py index b5501cb..cb89786 100644 --- a/redis_developer/orm/model.py +++ b/redis_developer/orm/model.py @@ -1,5 +1,11 @@ +import abc +import dataclasses +import decimal +import operator +from copy import copy from dataclasses import dataclass from enum import Enum +from functools import reduce from typing import ( AbstractSet, Any, @@ -22,11 +28,15 @@ import redis from pydantic import BaseModel, validator from pydantic.fields import FieldInfo as PydanticFieldInfo from pydantic.fields import ModelField, Undefined, UndefinedType +from pydantic.main import ModelMetaclass from pydantic.typing import NoArgAnyCallable from pydantic.utils import Representation from .encoders import jsonable_encoder + +model_registry = {} + _T = TypeVar("_T") @@ -38,17 +48,138 @@ class NotFoundError(Exception): pass -class Operations(Enum): +class Operators(Enum): EQ = 1 - LT = 2 - GT = 3 + NE = 2 + LT = 3 + LE = 4 + GT = 5 + GE = 6 + OR = 7 + AND = 8 + NOT = 9 + IN = 10 + NOT_IN = 11 + GTE = 12 + LTE = 13 + LIKE = 14 @dataclass class Expression: - field: ModelField - op: Operations - right_value: Any + op: Operators + left: Any + right: Any + + def __and__(self, other): + return Expression(left=self, op=Operators.AND, right=other) + + + +class QueryNotSupportedError(Exception): + """The attempted query is not supported.""" + + +class RediSearchFieldTypes(Enum): + TEXT = 'TEXT' + TAG = 'TAG' + NUMERIC = 'NUMERIC' + GEO = 'GEO' + + +# TODO: How to handle Geo fields? +NUMERIC_TYPES = (float, int, decimal.Decimal) + + +@dataclass +class FindQuery: + expressions: Sequence[Expression] + expression: Expression = dataclasses.field(init=False) + query: str = dataclasses.field(init=False) + model: Type['RedisModel'] + + def __post_init__(self): + self.expression = reduce(operator.and_, self.expressions) + self.query = self.resolve_redisearch_query(self.expression) + + def resolve_field_type(self, field: ModelField) -> RediSearchFieldTypes: + if getattr(field.field_info, 'primary_key', None): + return RediSearchFieldTypes.TAG + field_type = field.outer_type_ + + # TODO: GEO + # TODO: TAG (other than PK) + if any(isinstance(field_type, t) for t in NUMERIC_TYPES): + return RediSearchFieldTypes.NUMERIC + else: + return RediSearchFieldTypes.TEXT + + def resolve_value(self, field_name: str, field_type: RediSearchFieldTypes, + op: Operators, value: Any) -> str: + result = "" + if field_type is RediSearchFieldTypes.TEXT: + result = f"@{field_name}:" + if op is Operators.EQ: + result += f'"{value}"' + elif op is Operators.LIKE: + result += value + else: + raise QueryNotSupportedError("Only equals (=) comparisons are currently supported " + "for TEXT fields. See docs: TODO") + elif field_type is RediSearchFieldTypes.NUMERIC: + if op is Operators.EQ: + result += f"@{field_name}:[{value} {value}]" + elif op is Operators.NE: + # TODO: Is this enough or do we also need a clause for all values ([-inf +inf])? + result += f"~(@{field_name}:[{value} {value}])" + elif op is Operators.GT: + result += f"@{field_name}:[({value} +inf]" + elif op is Operators.LT: + result += f"@{field_name}:[-inf ({value}]" + elif op is Operators.GTE: + result += f"@{field_name}:[{value} +inf]" + elif op is Operators.LTE: + result += f"@{field_name}:[-inf {value}]" + + return result + + def resolve_redisearch_query(self, expression: Expression): + """Resolve an expression to a string RediSearch query.""" + field_type = None + field_name = None + result = "" + if isinstance(expression.left, Expression): + result += f"({self.resolve_redisearch_query(expression.left)})" + elif isinstance(expression.left, ModelField): + field_type = self.resolve_field_type(expression.left) + field_name = expression.left.name + else: + raise QueryNotSupportedError(f"A query expression should start with either a field" + f"or an expression enclosed in parenthesis. See docs: " + f"TODO") + + if isinstance(expression.right, Expression): + if expression.op == Operators.AND: + result += " (" + elif expression.op == Operators.OR: + result += "| (" + elif expression.op == Operators.NOT: + result += " ~(" + else: + raise QueryNotSupportedError("You can only combine two query expressions with" + "AND (&), OR (|), or NOT (~). See docs: TODO") + result += f"{self.resolve_redisearch_query(expression.right)})" # NOTE: We add the closing paren + else: + if isinstance(expression.right, ModelField): + raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO") + else: + result += f"({self.resolve_value(field_name, field_type, expression.op, expression.right)})" + + return result + + def find(self): + return self.model.db().execute_command("ft.search", self.model.Meta.index_name, + self.query) class PrimaryKeyCreator(Protocol): @@ -66,13 +197,22 @@ class ExpressionProxy: self.field = field def __eq__(self, other: Any) -> Expression: - return Expression(field=self.field, op=Operations.EQ, right_value=other) + return Expression(left=self.field, op=Operators.EQ, right=other) + + def __ne__(self, other: Any) -> Expression: + return Expression(left=self.field, op=Operators.NE, right=other) def __lt__(self, other: Any) -> Expression: - return Expression(field=self.field, op=Operations.LT, right_value=other) + return Expression(left=self.field, op=Operators.LT, right=other) + + def __le__(self, other: Any) -> Expression: + return Expression(left=self.field, op=Operators.LE, right=other) def __gt__(self, other: Any) -> Expression: - return Expression(field=self.field, op=Operations.GT, right_value=other) + return Expression(left=self.field, op=Operators.GT, right=other) + + def __ge__(self, other: Any) -> Expression: + return Expression(left=self.field, op=Operators.GE, right=other) def __dataclass_transform__( @@ -88,13 +228,13 @@ def __dataclass_transform__( class FieldInfo(PydanticFieldInfo): def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: primary_key = kwargs.pop("primary_key", False) - nullable = kwargs.pop("nullable", Undefined) + sortable = kwargs.pop("sortable", Undefined) foreign_key = kwargs.pop("foreign_key", Undefined) index = kwargs.pop("index", Undefined) unique = kwargs.pop("unique", Undefined) super().__init__(default=default, **kwargs) self.primary_key = primary_key - self.nullable = nullable + self.sortable = sortable self.foreign_key = foreign_key self.index = index self.unique = unique @@ -139,7 +279,7 @@ def Field( primary_key: bool = False, unique: bool = False, foreign_key: Optional[Any] = None, - nullable: Union[bool, UndefinedType] = Undefined, + sortable: Union[bool, UndefinedType] = Undefined, index: Union[bool, UndefinedType] = Undefined, schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: @@ -167,7 +307,7 @@ def Field( primary_key=primary_key, unique=unique, foreign_key=foreign_key, - nullable=nullable, + sortable=sortable, index=index, **current_schema_extra, ) @@ -188,48 +328,83 @@ class DefaultMeta: database: Optional[redis.Redis] = None primary_key: Optional[PrimaryKey] = None primary_key_creator_cls: Type[PrimaryKeyCreator] = None + index_name: str = None + abstract: bool = False -class RedisModel(BaseModel): +class ModelMeta(ModelMetaclass): + def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 + meta = attrs.pop('Meta', None) + new_class = super().__new__(cls, name, bases, attrs, **kwargs) + + meta = meta or getattr(new_class, 'Meta', None) + base_meta = getattr(new_class, '_meta', None) + + if meta and meta is not DefaultMeta: + new_class.Meta = meta + new_class._meta = meta + elif base_meta: + new_class._meta = copy(base_meta) + new_class.Meta = new_class._meta + # Unset inherited values we don't want to reuse (typically based on the model name). + new_class._meta.abstract = False + new_class._meta.model_key_prefix = None + new_class._meta.index_name = None + else: + new_class._meta = copy(DefaultMeta) + new_class.Meta = new_class._meta + + # Not an abstract model class + if abc.ABC not in bases: + key = f"{new_class.__module__}.{new_class.__qualname__}" + model_registry[key] = new_class + + # Create proxies for each model field so that we can use the field + # in queries, like Model.get(Model.field_name == 1) + for name, field in new_class.__fields__.items(): + setattr(new_class, name, ExpressionProxy(field)) + # Check if this is our FieldInfo version with extended ORM metadata. + if isinstance(field.field_info, FieldInfo): + if field.field_info.primary_key: + new_class._meta.primary_key = PrimaryKey(name=name, field=field) + + # TODO: Raise exception here, global key prefix required? + if not getattr(new_class._meta, 'global_key_prefix', None): + new_class._meta.global_key_prefix = getattr(base_meta, "global_key_prefix", "") + if not getattr(new_class._meta, 'model_key_prefix', None): + # Don't look at the base class for this. + new_class._meta.model_key_prefix = f"{new_class.__name__.lower()}" + if not getattr(new_class._meta, 'primary_key_pattern', None): + new_class._meta.primary_key_pattern = getattr(base_meta, "primary_key_pattern", + "{pk}") + if not getattr(new_class._meta, 'database', None): + new_class._meta.database = getattr(base_meta, "database", + redis.Redis(decode_responses=True)) + if not getattr(new_class._meta, 'primary_key_creator_cls', None): + new_class._meta.primary_key_creator_cls = getattr(base_meta, "primary_key_creator_cls", + Uuid4PrimaryKey) + if not getattr(new_class._meta, 'index_name', None): + new_class._meta.index_name = f"{new_class._meta.global_key_prefix}:" \ + f"{new_class._meta.model_key_prefix}:index" + + return new_class + + +class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): """ TODO: Convert expressions to Redis commands, execute TODO: Key prefix vs. "key pattern" (that's actually the primary key pattern) - TODO: Default key prefix is model name lowercase - TODO: Build primary key pattern from PK field name, model prefix - TODO: Default PK pattern is model name:pk field + TODO: Generate RediSearch schema """ pk: Optional[str] = Field(default=None, primary_key=True) + Meta = DefaultMeta + class Config: orm_mode = True arbitrary_types_allowed = True extra = 'allow' - Meta = DefaultMeta - - def __init_subclass__(cls, **kwargs): - # Create proxies for each model field so that we can use the field - # in queries, like Model.get(Model.field_name == 1) - super().__init_subclass__(**kwargs) - - for name, field in cls.__fields__.items(): - setattr(cls, name, ExpressionProxy(field)) - # Check if this is our FieldInfo version with extended ORM metadata. - if isinstance(field.field_info, FieldInfo): - if field.field_info.primary_key: - cls.Meta.primary_key = PrimaryKey(name=name, field=field) - # TODO: Raise exception here, global key prefix required? - if not getattr(cls.Meta, 'global_key_prefix'): - cls.Meta.global_key_prefix = "" - if not getattr(cls.Meta, 'model_key_prefix'): - cls.Meta.model_key_prefix = f"{cls.__name__.lower()}" - if not getattr(cls.Meta, 'primary_key_pattern'): - cls.Meta.primary_key_pattern = "{pk}" - if not getattr(cls.Meta, 'database'): - cls.Meta.database = redis.Redis(decode_responses=True) - if not getattr(cls.Meta, 'primary_key_creator_cls'): - cls.Meta.primary_key_creator_cls = Uuid4PrimaryKey - def __init__(__pydantic_self__, **data: Any) -> None: super().__init__(**data) __pydantic_self__.validate_primary_key() @@ -237,7 +412,7 @@ class RedisModel(BaseModel): @validator("pk", always=True) def validate_pk(cls, v): if not v: - v = cls.Meta.primary_key_creator_cls().create_pk() + v = cls._meta.primary_key_creator_cls().create_pk() return v @classmethod @@ -254,30 +429,66 @@ class RedisModel(BaseModel): @classmethod def make_key(cls, part: str): - global_prefix = getattr(cls.Meta, 'global_key_prefix', '').strip(":") - model_prefix = getattr(cls.Meta, 'model_key_prefix', '').strip(":") + global_prefix = getattr(cls._meta, 'global_key_prefix', '').strip(":") + model_prefix = getattr(cls._meta, 'model_key_prefix', '').strip(":") return f"{global_prefix}:{model_prefix}:{part}" @classmethod def make_primary_key(cls, pk: Any): """Return the Redis key for this model.""" - return cls.make_key(cls.Meta.primary_key_pattern.format(pk=pk)) + return cls.make_key(cls._meta.primary_key_pattern.format(pk=pk)) def key(self): """Return the Redis key for this model.""" - pk = getattr(self, self.Meta.primary_key.field.name) + pk = getattr(self, self._meta.primary_key.field.name) return self.make_primary_key(pk) @classmethod def db(cls): - return cls.Meta.database + return cls._meta.database + + @classmethod + def from_redis(cls, res: Any): + import six + from six.moves import xrange, zip as izip + + def to_string(s): + if isinstance(s, six.string_types): + return s + elif isinstance(s, six.binary_type): + return s.decode('utf-8','ignore') + else: + return s # Not a string we care about + + docs = [] + step = 2 # Because the result has content + offset = 1 + + for i in xrange(1, len(res), step): + fields_offset = offset + + fields = dict( + dict(izip(map(to_string, res[i + fields_offset][::2]), + map(to_string, res[i + fields_offset][1::2]))) + ) + + try: + del fields['id'] + except KeyError: + pass + + doc = cls(**fields) + docs.append(doc) + return docs @classmethod - def filter(cls, *expressions: Sequence[Expression]): - return cls + def find(cls, *expressions: Expression): + query = FindQuery(expressions=expressions, model=cls) + raw_result = query.find() + return cls.from_redis(raw_result) @classmethod - def exclude(cls, *expressions: Sequence[Expression]): + def find_one(cls, *expressions: Sequence[Expression]): return cls @classmethod @@ -296,16 +507,15 @@ class RedisModel(BaseModel): def delete(self): return self.db().delete(self.key()) - # TODO: Protocol - @classmethod - def get(cls, pk: Any): - raise NotImplementedError - def save(self, *args, **kwargs) -> 'RedisModel': raise NotImplementedError + @classmethod + def schema(cls): + raise NotImplementedError -class HashModel(RedisModel): + +class HashModel(RedisModel, abc.ABC): def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) @@ -346,8 +556,34 @@ class HashModel(RedisModel): return "" return val + @classmethod + def schema_for_type(cls, name, typ: Type): + if any(issubclass(typ, t) for t in NUMERIC_TYPES): + return f"{name} NUMERIC" + else: + return f"{name} TEXT" + + @classmethod + def schema(cls): + hash_prefix = cls.make_key(cls._meta.primary_key_pattern.format(pk="")) + schema_prefix = f"ON HASH PREFIX 1 {hash_prefix} SCHEMA" + schema_parts = [schema_prefix] + for name, field in cls.__fields__.items(): + _type = field.outer_type_ + if getattr(field.field_info, 'primary_key', None): + if issubclass(_type, str): + redisearch_field = f"{name} TAG" + else: + redisearch_field = cls.schema_for_type(name, _type) + schema_parts.append(redisearch_field) + else: + schema_parts.append(cls.schema_for_type(name, _type)) + if getattr(field.field_info, 'sortable', False): + schema_parts.append("SORTABLE") + return " ".join(schema_parts) -class JsonModel(RedisModel): + +class JsonModel(RedisModel, abc.ABC): def save(self, *args, **kwargs) -> 'JsonModel': success = self.db().execute_command('JSON.SET', self.key(), ".", self.json()) return success diff --git a/redis_developer/orm/models.py b/redis_developer/orm/models.py new file mode 100644 index 0000000..664a7bc --- /dev/null +++ b/redis_developer/orm/models.py @@ -0,0 +1,30 @@ +import abc +from typing import Optional + +from redis_developer.orm.model import JsonModel, HashModel + + +class BaseJsonModel(JsonModel, abc.ABC): + class Meta: + global_key_prefix = "redis-developer" + + +class BaseHashModel(HashModel, abc.ABC): + class Meta: + global_key_prefix = "redis-developer" + + +# class AddressJson(BaseJsonModel): +# address_line_1: str +# address_line_2: Optional[str] +# city: str +# country: str +# postal_code: str +# + +class AddressHash(BaseHashModel): + address_line_1: str + address_line_2: Optional[str] + city: str + country: str + postal_code: str diff --git a/redis_developer/orm/query_iterator.py b/redis_developer/orm/query_iterator.py new file mode 100644 index 0000000..698b743 --- /dev/null +++ b/redis_developer/orm/query_iterator.py @@ -0,0 +1,55 @@ +from redis_developer.orm.model import Expression + + +class QueryIterator: + """ + A lazy iterator that yields results from a RediSearch query. + + Examples: + + results = Model.filter(email == "a@example.com") + + # Consume all results. + for r in results: + print(r) + + # Consume an item at an index. + print(results[100]) + + # Consume a slice. + print(results[0:100]) + + # Alternative notation to consume all items. + print(results[0:-1]) + + # Specify the batch size: + results = Model.filter(email == "a@example.com", batch_size=1000) + ... + """ + def __init__(self, client, query, batch_size=100): + self.client = client + self.query = query + self.batch_size = batch_size + + def __iter__(self): + pass + + def __getattr__(self, item): + """Support getting a single value or a slice.""" + + # TODO: Query mixin? + + def filter(self, *expressions: Expression): + pass + + def exclude(self, *expressions: Expression): + pass + + def and_(self, *expressions: Expression): + pass + + def or_(self, *expressions: Expression): + pass + + def not_(self, *expressions: Expression): + pass diff --git a/redis_developer/orm/query_resolver.py b/redis_developer/orm/query_resolver.py new file mode 100644 index 0000000..18e1b2b --- /dev/null +++ b/redis_developer/orm/query_resolver.py @@ -0,0 +1,106 @@ +from collections import Sequence +from typing import Any, Dict, Mapping, Union, List + +from redis_developer.orm.model import Expression + + +class LogicalOperatorForListOfExpressions(Expression): + operator: str = "" + + def __init__(self, *expressions: Expression): + self.expressions = list(expressions) + + @property + def query(self) -> Mapping[str, List[Expression]]: + if not self.expressions: + raise AttributeError("At least one expression must be provided") + # TODO: This needs to return a RediSearch string. + # Use the values in each expression object to build the string. + # Determine the type of query based on the field (numeric range, + # tag field, etc.). + return {self.operator: self.expressions} + + +class Or(LogicalOperatorForListOfExpressions): + """ + Logical OR query operator + + Example: + + ```python + class Product(JsonModel): + price: float + category: str + + Or(Product.price < 10, Product.category == "Sweets") + ``` + + Will return RediSearch query string like: + + ``` + (@price:[-inf 10]) | (@category:{Sweets}) + ``` + """ + + operator = "|" + + +class And(LogicalOperatorForListOfExpressions): + """ + Logical AND query operator + + Example: + + ```python + class Product(Document): + price: float + category: str + + And(Product.price < 10, Product.category == "Sweets") + ``` + + Will return a query string like: + + ``` + (@price:[-inf 10]) (@category:{Sweets}) + ``` + + Note that in RediSearch, AND is implied with multiple terms. + """ + + operator = " " + + +class Not(LogicalOperatorForListOfExpressions): + """ + Logical NOT query operator + + Example: + + ```python + class Product(Document): + price: float + category: str + + Not(Product.price<10, Product.category=="Sweets") + ``` + + Will return a query string like: + + ``` + -(@price:[-inf 10]) -(@category:{Sweets}) + ``` + """ + @property + def query(self): + return "-(expression1) -(expression2)" + + +class QueryResolver: + def __init__(self, *expressions: Expression): + self.expressions = expressions + + def resolve(self) -> str: + """Resolve expressions to a RediSearch query string.""" + + diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..237f174 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,31 @@ +import pytest +from redis import Redis + +from redis_developer.connections import get_redis_connection +from redis_developer.orm.migrations.migrator import Migrator + + +@pytest.fixture(scope="module", autouse=True) +def migrations(): + Migrator().run() + + +@pytest.fixture +def redis(): + yield get_redis_connection() + + +@pytest.fixture +def key_prefix(): + # TODO + yield "redis-developer" + + +def _delete_test_keys(prefix: str, conn: Redis): + for key in conn.scan_iter(f"{prefix}:*"): + conn.delete(key) + + +@pytest.fixture(scope="function", autouse=True) +def delete_test_keys(redis, request, key_prefix): + _delete_test_keys(key_prefix, redis) diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index c6363ce..fbb7ad2 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -1,6 +1,7 @@ +import abc import decimal import datetime -from typing import Optional, List +from typing import Optional import pytest import redis @@ -16,19 +17,11 @@ r = redis.Redis() today = datetime.date.today() -class BaseHashModel(HashModel): - class Meta(HashModel.Meta): +class BaseHashModel(HashModel, abc.ABC): + class Meta: global_key_prefix = "redis-developer" -class Address(BaseHashModel): - address_line_1: str - address_line_2: Optional[str] - city: str - country: str - postal_code: str - - class Order(BaseHashModel): total: decimal.Decimal currency: str @@ -41,7 +34,7 @@ class Member(BaseHashModel): email: str = Field(unique=True, index=True) join_date: datetime.date - class Meta(BaseHashModel.Meta): + class Meta: model_key_prefix = "member" primary_key_pattern = "" @@ -68,13 +61,6 @@ def test_validates_field(): # Passes validation def test_validation_passes(): - address = Address( - address_line_1="1 Main St.", - city="Happy Town", - state="WY", - postal_code=11111, - country="USA" - ) member = Member( first_name="Andrew", last_name="Brookins", @@ -99,6 +85,13 @@ def test_saves_model_and_creates_pk(): def test_raises_error_with_embedded_models(): + class Address(BaseHashModel): + address_line_1: str + address_line_2: Optional[str] + city: str + country: str + postal_code: str + with pytest.raises(RedisModelError): class InvalidMember(BaseHashModel): address: Address @@ -142,3 +135,51 @@ def test_updates_a_model(): # Or, affecting multiple model instances with an implicit save: Member.filter(Member.last_name == "Brookins").update(last_name="Sam-Bodden") + + +def test_exact_match_queries(): + member1 = Member( + first_name="Andrew", + last_name="Brookins", + email="a@example.com", + join_date=today + ) + + member2 = Member( + first_name="Kim", + last_name="Brookins", + email="k@example.com", + join_date=today + ) + member1.save() + member2.save() + + actual = Member.find(Member.last_name == "Brookins") + assert actual == [member2, member1] + + + # actual = Member.find( + # (Member.last_name == "Brookins") & (~Member.first_name == "Andrew")) + # assert actual == [member2] + + # actual = Member.find(~Member.last_name == "Brookins") + # assert actual == [] + + # actual = Member.find( + # (Member.last_name == "Brookins") & (Member.first_name == "Andrew") + # | (Member.first_name == "Kim") + # ) + # assert actual == [member1, member2] + + # actual = Member.find_one(Member.last_name == "Brookins") + # assert actual == member1 + + +def test_schema(): + class Address(BaseHashModel): + a_string: str + an_integer: int + a_float: float + + assert Address.schema() == "SCHEMA pk TAG SORTABLE a_string TEXT an_integer NUMERIC " \ + "a_float NUMERIC" diff --git a/tests/test_json_model.py b/tests/test_json_model.py index ef34383..98b24aa 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -1,3 +1,4 @@ +import abc import decimal import datetime from typing import Optional, List @@ -15,8 +16,8 @@ r = redis.Redis() today = datetime.datetime.today() -class BaseJsonModel(JsonModel): - class Meta(JsonModel.Meta): +class BaseJsonModel(JsonModel, abc.ABC): + class Meta: global_key_prefix = "redis-developer"