redis-om-python/redis_developer/orm/model.py

361 lines
10 KiB
Python
Raw Normal View History

from dataclasses import dataclass
from enum import Enum
from typing import (
AbstractSet,
Any,
Callable,
Dict,
Mapping,
Optional,
Set,
Tuple,
TypeVar,
Union,
2021-09-01 21:56:06 +02:00
Sequence,
no_type_check,
Protocol,
List, Type
)
2021-09-01 01:30:31 +02:00
import uuid
import redis
2021-09-01 21:56:06 +02:00
from pydantic import BaseModel, validator
from pydantic.fields import FieldInfo as PydanticFieldInfo
from pydantic.fields import ModelField, Undefined, UndefinedType
2021-09-01 21:56:06 +02:00
from pydantic.typing import NoArgAnyCallable
from pydantic.utils import Representation
from .encoders import jsonable_encoder
_T = TypeVar("_T")
class RedisModelError(Exception):
pass
class NotFoundError(Exception):
pass
class Operations(Enum):
EQ = 1
LT = 2
GT = 3
@dataclass
class Expression:
field: ModelField
op: Operations
right_value: Any
2021-09-01 01:30:31 +02:00
class PrimaryKeyCreator(Protocol):
2021-09-01 21:56:06 +02:00
def create_pk(self, *args, **kwargs) -> str:
2021-09-01 01:30:31 +02:00
"""Create a new primary key"""
class Uuid4PrimaryKey:
2021-09-01 21:56:06 +02:00
def create_pk(self) -> str:
2021-09-01 01:30:31 +02:00
return str(uuid.uuid4())
class ExpressionProxy:
def __init__(self, field: ModelField):
self.field = field
def __eq__(self, other: Any) -> Expression:
return Expression(field=self.field, op=Operations.EQ, right_value=other)
def __lt__(self, other: Any) -> Expression:
return Expression(field=self.field, op=Operations.LT, right_value=other)
def __gt__(self, other: Any) -> Expression:
return Expression(field=self.field, op=Operations.GT, right_value=other)
def __dataclass_transform__(
*,
eq_default: bool = True,
order_default: bool = False,
kw_only_default: bool = False,
field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()),
) -> Callable[[_T], _T]:
return lambda a: a
class FieldInfo(PydanticFieldInfo):
def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
primary_key = kwargs.pop("primary_key", False)
nullable = kwargs.pop("nullable", Undefined)
foreign_key = kwargs.pop("foreign_key", Undefined)
index = kwargs.pop("index", Undefined)
unique = kwargs.pop("unique", Undefined)
super().__init__(default=default, **kwargs)
self.primary_key = primary_key
self.nullable = nullable
self.foreign_key = foreign_key
self.index = index
self.unique = unique
class RelationshipInfo(Representation):
def __init__(
self,
*,
back_populates: Optional[str] = None,
link_model: Optional[Any] = None,
) -> None:
self.back_populates = back_populates
self.link_model = link_model
def Field(
default: Any = Undefined,
*,
default_factory: Optional[NoArgAnyCallable] = None,
alias: str = None,
title: str = None,
description: str = None,
exclude: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
include: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
const: bool = None,
gt: float = None,
ge: float = None,
lt: float = None,
le: float = None,
multiple_of: float = None,
min_items: int = None,
max_items: int = None,
min_length: int = None,
max_length: int = None,
allow_mutation: bool = True,
regex: str = None,
primary_key: bool = False,
unique: bool = False,
foreign_key: Optional[Any] = None,
nullable: Union[bool, UndefinedType] = Undefined,
index: Union[bool, UndefinedType] = Undefined,
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any:
current_schema_extra = schema_extra or {}
field_info = FieldInfo(
default,
default_factory=default_factory,
alias=alias,
title=title,
description=description,
exclude=exclude,
include=include,
const=const,
gt=gt,
ge=ge,
lt=lt,
le=le,
multiple_of=multiple_of,
min_items=min_items,
max_items=max_items,
min_length=min_length,
max_length=max_length,
allow_mutation=allow_mutation,
regex=regex,
primary_key=primary_key,
unique=unique,
foreign_key=foreign_key,
nullable=nullable,
index=index,
**current_schema_extra,
)
field_info._validate()
return field_info
@dataclass
class PrimaryKey:
name: str
field: ModelField
class DefaultMeta:
global_key_prefix: Optional[str] = None
model_key_prefix: Optional[str] = None
primary_key_pattern: Optional[str] = None
database: Optional[redis.Redis] = None
primary_key: Optional[PrimaryKey] = None
2021-09-01 21:56:06 +02:00
primary_key_creator_cls: Type[PrimaryKeyCreator] = None
2021-09-01 21:56:06 +02:00
class RedisModel(BaseModel):
"""
TODO: Convert expressions to Redis commands, execute
TODO: Key prefix vs. "key pattern" (that's actually the primary key pattern)
TODO: Default key prefix is model name lowercase
TODO: Build primary key pattern from PK field name, model prefix
TODO: Default PK pattern is model name:pk field
"""
pk: Optional[str] = Field(default=None, primary_key=True)
class Config:
orm_mode = True
arbitrary_types_allowed = True
extra = 'allow'
Meta = DefaultMeta
def __init_subclass__(cls, **kwargs):
# Create proxies for each model field so that we can use the field
# in queries, like Model.get(Model.field_name == 1)
super().__init_subclass__(**kwargs)
for name, field in cls.__fields__.items():
setattr(cls, name, ExpressionProxy(field))
# Check if this is our FieldInfo version with extended ORM metadata.
if isinstance(field.field_info, FieldInfo):
if field.field_info.primary_key:
cls.Meta.primary_key = PrimaryKey(name=name, field=field)
2021-09-01 21:56:06 +02:00
# TODO: Raise exception here, global key prefix required?
if not getattr(cls.Meta, 'global_key_prefix'):
cls.Meta.global_key_prefix = ""
if not getattr(cls.Meta, 'model_key_prefix'):
cls.Meta.model_key_prefix = f"{cls.__name__.lower()}"
if not getattr(cls.Meta, 'primary_key_pattern'):
cls.Meta.primary_key_pattern = "{pk}"
if not getattr(cls.Meta, 'database'):
cls.Meta.database = redis.Redis(decode_responses=True)
if not getattr(cls.Meta, 'primary_key_creator_cls'):
cls.Meta.primary_key_creator_cls = Uuid4PrimaryKey
def __init__(__pydantic_self__, **data: Any) -> None:
2021-09-01 01:30:31 +02:00
super().__init__(**data)
__pydantic_self__.validate_primary_key()
2021-09-01 21:56:06 +02:00
@validator("pk", always=True)
def validate_pk(cls, v):
if not v:
v = cls.Meta.primary_key_creator_cls().create_pk()
return v
@classmethod
def validate_primary_key(cls):
"""Check for a primary key. We need one (and only one)."""
primary_keys = 0
for name, field in cls.__fields__.items():
2021-09-01 01:30:31 +02:00
if getattr(field.field_info, 'primary_key', None):
primary_keys += 1
if primary_keys == 0:
raise RedisModelError("You must define a primary key for the model")
elif primary_keys > 1:
raise RedisModelError("You must define only one primary key for a model")
@classmethod
2021-08-31 21:03:53 +02:00
def make_key(cls, part: str):
2021-09-01 21:56:06 +02:00
global_prefix = getattr(cls.Meta, 'global_key_prefix', '').strip(":")
model_prefix = getattr(cls.Meta, 'model_key_prefix', '').strip(":")
return f"{global_prefix}:{model_prefix}:{part}"
2021-08-31 21:03:53 +02:00
@classmethod
2021-09-01 01:30:31 +02:00
def make_primary_key(cls, pk: Any):
2021-08-31 21:03:53 +02:00
"""Return the Redis key for this model."""
2021-09-01 01:30:31 +02:00
return cls.make_key(cls.Meta.primary_key_pattern.format(pk=pk))
2021-08-31 21:03:53 +02:00
def key(self):
"""Return the Redis key for this model."""
2021-09-01 00:52:21 +02:00
pk = getattr(self, self.Meta.primary_key.field.name)
2021-08-31 21:03:53 +02:00
return self.make_primary_key(pk)
@classmethod
def db(cls):
return cls.Meta.database
@classmethod
def filter(cls, *expressions: Sequence[Expression]):
return cls
@classmethod
def exclude(cls, *expressions: Sequence[Expression]):
return cls
@classmethod
def add(cls, models: Sequence['RedisModel']) -> Sequence['RedisModel']:
return [model.save() for model in models]
@classmethod
def update(cls, **field_values):
return cls
@classmethod
def values(cls):
"""Return raw values from Redis instead of model instances."""
return cls
2021-08-31 21:03:53 +02:00
def delete(self):
return self.db().delete(self.key())
2021-09-01 21:56:06 +02:00
# TODO: Protocol
@classmethod
def get(cls, pk: Any):
raise NotImplementedError
def save(self, *args, **kwargs) -> 'RedisModel':
raise NotImplementedError
class HashModel(RedisModel):
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
for name, field in cls.__fields__.items():
if issubclass(field.outer_type_, RedisModel):
raise RedisModelError(f"HashModels cannot have embedded model "
f"fields. Field: {name}")
2021-09-01 21:56:06 +02:00
for typ in (Set, Mapping, List):
if issubclass(field.outer_type_, typ):
raise RedisModelError(f"HashModels cannot have set, list,"
f" or mapping fields. Field: {name}")
2021-09-01 21:56:06 +02:00
def save(self, *args, **kwargs) -> 'HashModel':
document = jsonable_encoder(self.dict())
2021-08-31 21:03:53 +02:00
success = self.db().hset(self.key(), mapping=document)
2021-09-01 21:56:06 +02:00
return success
@classmethod
def get(cls, pk: Any) -> 'HashModel':
document = cls.db().hgetall(cls.make_primary_key(pk))
if not document:
raise NotFoundError
return cls.parse_obj(document)
@classmethod
@no_type_check
def _get_value(cls, *args, **kwargs) -> Any:
"""
Always send None as an empty string.
TODO: We do this because redis-py's hset() method requires non-null
values. Is there a better way?
"""
val = super()._get_value(*args, **kwargs)
if val is None:
return ""
return val
class JsonModel(RedisModel):
def save(self, *args, **kwargs) -> 'JsonModel':
success = self.db().execute_command('JSON.SET', self.key(), ".", self.json())
return success
2021-09-01 21:56:06 +02:00
@classmethod
def get(cls, pk: Any) -> 'JsonModel':
document = cls.db().execute_command("JSON.GET", cls.make_primary_key(pk))
if not document:
raise NotFoundError
return cls.parse_raw(document)