Merge branch 'main' into asyncio
This commit is contained in:
commit
ca6ae7d6e9
47 changed files with 3285 additions and 760 deletions
15
aredis_om/__init__.py
Normal file
15
aredis_om/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
from .checks import has_redis_json, has_redisearch
|
||||
from .connections import get_redis_connection
|
||||
from .model.migrations.migrator import MigrationError, Migrator
|
||||
from .model.model import (
|
||||
EmbeddedJsonModel,
|
||||
Field,
|
||||
FindQuery,
|
||||
HashModel,
|
||||
JsonModel,
|
||||
NotFoundError,
|
||||
QueryNotSupportedError,
|
||||
QuerySyntaxError,
|
||||
RedisModel,
|
||||
RedisModelError,
|
||||
)
|
28
aredis_om/checks.py
Normal file
28
aredis_om/checks.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
from functools import lru_cache
|
||||
from typing import List
|
||||
|
||||
from aredis_om.connections import get_redis_connection
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
async def get_modules(conn) -> List[str]:
|
||||
modules = await conn.execute_command("module", "list")
|
||||
return [m[1] for m in modules]
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
async def has_redis_json(conn=None):
|
||||
if conn is None:
|
||||
conn = get_redis_connection()
|
||||
names = await get_modules(conn)
|
||||
return b"ReJSON" in names or "ReJSON" in names
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
async def has_redisearch(conn=None):
|
||||
if conn is None:
|
||||
conn = get_redis_connection()
|
||||
if has_redis_json(conn):
|
||||
return True
|
||||
names = await get_modules(conn)
|
||||
return b"search" in names or "search" in names
|
22
aredis_om/connections.py
Normal file
22
aredis_om/connections.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
import os
|
||||
|
||||
import aioredis
|
||||
import dotenv
|
||||
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
URL = os.environ.get("REDIS_OM_URL", None)
|
||||
|
||||
|
||||
def get_redis_connection(**kwargs) -> aioredis.Redis:
|
||||
# If someone passed in a 'url' parameter, or specified a REDIS_OM_URL
|
||||
# environment variable, we'll create the Redis client from the URL.
|
||||
url = kwargs.pop("url", URL)
|
||||
if url:
|
||||
return aioredis.Redis.from_url(url, **kwargs)
|
||||
|
||||
# Decode from UTF-8 by default
|
||||
if "decode_responses" not in kwargs:
|
||||
kwargs["decode_responses"] = True
|
||||
return aioredis.Redis(**kwargs)
|
2
aredis_om/model/__init__.py
Normal file
2
aredis_om/model/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
from .migrations.migrator import MigrationError, Migrator
|
||||
from .model import EmbeddedJsonModel, Field, HashModel, JsonModel, RedisModel
|
0
aredis_om/model/cli/__init__.py
Normal file
0
aredis_om/model/cli/__init__.py
Normal file
17
aredis_om/model/cli/migrate.py
Normal file
17
aredis_om/model/cli/migrate.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
import click
|
||||
|
||||
from aredis_om.model.migrations.migrator import Migrator
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--module", default="aredis_om")
|
||||
def migrate(module):
|
||||
migrator = Migrator(module)
|
||||
|
||||
if migrator.migrations:
|
||||
print("Pending migrations:")
|
||||
for migration in migrator.migrations:
|
||||
print(migration)
|
||||
|
||||
if input("Run migrations? (y/n) ") == "y":
|
||||
migrator.run()
|
180
aredis_om/model/encoders.py
Normal file
180
aredis_om/model/encoders.py
Normal file
|
@ -0,0 +1,180 @@
|
|||
"""
|
||||
This file adapted from FastAPI's encoders.
|
||||
|
||||
Licensed under the MIT License (MIT).
|
||||
|
||||
Copyright (c) 2018 Sebastián Ramírez
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from pathlib import PurePath
|
||||
from types import GeneratorType
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.json import ENCODERS_BY_TYPE
|
||||
|
||||
|
||||
SetIntStr = Set[Union[int, str]]
|
||||
DictIntStrAny = Dict[Union[int, str], Any]
|
||||
|
||||
|
||||
def generate_encoders_by_class_tuples(
|
||||
type_encoder_map: Dict[Any, Callable[[Any], Any]]
|
||||
) -> Dict[Callable[[Any], Any], Tuple[Any, ...]]:
|
||||
encoders_by_class_tuples: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict(
|
||||
tuple
|
||||
)
|
||||
for type_, encoder in type_encoder_map.items():
|
||||
encoders_by_class_tuples[encoder] += (type_,)
|
||||
return encoders_by_class_tuples
|
||||
|
||||
|
||||
encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE)
|
||||
|
||||
|
||||
def jsonable_encoder(
|
||||
obj: Any,
|
||||
include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
|
||||
exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
|
||||
by_alias: bool = True,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
custom_encoder: Dict[Any, Callable[[Any], Any]] = {},
|
||||
sqlalchemy_safe: bool = True,
|
||||
) -> Any:
|
||||
if include is not None and not isinstance(include, (set, dict)):
|
||||
include = set(include)
|
||||
if exclude is not None and not isinstance(exclude, (set, dict)):
|
||||
exclude = set(exclude)
|
||||
if isinstance(obj, BaseModel):
|
||||
encoder = getattr(obj.__config__, "json_encoders", {})
|
||||
if custom_encoder:
|
||||
encoder.update(custom_encoder)
|
||||
obj_dict = obj.dict(
|
||||
include=include, # type: ignore # in Pydantic
|
||||
exclude=exclude, # type: ignore # in Pydantic
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_none=exclude_none,
|
||||
exclude_defaults=exclude_defaults,
|
||||
)
|
||||
if "__root__" in obj_dict:
|
||||
obj_dict = obj_dict["__root__"]
|
||||
return jsonable_encoder(
|
||||
obj_dict,
|
||||
exclude_none=exclude_none,
|
||||
exclude_defaults=exclude_defaults,
|
||||
custom_encoder=encoder,
|
||||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
||||
if dataclasses.is_dataclass(obj):
|
||||
return dataclasses.asdict(obj)
|
||||
if isinstance(obj, Enum):
|
||||
return obj.value
|
||||
if isinstance(obj, PurePath):
|
||||
return str(obj)
|
||||
if isinstance(obj, (str, int, float, type(None))):
|
||||
return obj
|
||||
if isinstance(obj, dict):
|
||||
encoded_dict = {}
|
||||
for key, value in obj.items():
|
||||
if (
|
||||
(
|
||||
not sqlalchemy_safe
|
||||
or (not isinstance(key, str))
|
||||
or (not key.startswith("_sa"))
|
||||
)
|
||||
and (value is not None or not exclude_none)
|
||||
and ((include and key in include) or not exclude or key not in exclude)
|
||||
):
|
||||
encoded_key = jsonable_encoder(
|
||||
key,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_none=exclude_none,
|
||||
custom_encoder=custom_encoder,
|
||||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
||||
encoded_value = jsonable_encoder(
|
||||
value,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_none=exclude_none,
|
||||
custom_encoder=custom_encoder,
|
||||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
||||
encoded_dict[encoded_key] = encoded_value
|
||||
return encoded_dict
|
||||
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
|
||||
encoded_list = []
|
||||
for item in obj:
|
||||
encoded_list.append(
|
||||
jsonable_encoder(
|
||||
item,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
custom_encoder=custom_encoder,
|
||||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
||||
)
|
||||
return encoded_list
|
||||
|
||||
if custom_encoder:
|
||||
if type(obj) in custom_encoder:
|
||||
return custom_encoder[type(obj)](obj)
|
||||
else:
|
||||
for encoder_type, encoder in custom_encoder.items():
|
||||
if isinstance(obj, encoder_type):
|
||||
return encoder(obj)
|
||||
|
||||
if type(obj) in ENCODERS_BY_TYPE:
|
||||
return ENCODERS_BY_TYPE[type(obj)](obj)
|
||||
for encoder, classes_tuple in encoders_by_class_tuples.items():
|
||||
if isinstance(obj, classes_tuple):
|
||||
return encoder(obj)
|
||||
|
||||
errors: List[Exception] = []
|
||||
try:
|
||||
data = dict(obj)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
try:
|
||||
data = vars(obj)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
raise ValueError(errors)
|
||||
return jsonable_encoder(
|
||||
data,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
custom_encoder=custom_encoder,
|
||||
sqlalchemy_safe=sqlalchemy_safe,
|
||||
)
|
0
aredis_om/model/migrations/__init__.py
Normal file
0
aredis_om/model/migrations/__init__.py
Normal file
154
aredis_om/model/migrations/migrator.py
Normal file
154
aredis_om/model/migrations/migrator.py
Normal file
|
@ -0,0 +1,154 @@
|
|||
import hashlib
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from aioredis import Redis, ResponseError
|
||||
|
||||
from aredis_om.model.model import model_registry
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
import importlib # noqa: E402
|
||||
import pkgutil # noqa: E402
|
||||
|
||||
|
||||
class MigrationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def import_submodules(root_module_name: str):
|
||||
"""Import all submodules of a module, recursively."""
|
||||
# TODO: Call this without specifying a module name, to import everything?
|
||||
root_module = importlib.import_module(root_module_name)
|
||||
|
||||
if not hasattr(root_module, "__path__"):
|
||||
raise MigrationError(
|
||||
"The root module must be a Python package. "
|
||||
f"You specified: {root_module_name}"
|
||||
)
|
||||
|
||||
for loader, module_name, is_pkg in pkgutil.walk_packages(
|
||||
root_module.__path__, root_module.__name__ + "." # type: ignore
|
||||
):
|
||||
importlib.import_module(module_name)
|
||||
|
||||
|
||||
def schema_hash_key(index_name):
|
||||
return f"{index_name}:hash"
|
||||
|
||||
|
||||
async def create_index(redis: Redis, index_name, schema, current_hash):
|
||||
try:
|
||||
await redis.execute_command(f"ft.info {index_name}")
|
||||
except ResponseError:
|
||||
await redis.execute_command(f"ft.create {index_name} {schema}")
|
||||
await redis.set(schema_hash_key(index_name), current_hash)
|
||||
else:
|
||||
log.info("Index already exists, skipping. Index hash: %s", index_name)
|
||||
|
||||
|
||||
class MigrationAction(Enum):
|
||||
CREATE = 2
|
||||
DROP = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexMigration:
|
||||
model_name: str
|
||||
index_name: str
|
||||
schema: str
|
||||
hash: str
|
||||
action: MigrationAction
|
||||
redis: Redis
|
||||
previous_hash: Optional[str] = None
|
||||
|
||||
async def run(self):
|
||||
if self.action is MigrationAction.CREATE:
|
||||
await self.create()
|
||||
elif self.action is MigrationAction.DROP:
|
||||
await self.drop()
|
||||
|
||||
async def create(self):
|
||||
try:
|
||||
await create_index(self.redis, self.index_name, self.schema, self.hash)
|
||||
except ResponseError:
|
||||
log.info("Index already exists: %s", self.index_name)
|
||||
|
||||
async def drop(self):
|
||||
try:
|
||||
await self.redis.execute_command(f"FT.DROPINDEX {self.index_name}")
|
||||
except ResponseError:
|
||||
log.info("Index does not exist: %s", self.index_name)
|
||||
|
||||
|
||||
class Migrator:
|
||||
def __init__(self, redis: Redis, module=None):
|
||||
self.module = module
|
||||
self.migrations: List[IndexMigration] = []
|
||||
self.redis = redis
|
||||
|
||||
async def run(self):
|
||||
# Try to load any modules found under the given path or module name.
|
||||
if self.module:
|
||||
import_submodules(self.module)
|
||||
|
||||
for name, cls in model_registry.items():
|
||||
hash_key = schema_hash_key(cls.Meta.index_name)
|
||||
try:
|
||||
schema = cls.redisearch_schema()
|
||||
except NotImplementedError:
|
||||
log.info("Skipping migrations for %s", name)
|
||||
continue
|
||||
current_hash = hashlib.sha1(schema.encode("utf-8")).hexdigest() # nosec
|
||||
|
||||
try:
|
||||
await self.redis.execute_command("ft.info", cls.Meta.index_name)
|
||||
except ResponseError:
|
||||
self.migrations.append(
|
||||
IndexMigration(
|
||||
name,
|
||||
cls.Meta.index_name,
|
||||
schema,
|
||||
current_hash,
|
||||
MigrationAction.CREATE,
|
||||
self.redis,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
stored_hash = await self.redis.get(hash_key)
|
||||
schema_out_of_date = current_hash != stored_hash
|
||||
|
||||
if schema_out_of_date:
|
||||
# TODO: Switch out schema with an alias to avoid downtime -- separate migration?
|
||||
self.migrations.append(
|
||||
IndexMigration(
|
||||
name,
|
||||
cls.Meta.index_name,
|
||||
schema,
|
||||
current_hash,
|
||||
MigrationAction.DROP,
|
||||
self.redis,
|
||||
stored_hash,
|
||||
)
|
||||
)
|
||||
self.migrations.append(
|
||||
IndexMigration(
|
||||
name,
|
||||
cls.Meta.index_name,
|
||||
schema,
|
||||
current_hash,
|
||||
MigrationAction.CREATE,
|
||||
self.redis,
|
||||
stored_hash,
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: Migration history
|
||||
# TODO: Dry run with output
|
||||
for migration in self.migrations:
|
||||
await migration.run()
|
1656
aredis_om/model/model.py
Normal file
1656
aredis_om/model/model.py
Normal file
File diff suppressed because it is too large
Load diff
104
aredis_om/model/query_resolver.py
Normal file
104
aredis_om/model/query_resolver.py
Normal file
|
@ -0,0 +1,104 @@
|
|||
from typing import List, Mapping
|
||||
|
||||
from aredis_om.model.model import Expression
|
||||
|
||||
|
||||
class LogicalOperatorForListOfExpressions(Expression):
|
||||
operator: str = ""
|
||||
|
||||
def __init__(self, *expressions: Expression):
|
||||
self.expressions = list(expressions)
|
||||
|
||||
@property
|
||||
def query(self) -> Mapping[str, List[Expression]]:
|
||||
if not self.expressions:
|
||||
raise AttributeError("At least one expression must be provided")
|
||||
# TODO: This needs to return a RediSearch string.
|
||||
# Use the values in each expression object to build the string.
|
||||
# Determine the type of query based on the field (numeric range,
|
||||
# tag field, etc.).
|
||||
return {self.operator: self.expressions}
|
||||
|
||||
|
||||
class Or(LogicalOperatorForListOfExpressions):
|
||||
"""
|
||||
Logical OR query operator
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
class Product(JsonModel):
|
||||
price: float
|
||||
category: str
|
||||
|
||||
Or(Product.price < 10, Product.category == "Sweets")
|
||||
```
|
||||
|
||||
Will return RediSearch query string like:
|
||||
|
||||
```
|
||||
(@price:[-inf 10]) | (@category:{Sweets})
|
||||
```
|
||||
"""
|
||||
|
||||
operator = "|"
|
||||
|
||||
|
||||
class And(LogicalOperatorForListOfExpressions):
|
||||
"""
|
||||
Logical AND query operator
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
class Product(Document):
|
||||
price: float
|
||||
category: str
|
||||
|
||||
And(Product.price < 10, Product.category == "Sweets")
|
||||
```
|
||||
|
||||
Will return a query string like:
|
||||
|
||||
```
|
||||
(@price:[-inf 10]) (@category:{Sweets})
|
||||
```
|
||||
|
||||
Note that in RediSearch, AND is implied with multiple terms.
|
||||
"""
|
||||
|
||||
operator = " "
|
||||
|
||||
|
||||
class Not(LogicalOperatorForListOfExpressions):
|
||||
"""
|
||||
Logical NOT query operator
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
class Product(Document):
|
||||
price: float
|
||||
category: str
|
||||
|
||||
Not(Product.price<10, Product.category=="Sweets")
|
||||
```
|
||||
|
||||
Will return a query string like:
|
||||
|
||||
```
|
||||
-(@price:[-inf 10]) -(@category:{Sweets})
|
||||
```
|
||||
"""
|
||||
|
||||
@property
|
||||
def query(self):
|
||||
return "-(expression1) -(expression2)"
|
||||
|
||||
|
||||
class QueryResolver:
|
||||
def __init__(self, *expressions: Expression):
|
||||
self.expressions = expressions
|
||||
|
||||
def resolve(self) -> str:
|
||||
"""Resolve expressions to a RediSearch query string."""
|
75
aredis_om/model/render_tree.py
Normal file
75
aredis_om/model/render_tree.py
Normal file
|
@ -0,0 +1,75 @@
|
|||
"""
|
||||
This code adapted from the library "pptree," Copyright (c) 2017 Clément Michard
|
||||
and released under the MIT license: https://github.com/clemtoy/pptree
|
||||
"""
|
||||
import io
|
||||
|
||||
|
||||
def render_tree(
|
||||
current_node,
|
||||
nameattr="name",
|
||||
left_child="left",
|
||||
right_child="right",
|
||||
indent="",
|
||||
last="updown",
|
||||
buffer=None,
|
||||
):
|
||||
"""Print a tree-like structure, `current_node`.
|
||||
|
||||
This is a mostly-direct-copy of the print_tree() function from the ppbtree
|
||||
module of the pptree library, but instead of printing to standard out, we
|
||||
write to a StringIO buffer, then use that buffer to accumulate written lines
|
||||
during recursive calls to render_tree().
|
||||
"""
|
||||
if buffer is None:
|
||||
buffer = io.StringIO()
|
||||
if hasattr(current_node, nameattr):
|
||||
name = lambda node: getattr(node, nameattr) # noqa: E731
|
||||
else:
|
||||
name = lambda node: str(node) # noqa: E731
|
||||
|
||||
up = getattr(current_node, left_child, None)
|
||||
down = getattr(current_node, right_child, None)
|
||||
|
||||
if up is not None:
|
||||
next_last = "up"
|
||||
next_indent = "{0}{1}{2}".format(
|
||||
indent, " " if "up" in last else "|", " " * len(str(name(current_node)))
|
||||
)
|
||||
render_tree(
|
||||
up, nameattr, left_child, right_child, next_indent, next_last, buffer
|
||||
)
|
||||
|
||||
if last == "up":
|
||||
start_shape = "┌"
|
||||
elif last == "down":
|
||||
start_shape = "└"
|
||||
elif last == "updown":
|
||||
start_shape = " "
|
||||
else:
|
||||
start_shape = "├"
|
||||
|
||||
if up is not None and down is not None:
|
||||
end_shape = "┤"
|
||||
elif up:
|
||||
end_shape = "┘"
|
||||
elif down:
|
||||
end_shape = "┐"
|
||||
else:
|
||||
end_shape = ""
|
||||
|
||||
print(
|
||||
"{0}{1}{2}{3}".format(indent, start_shape, name(current_node), end_shape),
|
||||
file=buffer,
|
||||
)
|
||||
|
||||
if down is not None:
|
||||
next_last = "down"
|
||||
next_indent = "{0}{1}{2}".format(
|
||||
indent, " " if "down" in last else "|", " " * len(str(name(current_node)))
|
||||
)
|
||||
render_tree(
|
||||
down, nameattr, left_child, right_child, next_indent, next_last, buffer
|
||||
)
|
||||
|
||||
return f"\n{buffer.getvalue()}"
|
25
aredis_om/model/token_escaper.py
Normal file
25
aredis_om/model/token_escaper.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
import re
|
||||
from typing import Optional, Pattern
|
||||
|
||||
|
||||
class TokenEscaper:
|
||||
"""
|
||||
Escape punctuation within an input string.
|
||||
"""
|
||||
|
||||
# Characters that RediSearch requires us to escape during queries.
|
||||
# Source: https://oss.redis.com/redisearch/Escaping/#the_rules_of_text_field_tokenization
|
||||
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\ ]"
|
||||
|
||||
def __init__(self, escape_chars_re: Optional[Pattern] = None):
|
||||
if escape_chars_re:
|
||||
self.escaped_chars_re = escape_chars_re
|
||||
else:
|
||||
self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS)
|
||||
|
||||
def escape(self, value: str) -> str:
|
||||
def escape_symbol(match):
|
||||
value = match.group(0)
|
||||
return f"\\{value}"
|
||||
|
||||
return self.escaped_chars_re.sub(escape_symbol, value)
|
41
aredis_om/unasync_util.py
Normal file
41
aredis_om/unasync_util.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
"""Set of utility functions for unasync that transform into sync counterparts cleanly"""
|
||||
|
||||
import inspect
|
||||
|
||||
|
||||
_original_next = next
|
||||
|
||||
|
||||
def is_async_mode():
|
||||
"""Tests if we're in the async part of the code or not"""
|
||||
|
||||
async def f():
|
||||
"""Unasync transforms async functions in sync functions"""
|
||||
return None
|
||||
|
||||
obj = f()
|
||||
if obj is None:
|
||||
return False
|
||||
else:
|
||||
obj.close() # prevent unawaited coroutine warning
|
||||
return True
|
||||
|
||||
|
||||
ASYNC_MODE = is_async_mode()
|
||||
|
||||
|
||||
async def anext(x):
|
||||
return await x.__anext__()
|
||||
|
||||
|
||||
async def await_if_coro(x):
|
||||
if inspect.iscoroutine(x):
|
||||
return await x
|
||||
return x
|
||||
|
||||
|
||||
next = _original_next
|
||||
|
||||
|
||||
def return_non_coro(x):
|
||||
return x
|
Loading…
Add table
Add a link
Reference in a new issue