Rename to redis_om

This commit is contained in:
Andrew Brookins 2021-10-22 06:36:15 -07:00
parent 9b18dae2eb
commit c9967b0d40
19 changed files with 29 additions and 29 deletions

0
redis_om/__init__.py Normal file
View file

22
redis_om/connections.py Normal file
View file

@ -0,0 +1,22 @@
import os
import dotenv
import redis
dotenv.load_dotenv()
URL = os.environ.get("REDIS_OM_URL", None)
def get_redis_connection(**kwargs) -> redis.Redis:
# If someone passed in a 'url' parameter, or specified a REDIS_OM_URL
# environment variable, we'll create the Redis client from the URL.
url = kwargs.pop("url", URL)
if url:
return redis.from_url(url, **kwargs)
# Decode from UTF-8 by default
if "decode_responses" not in kwargs:
kwargs["decode_responses"] = True
return redis.Redis(**kwargs)

View file

@ -0,0 +1 @@
from .model import EmbeddedJsonModel, Field, HashModel, JsonModel, RedisModel

View file

View file

@ -0,0 +1,17 @@
import click
from redis_om.model.migrations.migrator import Migrator
@click.command()
@click.option("--module", default="redis_om")
def migrate(module):
migrator = Migrator(module)
if migrator.migrations:
print("Pending migrations:")
for migration in migrator.migrations:
print(migration)
if input("Run migrations? (y/n) ") == "y":
migrator.run()

180
redis_om/model/encoders.py Normal file
View file

@ -0,0 +1,180 @@
"""
This file adapted from FastAPI's encoders.
Licensed under the MIT License (MIT).
Copyright (c) 2018 Sebastián Ramírez
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
import dataclasses
from collections import defaultdict
from enum import Enum
from pathlib import PurePath
from types import GeneratorType
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from pydantic import BaseModel
from pydantic.json import ENCODERS_BY_TYPE
SetIntStr = Set[Union[int, str]]
DictIntStrAny = Dict[Union[int, str], Any]
def generate_encoders_by_class_tuples(
type_encoder_map: Dict[Any, Callable[[Any], Any]]
) -> Dict[Callable[[Any], Any], Tuple[Any, ...]]:
encoders_by_class_tuples: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict(
tuple
)
for type_, encoder in type_encoder_map.items():
encoders_by_class_tuples[encoder] += (type_,)
return encoders_by_class_tuples
encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE)
def jsonable_encoder(
obj: Any,
include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
custom_encoder: Dict[Any, Callable[[Any], Any]] = {},
sqlalchemy_safe: bool = True,
) -> Any:
if include is not None and not isinstance(include, (set, dict)):
include = set(include)
if exclude is not None and not isinstance(exclude, (set, dict)):
exclude = set(exclude)
if isinstance(obj, BaseModel):
encoder = getattr(obj.__config__, "json_encoders", {})
if custom_encoder:
encoder.update(custom_encoder)
obj_dict = obj.dict(
include=include, # type: ignore # in Pydantic
exclude=exclude, # type: ignore # in Pydantic
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
)
if "__root__" in obj_dict:
obj_dict = obj_dict["__root__"]
return jsonable_encoder(
obj_dict,
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
custom_encoder=encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
if dataclasses.is_dataclass(obj):
return dataclasses.asdict(obj)
if isinstance(obj, Enum):
return obj.value
if isinstance(obj, PurePath):
return str(obj)
if isinstance(obj, (str, int, float, type(None))):
return obj
if isinstance(obj, dict):
encoded_dict = {}
for key, value in obj.items():
if (
(
not sqlalchemy_safe
or (not isinstance(key, str))
or (not key.startswith("_sa"))
)
and (value is not None or not exclude_none)
and ((include and key in include) or not exclude or key not in exclude)
):
encoded_key = jsonable_encoder(
key,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
encoded_value = jsonable_encoder(
value,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
encoded_dict[encoded_key] = encoded_value
return encoded_dict
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
encoded_list = []
for item in obj:
encoded_list.append(
jsonable_encoder(
item,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
)
return encoded_list
if custom_encoder:
if type(obj) in custom_encoder:
return custom_encoder[type(obj)](obj)
else:
for encoder_type, encoder in custom_encoder.items():
if isinstance(obj, encoder_type):
return encoder(obj)
if type(obj) in ENCODERS_BY_TYPE:
return ENCODERS_BY_TYPE[type(obj)](obj)
for encoder, classes_tuple in encoders_by_class_tuples.items():
if isinstance(obj, classes_tuple):
return encoder(obj)
errors: List[Exception] = []
try:
data = dict(obj)
except Exception as e:
errors.append(e)
try:
data = vars(obj)
except Exception as e:
errors.append(e)
raise ValueError(errors)
return jsonable_encoder(
data,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)

View file

View file

@ -0,0 +1,150 @@
import hashlib
import logging
from dataclasses import dataclass
from enum import Enum
from typing import Optional
from redis import ResponseError
from redis_om.connections import get_redis_connection
from redis_om.model.model import model_registry
redis = get_redis_connection()
log = logging.getLogger(__name__)
import importlib # noqa: E402
import pkgutil # noqa: E402
class MigrationError(Exception):
pass
def import_submodules(root_module_name: str):
"""Import all submodules of a module, recursively."""
# TODO: Call this without specifying a module name, to import everything?
root_module = importlib.import_module(root_module_name)
if not hasattr(root_module, "__path__"):
raise MigrationError(
"The root module must be a Python package. "
f"You specified: {root_module_name}"
)
for loader, module_name, is_pkg in pkgutil.walk_packages(
root_module.__path__, root_module.__name__ + "." # type: ignore
):
importlib.import_module(module_name)
def schema_hash_key(index_name):
return f"{index_name}:hash"
def create_index(index_name, schema, current_hash):
try:
redis.execute_command(f"ft.info {index_name}")
except ResponseError:
redis.execute_command(f"ft.create {index_name} {schema}")
redis.set(schema_hash_key(index_name), current_hash)
else:
log.info("Index already exists, skipping. Index hash: %s", index_name)
class MigrationAction(Enum):
CREATE = 2
DROP = 1
@dataclass
class IndexMigration:
model_name: str
index_name: str
schema: str
hash: str
action: MigrationAction
previous_hash: Optional[str] = None
def run(self):
if self.action is MigrationAction.CREATE:
self.create()
elif self.action is MigrationAction.DROP:
self.drop()
def create(self):
try:
return create_index(self.index_name, self.schema, self.hash)
except ResponseError:
log.info("Index already exists: %s", self.index_name)
def drop(self):
try:
redis.execute_command(f"FT.DROPINDEX {self.index_name}")
except ResponseError:
log.info("Index does not exist: %s", self.index_name)
class Migrator:
def __init__(self, module=None):
# Try to load any modules found under the given path or module name.
if module:
import_submodules(module)
self.migrations = []
for name, cls in model_registry.items():
hash_key = schema_hash_key(cls.Meta.index_name)
try:
schema = cls.redisearch_schema()
except NotImplementedError:
log.info("Skipping migrations for %s", name)
continue
current_hash = hashlib.sha1(schema.encode("utf-8")).hexdigest() # nosec
try:
redis.execute_command("ft.info", cls.Meta.index_name)
except ResponseError:
self.migrations.append(
IndexMigration(
name,
cls.Meta.index_name,
schema,
current_hash,
MigrationAction.CREATE,
)
)
continue
stored_hash = redis.get(hash_key)
schema_out_of_date = current_hash != stored_hash
if schema_out_of_date:
# TODO: Switch out schema with an alias to avoid downtime -- separate migration?
self.migrations.append(
IndexMigration(
name,
cls.Meta.index_name,
schema,
current_hash,
MigrationAction.DROP,
stored_hash,
)
)
self.migrations.append(
IndexMigration(
name,
cls.Meta.index_name,
schema,
current_hash,
MigrationAction.CREATE,
stored_hash,
)
)
def run(self):
# TODO: Migration history
# TODO: Dry run with output
for migration in self.migrations:
migration.run()

1502
redis_om/model/model.py Normal file

File diff suppressed because it is too large Load diff

31
redis_om/model/models.py Normal file
View file

@ -0,0 +1,31 @@
import abc
from typing import Optional
from redis_om.model.model import HashModel, JsonModel
class BaseJsonModel(JsonModel, abc.ABC):
class Meta:
global_key_prefix = "redis-om"
class BaseHashModel(HashModel, abc.ABC):
class Meta:
global_key_prefix = "redis-om"
# class AddressJson(BaseJsonModel):
# address_line_1: str
# address_line_2: Optional[str]
# city: str
# country: str
# postal_code: str
#
class AddressHash(BaseHashModel):
address_line_1: str
address_line_2: Optional[str]
city: str
country: str
postal_code: str

View file

@ -0,0 +1,105 @@
from collections import Sequence
from typing import Any, Dict, List, Mapping, Union
from redis_om.model.model import Expression
class LogicalOperatorForListOfExpressions(Expression):
operator: str = ""
def __init__(self, *expressions: Expression):
self.expressions = list(expressions)
@property
def query(self) -> Mapping[str, List[Expression]]:
if not self.expressions:
raise AttributeError("At least one expression must be provided")
# TODO: This needs to return a RediSearch string.
# Use the values in each expression object to build the string.
# Determine the type of query based on the field (numeric range,
# tag field, etc.).
return {self.operator: self.expressions}
class Or(LogicalOperatorForListOfExpressions):
"""
Logical OR query operator
Example:
```python
class Product(JsonModel):
price: float
category: str
Or(Product.price < 10, Product.category == "Sweets")
```
Will return RediSearch query string like:
```
(@price:[-inf 10]) | (@category:{Sweets})
```
"""
operator = "|"
class And(LogicalOperatorForListOfExpressions):
"""
Logical AND query operator
Example:
```python
class Product(Document):
price: float
category: str
And(Product.price < 10, Product.category == "Sweets")
```
Will return a query string like:
```
(@price:[-inf 10]) (@category:{Sweets})
```
Note that in RediSearch, AND is implied with multiple terms.
"""
operator = " "
class Not(LogicalOperatorForListOfExpressions):
"""
Logical NOT query operator
Example:
```python
class Product(Document):
price: float
category: str
Not(Product.price<10, Product.category=="Sweets")
```
Will return a query string like:
```
-(@price:[-inf 10]) -(@category:{Sweets})
```
"""
@property
def query(self):
return "-(expression1) -(expression2)"
class QueryResolver:
def __init__(self, *expressions: Expression):
self.expressions = expressions
def resolve(self) -> str:
"""Resolve expressions to a RediSearch query string."""

View file

@ -0,0 +1,75 @@
"""
This code adapted from the library "pptree," Copyright (c) 2017 Clément Michard
and released under the MIT license: https://github.com/clemtoy/pptree
"""
import io
def render_tree(
current_node,
nameattr="name",
left_child="left",
right_child="right",
indent="",
last="updown",
buffer=None,
):
"""Print a tree-like structure, `current_node`.
This is a mostly-direct-copy of the print_tree() function from the ppbtree
module of the pptree library, but instead of printing to standard out, we
write to a StringIO buffer, then use that buffer to accumulate written lines
during recursive calls to render_tree().
"""
if buffer is None:
buffer = io.StringIO()
if hasattr(current_node, nameattr):
name = lambda node: getattr(node, nameattr) # noqa: E731
else:
name = lambda node: str(node) # noqa: E731
up = getattr(current_node, left_child, None)
down = getattr(current_node, right_child, None)
if up is not None:
next_last = "up"
next_indent = "{0}{1}{2}".format(
indent, " " if "up" in last else "|", " " * len(str(name(current_node)))
)
render_tree(
up, nameattr, left_child, right_child, next_indent, next_last, buffer
)
if last == "up":
start_shape = ""
elif last == "down":
start_shape = ""
elif last == "updown":
start_shape = " "
else:
start_shape = ""
if up is not None and down is not None:
end_shape = ""
elif up:
end_shape = ""
elif down:
end_shape = ""
else:
end_shape = ""
print(
"{0}{1}{2}{3}".format(indent, start_shape, name(current_node), end_shape),
file=buffer,
)
if down is not None:
next_last = "down"
next_indent = "{0}{1}{2}".format(
indent, " " if "down" in last else "|", " " * len(str(name(current_node)))
)
render_tree(
down, nameattr, left_child, right_child, next_indent, next_last, buffer
)
return f"\n{buffer.getvalue()}"

View file

@ -0,0 +1,25 @@
import re
from typing import Optional, Pattern
class TokenEscaper:
"""
Escape punctuation within an input string.
"""
# Characters that RediSearch requires us to escape during queries.
# Source: https://oss.redis.com/redisearch/Escaping/#the_rules_of_text_field_tokenization
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\ ]"
def __init__(self, escape_chars_re: Optional[Pattern] = None):
if escape_chars_re:
self.escaped_chars_re = escape_chars_re
else:
self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS)
def escape(self, value: str) -> str:
def escape_symbol(match):
value = match.group(0)
return f"\\{value}"
return self.escaped_chars_re.sub(escape_symbol, value)