Merge branch 'main' into asyncio

This commit is contained in:
Andrew Brookins 2021-11-09 15:59:10 -08:00
commit ca6ae7d6e9
47 changed files with 3285 additions and 760 deletions

15
aredis_om/__init__.py Normal file
View file

@ -0,0 +1,15 @@
from .checks import has_redis_json, has_redisearch
from .connections import get_redis_connection
from .model.migrations.migrator import MigrationError, Migrator
from .model.model import (
EmbeddedJsonModel,
Field,
FindQuery,
HashModel,
JsonModel,
NotFoundError,
QueryNotSupportedError,
QuerySyntaxError,
RedisModel,
RedisModelError,
)

28
aredis_om/checks.py Normal file
View file

@ -0,0 +1,28 @@
from functools import lru_cache
from typing import List
from aredis_om.connections import get_redis_connection
@lru_cache(maxsize=None)
async def get_modules(conn) -> List[str]:
modules = await conn.execute_command("module", "list")
return [m[1] for m in modules]
@lru_cache(maxsize=None)
async def has_redis_json(conn=None):
if conn is None:
conn = get_redis_connection()
names = await get_modules(conn)
return b"ReJSON" in names or "ReJSON" in names
@lru_cache(maxsize=None)
async def has_redisearch(conn=None):
if conn is None:
conn = get_redis_connection()
if has_redis_json(conn):
return True
names = await get_modules(conn)
return b"search" in names or "search" in names

22
aredis_om/connections.py Normal file
View file

@ -0,0 +1,22 @@
import os
import aioredis
import dotenv
dotenv.load_dotenv()
URL = os.environ.get("REDIS_OM_URL", None)
def get_redis_connection(**kwargs) -> aioredis.Redis:
# If someone passed in a 'url' parameter, or specified a REDIS_OM_URL
# environment variable, we'll create the Redis client from the URL.
url = kwargs.pop("url", URL)
if url:
return aioredis.Redis.from_url(url, **kwargs)
# Decode from UTF-8 by default
if "decode_responses" not in kwargs:
kwargs["decode_responses"] = True
return aioredis.Redis(**kwargs)

View file

@ -0,0 +1,2 @@
from .migrations.migrator import MigrationError, Migrator
from .model import EmbeddedJsonModel, Field, HashModel, JsonModel, RedisModel

View file

View file

@ -0,0 +1,17 @@
import click
from aredis_om.model.migrations.migrator import Migrator
@click.command()
@click.option("--module", default="aredis_om")
def migrate(module):
migrator = Migrator(module)
if migrator.migrations:
print("Pending migrations:")
for migration in migrator.migrations:
print(migration)
if input("Run migrations? (y/n) ") == "y":
migrator.run()

180
aredis_om/model/encoders.py Normal file
View file

@ -0,0 +1,180 @@
"""
This file adapted from FastAPI's encoders.
Licensed under the MIT License (MIT).
Copyright (c) 2018 Sebastián Ramírez
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
import dataclasses
from collections import defaultdict
from enum import Enum
from pathlib import PurePath
from types import GeneratorType
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from pydantic import BaseModel
from pydantic.json import ENCODERS_BY_TYPE
SetIntStr = Set[Union[int, str]]
DictIntStrAny = Dict[Union[int, str], Any]
def generate_encoders_by_class_tuples(
type_encoder_map: Dict[Any, Callable[[Any], Any]]
) -> Dict[Callable[[Any], Any], Tuple[Any, ...]]:
encoders_by_class_tuples: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict(
tuple
)
for type_, encoder in type_encoder_map.items():
encoders_by_class_tuples[encoder] += (type_,)
return encoders_by_class_tuples
encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE)
def jsonable_encoder(
obj: Any,
include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
custom_encoder: Dict[Any, Callable[[Any], Any]] = {},
sqlalchemy_safe: bool = True,
) -> Any:
if include is not None and not isinstance(include, (set, dict)):
include = set(include)
if exclude is not None and not isinstance(exclude, (set, dict)):
exclude = set(exclude)
if isinstance(obj, BaseModel):
encoder = getattr(obj.__config__, "json_encoders", {})
if custom_encoder:
encoder.update(custom_encoder)
obj_dict = obj.dict(
include=include, # type: ignore # in Pydantic
exclude=exclude, # type: ignore # in Pydantic
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
)
if "__root__" in obj_dict:
obj_dict = obj_dict["__root__"]
return jsonable_encoder(
obj_dict,
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
custom_encoder=encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
if dataclasses.is_dataclass(obj):
return dataclasses.asdict(obj)
if isinstance(obj, Enum):
return obj.value
if isinstance(obj, PurePath):
return str(obj)
if isinstance(obj, (str, int, float, type(None))):
return obj
if isinstance(obj, dict):
encoded_dict = {}
for key, value in obj.items():
if (
(
not sqlalchemy_safe
or (not isinstance(key, str))
or (not key.startswith("_sa"))
)
and (value is not None or not exclude_none)
and ((include and key in include) or not exclude or key not in exclude)
):
encoded_key = jsonable_encoder(
key,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
encoded_value = jsonable_encoder(
value,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
encoded_dict[encoded_key] = encoded_value
return encoded_dict
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
encoded_list = []
for item in obj:
encoded_list.append(
jsonable_encoder(
item,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
)
return encoded_list
if custom_encoder:
if type(obj) in custom_encoder:
return custom_encoder[type(obj)](obj)
else:
for encoder_type, encoder in custom_encoder.items():
if isinstance(obj, encoder_type):
return encoder(obj)
if type(obj) in ENCODERS_BY_TYPE:
return ENCODERS_BY_TYPE[type(obj)](obj)
for encoder, classes_tuple in encoders_by_class_tuples.items():
if isinstance(obj, classes_tuple):
return encoder(obj)
errors: List[Exception] = []
try:
data = dict(obj)
except Exception as e:
errors.append(e)
try:
data = vars(obj)
except Exception as e:
errors.append(e)
raise ValueError(errors)
return jsonable_encoder(
data,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)

View file

View file

@ -0,0 +1,154 @@
import hashlib
import logging
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional
from aioredis import Redis, ResponseError
from aredis_om.model.model import model_registry
log = logging.getLogger(__name__)
import importlib # noqa: E402
import pkgutil # noqa: E402
class MigrationError(Exception):
pass
def import_submodules(root_module_name: str):
"""Import all submodules of a module, recursively."""
# TODO: Call this without specifying a module name, to import everything?
root_module = importlib.import_module(root_module_name)
if not hasattr(root_module, "__path__"):
raise MigrationError(
"The root module must be a Python package. "
f"You specified: {root_module_name}"
)
for loader, module_name, is_pkg in pkgutil.walk_packages(
root_module.__path__, root_module.__name__ + "." # type: ignore
):
importlib.import_module(module_name)
def schema_hash_key(index_name):
return f"{index_name}:hash"
async def create_index(redis: Redis, index_name, schema, current_hash):
try:
await redis.execute_command(f"ft.info {index_name}")
except ResponseError:
await redis.execute_command(f"ft.create {index_name} {schema}")
await redis.set(schema_hash_key(index_name), current_hash)
else:
log.info("Index already exists, skipping. Index hash: %s", index_name)
class MigrationAction(Enum):
CREATE = 2
DROP = 1
@dataclass
class IndexMigration:
model_name: str
index_name: str
schema: str
hash: str
action: MigrationAction
redis: Redis
previous_hash: Optional[str] = None
async def run(self):
if self.action is MigrationAction.CREATE:
await self.create()
elif self.action is MigrationAction.DROP:
await self.drop()
async def create(self):
try:
await create_index(self.redis, self.index_name, self.schema, self.hash)
except ResponseError:
log.info("Index already exists: %s", self.index_name)
async def drop(self):
try:
await self.redis.execute_command(f"FT.DROPINDEX {self.index_name}")
except ResponseError:
log.info("Index does not exist: %s", self.index_name)
class Migrator:
def __init__(self, redis: Redis, module=None):
self.module = module
self.migrations: List[IndexMigration] = []
self.redis = redis
async def run(self):
# Try to load any modules found under the given path or module name.
if self.module:
import_submodules(self.module)
for name, cls in model_registry.items():
hash_key = schema_hash_key(cls.Meta.index_name)
try:
schema = cls.redisearch_schema()
except NotImplementedError:
log.info("Skipping migrations for %s", name)
continue
current_hash = hashlib.sha1(schema.encode("utf-8")).hexdigest() # nosec
try:
await self.redis.execute_command("ft.info", cls.Meta.index_name)
except ResponseError:
self.migrations.append(
IndexMigration(
name,
cls.Meta.index_name,
schema,
current_hash,
MigrationAction.CREATE,
self.redis,
)
)
continue
stored_hash = await self.redis.get(hash_key)
schema_out_of_date = current_hash != stored_hash
if schema_out_of_date:
# TODO: Switch out schema with an alias to avoid downtime -- separate migration?
self.migrations.append(
IndexMigration(
name,
cls.Meta.index_name,
schema,
current_hash,
MigrationAction.DROP,
self.redis,
stored_hash,
)
)
self.migrations.append(
IndexMigration(
name,
cls.Meta.index_name,
schema,
current_hash,
MigrationAction.CREATE,
self.redis,
stored_hash,
)
)
# TODO: Migration history
# TODO: Dry run with output
for migration in self.migrations:
await migration.run()

1656
aredis_om/model/model.py Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,104 @@
from typing import List, Mapping
from aredis_om.model.model import Expression
class LogicalOperatorForListOfExpressions(Expression):
operator: str = ""
def __init__(self, *expressions: Expression):
self.expressions = list(expressions)
@property
def query(self) -> Mapping[str, List[Expression]]:
if not self.expressions:
raise AttributeError("At least one expression must be provided")
# TODO: This needs to return a RediSearch string.
# Use the values in each expression object to build the string.
# Determine the type of query based on the field (numeric range,
# tag field, etc.).
return {self.operator: self.expressions}
class Or(LogicalOperatorForListOfExpressions):
"""
Logical OR query operator
Example:
```python
class Product(JsonModel):
price: float
category: str
Or(Product.price < 10, Product.category == "Sweets")
```
Will return RediSearch query string like:
```
(@price:[-inf 10]) | (@category:{Sweets})
```
"""
operator = "|"
class And(LogicalOperatorForListOfExpressions):
"""
Logical AND query operator
Example:
```python
class Product(Document):
price: float
category: str
And(Product.price < 10, Product.category == "Sweets")
```
Will return a query string like:
```
(@price:[-inf 10]) (@category:{Sweets})
```
Note that in RediSearch, AND is implied with multiple terms.
"""
operator = " "
class Not(LogicalOperatorForListOfExpressions):
"""
Logical NOT query operator
Example:
```python
class Product(Document):
price: float
category: str
Not(Product.price<10, Product.category=="Sweets")
```
Will return a query string like:
```
-(@price:[-inf 10]) -(@category:{Sweets})
```
"""
@property
def query(self):
return "-(expression1) -(expression2)"
class QueryResolver:
def __init__(self, *expressions: Expression):
self.expressions = expressions
def resolve(self) -> str:
"""Resolve expressions to a RediSearch query string."""

View file

@ -0,0 +1,75 @@
"""
This code adapted from the library "pptree," Copyright (c) 2017 Clément Michard
and released under the MIT license: https://github.com/clemtoy/pptree
"""
import io
def render_tree(
current_node,
nameattr="name",
left_child="left",
right_child="right",
indent="",
last="updown",
buffer=None,
):
"""Print a tree-like structure, `current_node`.
This is a mostly-direct-copy of the print_tree() function from the ppbtree
module of the pptree library, but instead of printing to standard out, we
write to a StringIO buffer, then use that buffer to accumulate written lines
during recursive calls to render_tree().
"""
if buffer is None:
buffer = io.StringIO()
if hasattr(current_node, nameattr):
name = lambda node: getattr(node, nameattr) # noqa: E731
else:
name = lambda node: str(node) # noqa: E731
up = getattr(current_node, left_child, None)
down = getattr(current_node, right_child, None)
if up is not None:
next_last = "up"
next_indent = "{0}{1}{2}".format(
indent, " " if "up" in last else "|", " " * len(str(name(current_node)))
)
render_tree(
up, nameattr, left_child, right_child, next_indent, next_last, buffer
)
if last == "up":
start_shape = ""
elif last == "down":
start_shape = ""
elif last == "updown":
start_shape = " "
else:
start_shape = ""
if up is not None and down is not None:
end_shape = ""
elif up:
end_shape = ""
elif down:
end_shape = ""
else:
end_shape = ""
print(
"{0}{1}{2}{3}".format(indent, start_shape, name(current_node), end_shape),
file=buffer,
)
if down is not None:
next_last = "down"
next_indent = "{0}{1}{2}".format(
indent, " " if "down" in last else "|", " " * len(str(name(current_node)))
)
render_tree(
down, nameattr, left_child, right_child, next_indent, next_last, buffer
)
return f"\n{buffer.getvalue()}"

View file

@ -0,0 +1,25 @@
import re
from typing import Optional, Pattern
class TokenEscaper:
"""
Escape punctuation within an input string.
"""
# Characters that RediSearch requires us to escape during queries.
# Source: https://oss.redis.com/redisearch/Escaping/#the_rules_of_text_field_tokenization
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\ ]"
def __init__(self, escape_chars_re: Optional[Pattern] = None):
if escape_chars_re:
self.escaped_chars_re = escape_chars_re
else:
self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS)
def escape(self, value: str) -> str:
def escape_symbol(match):
value = match.group(0)
return f"\\{value}"
return self.escaped_chars_re.sub(escape_symbol, value)

41
aredis_om/unasync_util.py Normal file
View file

@ -0,0 +1,41 @@
"""Set of utility functions for unasync that transform into sync counterparts cleanly"""
import inspect
_original_next = next
def is_async_mode():
"""Tests if we're in the async part of the code or not"""
async def f():
"""Unasync transforms async functions in sync functions"""
return None
obj = f()
if obj is None:
return False
else:
obj.close() # prevent unawaited coroutine warning
return True
ASYNC_MODE = is_async_mode()
async def anext(x):
return await x.__anext__()
async def await_if_coro(x):
if inspect.iscoroutine(x):
return await x
return x
next = _original_next
def return_non_coro(x):
return x