Add basic migrations, query expression resolver
This commit is contained in:
parent
afe05fb7dd
commit
0990c2e1b4
15 changed files with 752 additions and 88 deletions
43
poetry.lock
generated
43
poetry.lock
generated
|
@ -59,17 +59,28 @@ category = "dev"
|
|||
optional = false
|
||||
python-versions = "*"
|
||||
|
||||
[[package]]
|
||||
name = "click"
|
||||
version = "8.0.1"
|
||||
description = "Composable command line interface toolkit"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
|
||||
[package.dependencies]
|
||||
colorama = {version = "*", markers = "platform_system == \"Windows\""}
|
||||
|
||||
[[package]]
|
||||
name = "colorama"
|
||||
version = "0.4.4"
|
||||
description = "Cross-platform colored terminal text."
|
||||
category = "dev"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
||||
|
||||
[[package]]
|
||||
name = "decorator"
|
||||
version = "5.0.9"
|
||||
version = "5.1.0"
|
||||
description = "Decorators for Humans"
|
||||
category = "dev"
|
||||
optional = false
|
||||
|
@ -145,7 +156,7 @@ testing = ["Django (<3.1)", "colorama", "docopt", "pytest (<6.0.0)"]
|
|||
|
||||
[[package]]
|
||||
name = "matplotlib-inline"
|
||||
version = "0.1.2"
|
||||
version = "0.1.3"
|
||||
description = "Inline Matplotlib backend for Jupyter"
|
||||
category = "dev"
|
||||
optional = false
|
||||
|
@ -298,6 +309,14 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
|||
[package.extras]
|
||||
hiredis = ["hiredis (>=0.1.3)"]
|
||||
|
||||
[[package]]
|
||||
name = "six"
|
||||
version = "1.16.0"
|
||||
description = "Python 2 and 3 compatibility utilities"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
|
||||
|
||||
[[package]]
|
||||
name = "toml"
|
||||
version = "0.10.2"
|
||||
|
@ -336,7 +355,7 @@ python-versions = "*"
|
|||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = "^3.8"
|
||||
content-hash = "e5ac777000236190a585bef489a7fbe744a0d0dc328c001cb198a34746bd24a3"
|
||||
content-hash = "b3f0c7c5701bb2c317df7f2f42218ef6c38d9b035ed49c9a9469df6a6727973c"
|
||||
|
||||
[metadata.files]
|
||||
aioredis = [
|
||||
|
@ -363,13 +382,17 @@ backcall = [
|
|||
{file = "backcall-0.2.0-py2.py3-none-any.whl", hash = "sha256:fbbce6a29f263178a1f7915c1940bde0ec2b2a967566fe1c65c1dfb7422bd255"},
|
||||
{file = "backcall-0.2.0.tar.gz", hash = "sha256:5cbdbf27be5e7cfadb448baf0aa95508f91f2bbc6c6437cd9cd06e2a4c215e1e"},
|
||||
]
|
||||
click = [
|
||||
{file = "click-8.0.1-py3-none-any.whl", hash = "sha256:fba402a4a47334742d782209a7c79bc448911afe1149d07bdabdf480b3e2f4b6"},
|
||||
{file = "click-8.0.1.tar.gz", hash = "sha256:8c04c11192119b1ef78ea049e0a6f0463e4c48ef00a30160c704337586f3ad7a"},
|
||||
]
|
||||
colorama = [
|
||||
{file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"},
|
||||
{file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"},
|
||||
]
|
||||
decorator = [
|
||||
{file = "decorator-5.0.9-py3-none-any.whl", hash = "sha256:6e5c199c16f7a9f0e3a61a4a54b3d27e7dad0dbdde92b944426cb20914376323"},
|
||||
{file = "decorator-5.0.9.tar.gz", hash = "sha256:72ecfba4320a893c53f9706bebb2d55c270c1e51a28789361aa93e4a21319ed5"},
|
||||
{file = "decorator-5.1.0-py3-none-any.whl", hash = "sha256:7b12e7c3c6ab203a29e157335e9122cb03de9ab7264b137594103fd4a683b374"},
|
||||
{file = "decorator-5.1.0.tar.gz", hash = "sha256:e59913af105b9860aa2c8d3272d9de5a56a4e608db9a2f167a8480b323d529a7"},
|
||||
]
|
||||
iniconfig = [
|
||||
{file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"},
|
||||
|
@ -387,8 +410,8 @@ jedi = [
|
|||
{file = "jedi-0.18.0.tar.gz", hash = "sha256:92550a404bad8afed881a137ec9a461fed49eca661414be45059329614ed0707"},
|
||||
]
|
||||
matplotlib-inline = [
|
||||
{file = "matplotlib-inline-0.1.2.tar.gz", hash = "sha256:f41d5ff73c9f5385775d5c0bc13b424535c8402fe70ea8210f93e11f3683993e"},
|
||||
{file = "matplotlib_inline-0.1.2-py3-none-any.whl", hash = "sha256:5cf1176f554abb4fa98cb362aa2b55c500147e4bdbb07e3fda359143e1da0811"},
|
||||
{file = "matplotlib-inline-0.1.3.tar.gz", hash = "sha256:a04bfba22e0d1395479f866853ec1ee28eea1485c1d69a6faf00dc3e24ff34ee"},
|
||||
{file = "matplotlib_inline-0.1.3-py3-none-any.whl", hash = "sha256:aed605ba3b72462d64d475a21a9296f400a19c4f74a31b59103d2a99ffd5aa5c"},
|
||||
]
|
||||
packaging = [
|
||||
{file = "packaging-21.0-py3-none-any.whl", hash = "sha256:c86254f9220d55e31cc94d69bade760f0847da8000def4dfe1c6b872fd14ff14"},
|
||||
|
@ -462,6 +485,10 @@ redis = [
|
|||
{file = "redis-3.5.3-py2.py3-none-any.whl", hash = "sha256:432b788c4530cfe16d8d943a09d40ca6c16149727e4afe8c2c9d5580c59d9f24"},
|
||||
{file = "redis-3.5.3.tar.gz", hash = "sha256:0e7e0cfca8660dea8b7d5cd8c4f6c5e29e11f31158c0b0ae91a397f00e5a05a2"},
|
||||
]
|
||||
six = [
|
||||
{file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"},
|
||||
{file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
|
||||
]
|
||||
toml = [
|
||||
{file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
|
||||
{file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
|
||||
|
|
|
@ -10,11 +10,16 @@ python = "^3.8"
|
|||
redis = "^3.5.3"
|
||||
aioredis = "^2.0.0"
|
||||
pydantic = "^1.8.2"
|
||||
click = "^8.0.1"
|
||||
six = "^1.16.0"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest = "^6.2.4"
|
||||
ipdb = "^0.13.9"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
migrate = "redis_developer.orm.cli.migrate:migrate"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
|
|
@ -2,4 +2,4 @@ import redis
|
|||
|
||||
|
||||
def get_redis_connection() -> redis.Redis:
|
||||
return redis.Redis()
|
||||
return redis.Redis(decode_responses=True)
|
0
redis_developer/orm/cli/__init__.py
Normal file
0
redis_developer/orm/cli/__init__.py
Normal file
16
redis_developer/orm/cli/migrate.py
Normal file
16
redis_developer/orm/cli/migrate.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
import click
|
||||
from redis_developer.orm.migrations.migrator import Migrator
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--module", default="redis_developer")
|
||||
def migrate(module):
|
||||
migrator = Migrator(module)
|
||||
|
||||
if migrator.migrations:
|
||||
print("Pending migrations:")
|
||||
for migration in migrator.migrations:
|
||||
print(migration)
|
||||
|
||||
if input(f"Run migrations? (y/n) ") == "y":
|
||||
migrator.run()
|
0
redis_developer/orm/migrations/__init__.py
Normal file
0
redis_developer/orm/migrations/__init__.py
Normal file
116
redis_developer/orm/migrations/migrator.py
Normal file
116
redis_developer/orm/migrations/migrator.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
import hashlib
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from redis import ResponseError
|
||||
|
||||
from redis_developer.connections import get_redis_connection
|
||||
from redis_developer.orm.model import model_registry
|
||||
|
||||
redis = get_redis_connection()
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
import importlib
|
||||
import pkgutil
|
||||
|
||||
|
||||
def import_submodules(root_module_name: str):
|
||||
"""Import all submodules of a module, recursively."""
|
||||
# TODO: Call this without specifying a module name, to import everything?
|
||||
root_module = importlib.import_module(root_module_name)
|
||||
for loader, module_name, is_pkg in pkgutil.walk_packages(
|
||||
root_module.__path__, root_module.__name__ + '.'):
|
||||
importlib.import_module(module_name)
|
||||
|
||||
|
||||
def schema_hash_key(index_name):
|
||||
return f"{index_name}:hash"
|
||||
|
||||
|
||||
def create_index(index_name, schema, current_hash):
|
||||
redis.execute_command(f"ft.create {index_name} "
|
||||
f"{schema}")
|
||||
redis.set(schema_hash_key(index_name), current_hash)
|
||||
|
||||
|
||||
class MigrationAction(Enum):
|
||||
CREATE = 2
|
||||
DROP = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexMigration:
|
||||
model_name: str
|
||||
index_name: str
|
||||
schema: str
|
||||
hash: str
|
||||
action: MigrationAction
|
||||
previous_hash: Optional[str] = None
|
||||
|
||||
def run(self):
|
||||
if self.action is MigrationAction.CREATE:
|
||||
self.create()
|
||||
elif self.action is MigrationAction.DROP:
|
||||
self.drop()
|
||||
|
||||
def create(self):
|
||||
return create_index(self.index_name, self.schema, self.hash)
|
||||
|
||||
def drop(self):
|
||||
redis.execute_command(f"FT.DROPINDEX {self.index_name}")
|
||||
|
||||
|
||||
class Migrator:
|
||||
def __init__(self, module=None):
|
||||
# Try to load any modules found under the given path or module name.
|
||||
if module:
|
||||
import_submodules(module)
|
||||
|
||||
self.migrations = []
|
||||
|
||||
for name, cls in model_registry.items():
|
||||
hash_key = schema_hash_key(cls.Meta.index_name)
|
||||
try:
|
||||
schema = cls.schema()
|
||||
except NotImplementedError:
|
||||
log.info("Skipping migrations for %s", name)
|
||||
continue
|
||||
current_hash = hashlib.sha1(schema.encode("utf-8")).hexdigest()
|
||||
|
||||
try:
|
||||
redis.execute_command("ft.info", cls.Meta.index_name)
|
||||
except ResponseError:
|
||||
self.migrations.append(
|
||||
IndexMigration(name, cls.Meta.index_name, schema, current_hash,
|
||||
MigrationAction.CREATE))
|
||||
|
||||
stored_hash = redis.get(hash_key)
|
||||
schema_out_of_date = current_hash != stored_hash
|
||||
|
||||
if schema_out_of_date:
|
||||
# TODO: Switch out schema with an alias to avoid downtime -- separate migration?
|
||||
self.migrations.append(
|
||||
IndexMigration(name, cls.Meta.index_name, schema, current_hash,
|
||||
MigrationAction.DROP, stored_hash))
|
||||
self.migrations.append(
|
||||
IndexMigration(name, cls.Meta.index_name, schema, current_hash,
|
||||
MigrationAction.CREATE, stored_hash))
|
||||
|
||||
@property
|
||||
def valid_migrations(self):
|
||||
return self.missing_indexes.keys() + self.out_of_date_indexes.keys()
|
||||
|
||||
def validate_migration(self, model_class_name):
|
||||
if model_class_name not in self.valid_migrations:
|
||||
migrations = ", ".join(self.valid_migrations)
|
||||
raise RuntimeError(f"No migration found for {model_class_name}."
|
||||
f"Valid migrations are: {migrations}")
|
||||
|
||||
def run(self):
|
||||
# TODO: Migration history
|
||||
# TODO: Dry run with output
|
||||
for migration in self.migrations:
|
||||
migration.run()
|
|
@ -1,5 +1,11 @@
|
|||
import abc
|
||||
import dataclasses
|
||||
import decimal
|
||||
import operator
|
||||
from copy import copy
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import reduce
|
||||
from typing import (
|
||||
AbstractSet,
|
||||
Any,
|
||||
|
@ -22,11 +28,15 @@ import redis
|
|||
from pydantic import BaseModel, validator
|
||||
from pydantic.fields import FieldInfo as PydanticFieldInfo
|
||||
from pydantic.fields import ModelField, Undefined, UndefinedType
|
||||
from pydantic.main import ModelMetaclass
|
||||
from pydantic.typing import NoArgAnyCallable
|
||||
from pydantic.utils import Representation
|
||||
|
||||
from .encoders import jsonable_encoder
|
||||
|
||||
|
||||
model_registry = {}
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
|
@ -38,17 +48,138 @@ class NotFoundError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
class Operations(Enum):
|
||||
class Operators(Enum):
|
||||
EQ = 1
|
||||
LT = 2
|
||||
GT = 3
|
||||
NE = 2
|
||||
LT = 3
|
||||
LE = 4
|
||||
GT = 5
|
||||
GE = 6
|
||||
OR = 7
|
||||
AND = 8
|
||||
NOT = 9
|
||||
IN = 10
|
||||
NOT_IN = 11
|
||||
GTE = 12
|
||||
LTE = 13
|
||||
LIKE = 14
|
||||
|
||||
|
||||
@dataclass
|
||||
class Expression:
|
||||
field: ModelField
|
||||
op: Operations
|
||||
right_value: Any
|
||||
op: Operators
|
||||
left: Any
|
||||
right: Any
|
||||
|
||||
def __and__(self, other):
|
||||
return Expression(left=self, op=Operators.AND, right=other)
|
||||
|
||||
|
||||
|
||||
class QueryNotSupportedError(Exception):
|
||||
"""The attempted query is not supported."""
|
||||
|
||||
|
||||
class RediSearchFieldTypes(Enum):
|
||||
TEXT = 'TEXT'
|
||||
TAG = 'TAG'
|
||||
NUMERIC = 'NUMERIC'
|
||||
GEO = 'GEO'
|
||||
|
||||
|
||||
# TODO: How to handle Geo fields?
|
||||
NUMERIC_TYPES = (float, int, decimal.Decimal)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FindQuery:
|
||||
expressions: Sequence[Expression]
|
||||
expression: Expression = dataclasses.field(init=False)
|
||||
query: str = dataclasses.field(init=False)
|
||||
model: Type['RedisModel']
|
||||
|
||||
def __post_init__(self):
|
||||
self.expression = reduce(operator.and_, self.expressions)
|
||||
self.query = self.resolve_redisearch_query(self.expression)
|
||||
|
||||
def resolve_field_type(self, field: ModelField) -> RediSearchFieldTypes:
|
||||
if getattr(field.field_info, 'primary_key', None):
|
||||
return RediSearchFieldTypes.TAG
|
||||
field_type = field.outer_type_
|
||||
|
||||
# TODO: GEO
|
||||
# TODO: TAG (other than PK)
|
||||
if any(isinstance(field_type, t) for t in NUMERIC_TYPES):
|
||||
return RediSearchFieldTypes.NUMERIC
|
||||
else:
|
||||
return RediSearchFieldTypes.TEXT
|
||||
|
||||
def resolve_value(self, field_name: str, field_type: RediSearchFieldTypes,
|
||||
op: Operators, value: Any) -> str:
|
||||
result = ""
|
||||
if field_type is RediSearchFieldTypes.TEXT:
|
||||
result = f"@{field_name}:"
|
||||
if op is Operators.EQ:
|
||||
result += f'"{value}"'
|
||||
elif op is Operators.LIKE:
|
||||
result += value
|
||||
else:
|
||||
raise QueryNotSupportedError("Only equals (=) comparisons are currently supported "
|
||||
"for TEXT fields. See docs: TODO")
|
||||
elif field_type is RediSearchFieldTypes.NUMERIC:
|
||||
if op is Operators.EQ:
|
||||
result += f"@{field_name}:[{value} {value}]"
|
||||
elif op is Operators.NE:
|
||||
# TODO: Is this enough or do we also need a clause for all values ([-inf +inf])?
|
||||
result += f"~(@{field_name}:[{value} {value}])"
|
||||
elif op is Operators.GT:
|
||||
result += f"@{field_name}:[({value} +inf]"
|
||||
elif op is Operators.LT:
|
||||
result += f"@{field_name}:[-inf ({value}]"
|
||||
elif op is Operators.GTE:
|
||||
result += f"@{field_name}:[{value} +inf]"
|
||||
elif op is Operators.LTE:
|
||||
result += f"@{field_name}:[-inf {value}]"
|
||||
|
||||
return result
|
||||
|
||||
def resolve_redisearch_query(self, expression: Expression):
|
||||
"""Resolve an expression to a string RediSearch query."""
|
||||
field_type = None
|
||||
field_name = None
|
||||
result = ""
|
||||
if isinstance(expression.left, Expression):
|
||||
result += f"({self.resolve_redisearch_query(expression.left)})"
|
||||
elif isinstance(expression.left, ModelField):
|
||||
field_type = self.resolve_field_type(expression.left)
|
||||
field_name = expression.left.name
|
||||
else:
|
||||
raise QueryNotSupportedError(f"A query expression should start with either a field"
|
||||
f"or an expression enclosed in parenthesis. See docs: "
|
||||
f"TODO")
|
||||
|
||||
if isinstance(expression.right, Expression):
|
||||
if expression.op == Operators.AND:
|
||||
result += " ("
|
||||
elif expression.op == Operators.OR:
|
||||
result += "| ("
|
||||
elif expression.op == Operators.NOT:
|
||||
result += " ~("
|
||||
else:
|
||||
raise QueryNotSupportedError("You can only combine two query expressions with"
|
||||
"AND (&), OR (|), or NOT (~). See docs: TODO")
|
||||
result += f"{self.resolve_redisearch_query(expression.right)})" # NOTE: We add the closing paren
|
||||
else:
|
||||
if isinstance(expression.right, ModelField):
|
||||
raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO")
|
||||
else:
|
||||
result += f"({self.resolve_value(field_name, field_type, expression.op, expression.right)})"
|
||||
|
||||
return result
|
||||
|
||||
def find(self):
|
||||
return self.model.db().execute_command("ft.search", self.model.Meta.index_name,
|
||||
self.query)
|
||||
|
||||
|
||||
class PrimaryKeyCreator(Protocol):
|
||||
|
@ -66,13 +197,22 @@ class ExpressionProxy:
|
|||
self.field = field
|
||||
|
||||
def __eq__(self, other: Any) -> Expression:
|
||||
return Expression(field=self.field, op=Operations.EQ, right_value=other)
|
||||
return Expression(left=self.field, op=Operators.EQ, right=other)
|
||||
|
||||
def __ne__(self, other: Any) -> Expression:
|
||||
return Expression(left=self.field, op=Operators.NE, right=other)
|
||||
|
||||
def __lt__(self, other: Any) -> Expression:
|
||||
return Expression(field=self.field, op=Operations.LT, right_value=other)
|
||||
return Expression(left=self.field, op=Operators.LT, right=other)
|
||||
|
||||
def __le__(self, other: Any) -> Expression:
|
||||
return Expression(left=self.field, op=Operators.LE, right=other)
|
||||
|
||||
def __gt__(self, other: Any) -> Expression:
|
||||
return Expression(field=self.field, op=Operations.GT, right_value=other)
|
||||
return Expression(left=self.field, op=Operators.GT, right=other)
|
||||
|
||||
def __ge__(self, other: Any) -> Expression:
|
||||
return Expression(left=self.field, op=Operators.GE, right=other)
|
||||
|
||||
|
||||
def __dataclass_transform__(
|
||||
|
@ -88,13 +228,13 @@ def __dataclass_transform__(
|
|||
class FieldInfo(PydanticFieldInfo):
|
||||
def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
|
||||
primary_key = kwargs.pop("primary_key", False)
|
||||
nullable = kwargs.pop("nullable", Undefined)
|
||||
sortable = kwargs.pop("sortable", Undefined)
|
||||
foreign_key = kwargs.pop("foreign_key", Undefined)
|
||||
index = kwargs.pop("index", Undefined)
|
||||
unique = kwargs.pop("unique", Undefined)
|
||||
super().__init__(default=default, **kwargs)
|
||||
self.primary_key = primary_key
|
||||
self.nullable = nullable
|
||||
self.sortable = sortable
|
||||
self.foreign_key = foreign_key
|
||||
self.index = index
|
||||
self.unique = unique
|
||||
|
@ -139,7 +279,7 @@ def Field(
|
|||
primary_key: bool = False,
|
||||
unique: bool = False,
|
||||
foreign_key: Optional[Any] = None,
|
||||
nullable: Union[bool, UndefinedType] = Undefined,
|
||||
sortable: Union[bool, UndefinedType] = Undefined,
|
||||
index: Union[bool, UndefinedType] = Undefined,
|
||||
schema_extra: Optional[Dict[str, Any]] = None,
|
||||
) -> Any:
|
||||
|
@ -167,7 +307,7 @@ def Field(
|
|||
primary_key=primary_key,
|
||||
unique=unique,
|
||||
foreign_key=foreign_key,
|
||||
nullable=nullable,
|
||||
sortable=sortable,
|
||||
index=index,
|
||||
**current_schema_extra,
|
||||
)
|
||||
|
@ -188,48 +328,83 @@ class DefaultMeta:
|
|||
database: Optional[redis.Redis] = None
|
||||
primary_key: Optional[PrimaryKey] = None
|
||||
primary_key_creator_cls: Type[PrimaryKeyCreator] = None
|
||||
index_name: str = None
|
||||
abstract: bool = False
|
||||
|
||||
|
||||
class RedisModel(BaseModel):
|
||||
class ModelMeta(ModelMetaclass):
|
||||
def __new__(cls, name, bases, attrs, **kwargs): # noqa C901
|
||||
meta = attrs.pop('Meta', None)
|
||||
new_class = super().__new__(cls, name, bases, attrs, **kwargs)
|
||||
|
||||
meta = meta or getattr(new_class, 'Meta', None)
|
||||
base_meta = getattr(new_class, '_meta', None)
|
||||
|
||||
if meta and meta is not DefaultMeta:
|
||||
new_class.Meta = meta
|
||||
new_class._meta = meta
|
||||
elif base_meta:
|
||||
new_class._meta = copy(base_meta)
|
||||
new_class.Meta = new_class._meta
|
||||
# Unset inherited values we don't want to reuse (typically based on the model name).
|
||||
new_class._meta.abstract = False
|
||||
new_class._meta.model_key_prefix = None
|
||||
new_class._meta.index_name = None
|
||||
else:
|
||||
new_class._meta = copy(DefaultMeta)
|
||||
new_class.Meta = new_class._meta
|
||||
|
||||
# Not an abstract model class
|
||||
if abc.ABC not in bases:
|
||||
key = f"{new_class.__module__}.{new_class.__qualname__}"
|
||||
model_registry[key] = new_class
|
||||
|
||||
# Create proxies for each model field so that we can use the field
|
||||
# in queries, like Model.get(Model.field_name == 1)
|
||||
for name, field in new_class.__fields__.items():
|
||||
setattr(new_class, name, ExpressionProxy(field))
|
||||
# Check if this is our FieldInfo version with extended ORM metadata.
|
||||
if isinstance(field.field_info, FieldInfo):
|
||||
if field.field_info.primary_key:
|
||||
new_class._meta.primary_key = PrimaryKey(name=name, field=field)
|
||||
|
||||
# TODO: Raise exception here, global key prefix required?
|
||||
if not getattr(new_class._meta, 'global_key_prefix', None):
|
||||
new_class._meta.global_key_prefix = getattr(base_meta, "global_key_prefix", "")
|
||||
if not getattr(new_class._meta, 'model_key_prefix', None):
|
||||
# Don't look at the base class for this.
|
||||
new_class._meta.model_key_prefix = f"{new_class.__name__.lower()}"
|
||||
if not getattr(new_class._meta, 'primary_key_pattern', None):
|
||||
new_class._meta.primary_key_pattern = getattr(base_meta, "primary_key_pattern",
|
||||
"{pk}")
|
||||
if not getattr(new_class._meta, 'database', None):
|
||||
new_class._meta.database = getattr(base_meta, "database",
|
||||
redis.Redis(decode_responses=True))
|
||||
if not getattr(new_class._meta, 'primary_key_creator_cls', None):
|
||||
new_class._meta.primary_key_creator_cls = getattr(base_meta, "primary_key_creator_cls",
|
||||
Uuid4PrimaryKey)
|
||||
if not getattr(new_class._meta, 'index_name', None):
|
||||
new_class._meta.index_name = f"{new_class._meta.global_key_prefix}:" \
|
||||
f"{new_class._meta.model_key_prefix}:index"
|
||||
|
||||
return new_class
|
||||
|
||||
|
||||
class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
|
||||
"""
|
||||
TODO: Convert expressions to Redis commands, execute
|
||||
TODO: Key prefix vs. "key pattern" (that's actually the primary key pattern)
|
||||
TODO: Default key prefix is model name lowercase
|
||||
TODO: Build primary key pattern from PK field name, model prefix
|
||||
TODO: Default PK pattern is model name:pk field
|
||||
TODO: Generate RediSearch schema
|
||||
"""
|
||||
pk: Optional[str] = Field(default=None, primary_key=True)
|
||||
|
||||
Meta = DefaultMeta
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
arbitrary_types_allowed = True
|
||||
extra = 'allow'
|
||||
|
||||
Meta = DefaultMeta
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
# Create proxies for each model field so that we can use the field
|
||||
# in queries, like Model.get(Model.field_name == 1)
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
for name, field in cls.__fields__.items():
|
||||
setattr(cls, name, ExpressionProxy(field))
|
||||
# Check if this is our FieldInfo version with extended ORM metadata.
|
||||
if isinstance(field.field_info, FieldInfo):
|
||||
if field.field_info.primary_key:
|
||||
cls.Meta.primary_key = PrimaryKey(name=name, field=field)
|
||||
# TODO: Raise exception here, global key prefix required?
|
||||
if not getattr(cls.Meta, 'global_key_prefix'):
|
||||
cls.Meta.global_key_prefix = ""
|
||||
if not getattr(cls.Meta, 'model_key_prefix'):
|
||||
cls.Meta.model_key_prefix = f"{cls.__name__.lower()}"
|
||||
if not getattr(cls.Meta, 'primary_key_pattern'):
|
||||
cls.Meta.primary_key_pattern = "{pk}"
|
||||
if not getattr(cls.Meta, 'database'):
|
||||
cls.Meta.database = redis.Redis(decode_responses=True)
|
||||
if not getattr(cls.Meta, 'primary_key_creator_cls'):
|
||||
cls.Meta.primary_key_creator_cls = Uuid4PrimaryKey
|
||||
|
||||
def __init__(__pydantic_self__, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
__pydantic_self__.validate_primary_key()
|
||||
|
@ -237,7 +412,7 @@ class RedisModel(BaseModel):
|
|||
@validator("pk", always=True)
|
||||
def validate_pk(cls, v):
|
||||
if not v:
|
||||
v = cls.Meta.primary_key_creator_cls().create_pk()
|
||||
v = cls._meta.primary_key_creator_cls().create_pk()
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
|
@ -254,30 +429,66 @@ class RedisModel(BaseModel):
|
|||
|
||||
@classmethod
|
||||
def make_key(cls, part: str):
|
||||
global_prefix = getattr(cls.Meta, 'global_key_prefix', '').strip(":")
|
||||
model_prefix = getattr(cls.Meta, 'model_key_prefix', '').strip(":")
|
||||
global_prefix = getattr(cls._meta, 'global_key_prefix', '').strip(":")
|
||||
model_prefix = getattr(cls._meta, 'model_key_prefix', '').strip(":")
|
||||
return f"{global_prefix}:{model_prefix}:{part}"
|
||||
|
||||
@classmethod
|
||||
def make_primary_key(cls, pk: Any):
|
||||
"""Return the Redis key for this model."""
|
||||
return cls.make_key(cls.Meta.primary_key_pattern.format(pk=pk))
|
||||
return cls.make_key(cls._meta.primary_key_pattern.format(pk=pk))
|
||||
|
||||
def key(self):
|
||||
"""Return the Redis key for this model."""
|
||||
pk = getattr(self, self.Meta.primary_key.field.name)
|
||||
pk = getattr(self, self._meta.primary_key.field.name)
|
||||
return self.make_primary_key(pk)
|
||||
|
||||
@classmethod
|
||||
def db(cls):
|
||||
return cls.Meta.database
|
||||
return cls._meta.database
|
||||
|
||||
@classmethod
|
||||
def filter(cls, *expressions: Sequence[Expression]):
|
||||
return cls
|
||||
def from_redis(cls, res: Any):
|
||||
import six
|
||||
from six.moves import xrange, zip as izip
|
||||
|
||||
def to_string(s):
|
||||
if isinstance(s, six.string_types):
|
||||
return s
|
||||
elif isinstance(s, six.binary_type):
|
||||
return s.decode('utf-8','ignore')
|
||||
else:
|
||||
return s # Not a string we care about
|
||||
|
||||
docs = []
|
||||
step = 2 # Because the result has content
|
||||
offset = 1
|
||||
|
||||
for i in xrange(1, len(res), step):
|
||||
fields_offset = offset
|
||||
|
||||
fields = dict(
|
||||
dict(izip(map(to_string, res[i + fields_offset][::2]),
|
||||
map(to_string, res[i + fields_offset][1::2])))
|
||||
)
|
||||
|
||||
try:
|
||||
del fields['id']
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
doc = cls(**fields)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@classmethod
|
||||
def exclude(cls, *expressions: Sequence[Expression]):
|
||||
def find(cls, *expressions: Expression):
|
||||
query = FindQuery(expressions=expressions, model=cls)
|
||||
raw_result = query.find()
|
||||
return cls.from_redis(raw_result)
|
||||
|
||||
@classmethod
|
||||
def find_one(cls, *expressions: Sequence[Expression]):
|
||||
return cls
|
||||
|
||||
@classmethod
|
||||
|
@ -296,16 +507,15 @@ class RedisModel(BaseModel):
|
|||
def delete(self):
|
||||
return self.db().delete(self.key())
|
||||
|
||||
# TODO: Protocol
|
||||
@classmethod
|
||||
def get(cls, pk: Any):
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self, *args, **kwargs) -> 'RedisModel':
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def schema(cls):
|
||||
raise NotImplementedError
|
||||
|
||||
class HashModel(RedisModel):
|
||||
|
||||
class HashModel(RedisModel, abc.ABC):
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
|
@ -346,8 +556,34 @@ class HashModel(RedisModel):
|
|||
return ""
|
||||
return val
|
||||
|
||||
@classmethod
|
||||
def schema_for_type(cls, name, typ: Type):
|
||||
if any(issubclass(typ, t) for t in NUMERIC_TYPES):
|
||||
return f"{name} NUMERIC"
|
||||
else:
|
||||
return f"{name} TEXT"
|
||||
|
||||
class JsonModel(RedisModel):
|
||||
@classmethod
|
||||
def schema(cls):
|
||||
hash_prefix = cls.make_key(cls._meta.primary_key_pattern.format(pk=""))
|
||||
schema_prefix = f"ON HASH PREFIX 1 {hash_prefix} SCHEMA"
|
||||
schema_parts = [schema_prefix]
|
||||
for name, field in cls.__fields__.items():
|
||||
_type = field.outer_type_
|
||||
if getattr(field.field_info, 'primary_key', None):
|
||||
if issubclass(_type, str):
|
||||
redisearch_field = f"{name} TAG"
|
||||
else:
|
||||
redisearch_field = cls.schema_for_type(name, _type)
|
||||
schema_parts.append(redisearch_field)
|
||||
else:
|
||||
schema_parts.append(cls.schema_for_type(name, _type))
|
||||
if getattr(field.field_info, 'sortable', False):
|
||||
schema_parts.append("SORTABLE")
|
||||
return " ".join(schema_parts)
|
||||
|
||||
|
||||
class JsonModel(RedisModel, abc.ABC):
|
||||
def save(self, *args, **kwargs) -> 'JsonModel':
|
||||
success = self.db().execute_command('JSON.SET', self.key(), ".", self.json())
|
||||
return success
|
||||
|
|
30
redis_developer/orm/models.py
Normal file
30
redis_developer/orm/models.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
import abc
|
||||
from typing import Optional
|
||||
|
||||
from redis_developer.orm.model import JsonModel, HashModel
|
||||
|
||||
|
||||
class BaseJsonModel(JsonModel, abc.ABC):
|
||||
class Meta:
|
||||
global_key_prefix = "redis-developer"
|
||||
|
||||
|
||||
class BaseHashModel(HashModel, abc.ABC):
|
||||
class Meta:
|
||||
global_key_prefix = "redis-developer"
|
||||
|
||||
|
||||
# class AddressJson(BaseJsonModel):
|
||||
# address_line_1: str
|
||||
# address_line_2: Optional[str]
|
||||
# city: str
|
||||
# country: str
|
||||
# postal_code: str
|
||||
#
|
||||
|
||||
class AddressHash(BaseHashModel):
|
||||
address_line_1: str
|
||||
address_line_2: Optional[str]
|
||||
city: str
|
||||
country: str
|
||||
postal_code: str
|
55
redis_developer/orm/query_iterator.py
Normal file
55
redis_developer/orm/query_iterator.py
Normal file
|
@ -0,0 +1,55 @@
|
|||
from redis_developer.orm.model import Expression
|
||||
|
||||
|
||||
class QueryIterator:
|
||||
"""
|
||||
A lazy iterator that yields results from a RediSearch query.
|
||||
|
||||
Examples:
|
||||
|
||||
results = Model.filter(email == "a@example.com")
|
||||
|
||||
# Consume all results.
|
||||
for r in results:
|
||||
print(r)
|
||||
|
||||
# Consume an item at an index.
|
||||
print(results[100])
|
||||
|
||||
# Consume a slice.
|
||||
print(results[0:100])
|
||||
|
||||
# Alternative notation to consume all items.
|
||||
print(results[0:-1])
|
||||
|
||||
# Specify the batch size:
|
||||
results = Model.filter(email == "a@example.com", batch_size=1000)
|
||||
...
|
||||
"""
|
||||
def __init__(self, client, query, batch_size=100):
|
||||
self.client = client
|
||||
self.query = query
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __iter__(self):
|
||||
pass
|
||||
|
||||
def __getattr__(self, item):
|
||||
"""Support getting a single value or a slice."""
|
||||
|
||||
# TODO: Query mixin?
|
||||
|
||||
def filter(self, *expressions: Expression):
|
||||
pass
|
||||
|
||||
def exclude(self, *expressions: Expression):
|
||||
pass
|
||||
|
||||
def and_(self, *expressions: Expression):
|
||||
pass
|
||||
|
||||
def or_(self, *expressions: Expression):
|
||||
pass
|
||||
|
||||
def not_(self, *expressions: Expression):
|
||||
pass
|
106
redis_developer/orm/query_resolver.py
Normal file
106
redis_developer/orm/query_resolver.py
Normal file
|
@ -0,0 +1,106 @@
|
|||
from collections import Sequence
|
||||
from typing import Any, Dict, Mapping, Union, List
|
||||
|
||||
from redis_developer.orm.model import Expression
|
||||
|
||||
|
||||
class LogicalOperatorForListOfExpressions(Expression):
|
||||
operator: str = ""
|
||||
|
||||
def __init__(self, *expressions: Expression):
|
||||
self.expressions = list(expressions)
|
||||
|
||||
@property
|
||||
def query(self) -> Mapping[str, List[Expression]]:
|
||||
if not self.expressions:
|
||||
raise AttributeError("At least one expression must be provided")
|
||||
# TODO: This needs to return a RediSearch string.
|
||||
# Use the values in each expression object to build the string.
|
||||
# Determine the type of query based on the field (numeric range,
|
||||
# tag field, etc.).
|
||||
return {self.operator: self.expressions}
|
||||
|
||||
|
||||
class Or(LogicalOperatorForListOfExpressions):
|
||||
"""
|
||||
Logical OR query operator
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
class Product(JsonModel):
|
||||
price: float
|
||||
category: str
|
||||
|
||||
Or(Product.price < 10, Product.category == "Sweets")
|
||||
```
|
||||
|
||||
Will return RediSearch query string like:
|
||||
|
||||
```
|
||||
(@price:[-inf 10]) | (@category:{Sweets})
|
||||
```
|
||||
"""
|
||||
|
||||
operator = "|"
|
||||
|
||||
|
||||
class And(LogicalOperatorForListOfExpressions):
|
||||
"""
|
||||
Logical AND query operator
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
class Product(Document):
|
||||
price: float
|
||||
category: str
|
||||
|
||||
And(Product.price < 10, Product.category == "Sweets")
|
||||
```
|
||||
|
||||
Will return a query string like:
|
||||
|
||||
```
|
||||
(@price:[-inf 10]) (@category:{Sweets})
|
||||
```
|
||||
|
||||
Note that in RediSearch, AND is implied with multiple terms.
|
||||
"""
|
||||
|
||||
operator = " "
|
||||
|
||||
|
||||
class Not(LogicalOperatorForListOfExpressions):
|
||||
"""
|
||||
Logical NOT query operator
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
class Product(Document):
|
||||
price: float
|
||||
category: str
|
||||
|
||||
Not(Product.price<10, Product.category=="Sweets")
|
||||
```
|
||||
|
||||
Will return a query string like:
|
||||
|
||||
```
|
||||
-(@price:[-inf 10]) -(@category:{Sweets})
|
||||
```
|
||||
"""
|
||||
@property
|
||||
def query(self):
|
||||
return "-(expression1) -(expression2)"
|
||||
|
||||
|
||||
class QueryResolver:
|
||||
def __init__(self, *expressions: Expression):
|
||||
self.expressions = expressions
|
||||
|
||||
def resolve(self) -> str:
|
||||
"""Resolve expressions to a RediSearch query string."""
|
||||
|
||||
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
31
tests/conftest.py
Normal file
31
tests/conftest.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
import pytest
|
||||
from redis import Redis
|
||||
|
||||
from redis_developer.connections import get_redis_connection
|
||||
from redis_developer.orm.migrations.migrator import Migrator
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def migrations():
|
||||
Migrator().run()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def redis():
|
||||
yield get_redis_connection()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def key_prefix():
|
||||
# TODO
|
||||
yield "redis-developer"
|
||||
|
||||
|
||||
def _delete_test_keys(prefix: str, conn: Redis):
|
||||
for key in conn.scan_iter(f"{prefix}:*"):
|
||||
conn.delete(key)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def delete_test_keys(redis, request, key_prefix):
|
||||
_delete_test_keys(key_prefix, redis)
|
|
@ -1,6 +1,7 @@
|
|||
import abc
|
||||
import decimal
|
||||
import datetime
|
||||
from typing import Optional, List
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import redis
|
||||
|
@ -16,19 +17,11 @@ r = redis.Redis()
|
|||
today = datetime.date.today()
|
||||
|
||||
|
||||
class BaseHashModel(HashModel):
|
||||
class Meta(HashModel.Meta):
|
||||
class BaseHashModel(HashModel, abc.ABC):
|
||||
class Meta:
|
||||
global_key_prefix = "redis-developer"
|
||||
|
||||
|
||||
class Address(BaseHashModel):
|
||||
address_line_1: str
|
||||
address_line_2: Optional[str]
|
||||
city: str
|
||||
country: str
|
||||
postal_code: str
|
||||
|
||||
|
||||
class Order(BaseHashModel):
|
||||
total: decimal.Decimal
|
||||
currency: str
|
||||
|
@ -41,7 +34,7 @@ class Member(BaseHashModel):
|
|||
email: str = Field(unique=True, index=True)
|
||||
join_date: datetime.date
|
||||
|
||||
class Meta(BaseHashModel.Meta):
|
||||
class Meta:
|
||||
model_key_prefix = "member"
|
||||
primary_key_pattern = ""
|
||||
|
||||
|
@ -68,13 +61,6 @@ def test_validates_field():
|
|||
|
||||
# Passes validation
|
||||
def test_validation_passes():
|
||||
address = Address(
|
||||
address_line_1="1 Main St.",
|
||||
city="Happy Town",
|
||||
state="WY",
|
||||
postal_code=11111,
|
||||
country="USA"
|
||||
)
|
||||
member = Member(
|
||||
first_name="Andrew",
|
||||
last_name="Brookins",
|
||||
|
@ -99,6 +85,13 @@ def test_saves_model_and_creates_pk():
|
|||
|
||||
|
||||
def test_raises_error_with_embedded_models():
|
||||
class Address(BaseHashModel):
|
||||
address_line_1: str
|
||||
address_line_2: Optional[str]
|
||||
city: str
|
||||
country: str
|
||||
postal_code: str
|
||||
|
||||
with pytest.raises(RedisModelError):
|
||||
class InvalidMember(BaseHashModel):
|
||||
address: Address
|
||||
|
@ -142,3 +135,51 @@ def test_updates_a_model():
|
|||
|
||||
# Or, affecting multiple model instances with an implicit save:
|
||||
Member.filter(Member.last_name == "Brookins").update(last_name="Sam-Bodden")
|
||||
|
||||
|
||||
def test_exact_match_queries():
|
||||
member1 = Member(
|
||||
first_name="Andrew",
|
||||
last_name="Brookins",
|
||||
email="a@example.com",
|
||||
join_date=today
|
||||
)
|
||||
|
||||
member2 = Member(
|
||||
first_name="Kim",
|
||||
last_name="Brookins",
|
||||
email="k@example.com",
|
||||
join_date=today
|
||||
)
|
||||
member1.save()
|
||||
member2.save()
|
||||
|
||||
actual = Member.find(Member.last_name == "Brookins")
|
||||
assert actual == [member2, member1]
|
||||
|
||||
|
||||
# actual = Member.find(
|
||||
# (Member.last_name == "Brookins") & (~Member.first_name == "Andrew"))
|
||||
# assert actual == [member2]
|
||||
|
||||
# actual = Member.find(~Member.last_name == "Brookins")
|
||||
# assert actual == []
|
||||
|
||||
# actual = Member.find(
|
||||
# (Member.last_name == "Brookins") & (Member.first_name == "Andrew")
|
||||
# | (Member.first_name == "Kim")
|
||||
# )
|
||||
# assert actual == [member1, member2]
|
||||
|
||||
# actual = Member.find_one(Member.last_name == "Brookins")
|
||||
# assert actual == member1
|
||||
|
||||
|
||||
def test_schema():
|
||||
class Address(BaseHashModel):
|
||||
a_string: str
|
||||
an_integer: int
|
||||
a_float: float
|
||||
|
||||
assert Address.schema() == "SCHEMA pk TAG SORTABLE a_string TEXT an_integer NUMERIC " \
|
||||
"a_float NUMERIC"
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import abc
|
||||
import decimal
|
||||
import datetime
|
||||
from typing import Optional, List
|
||||
|
@ -15,8 +16,8 @@ r = redis.Redis()
|
|||
today = datetime.datetime.today()
|
||||
|
||||
|
||||
class BaseJsonModel(JsonModel):
|
||||
class Meta(JsonModel.Meta):
|
||||
class BaseJsonModel(JsonModel, abc.ABC):
|
||||
class Meta:
|
||||
global_key_prefix = "redis-developer"
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue