361 lines
10 KiB
Python
361 lines
10 KiB
Python
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from typing import (
|
|
AbstractSet,
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
Mapping,
|
|
Optional,
|
|
Set,
|
|
Tuple,
|
|
TypeVar,
|
|
Union,
|
|
Sequence,
|
|
no_type_check,
|
|
Protocol,
|
|
List, Type
|
|
)
|
|
import uuid
|
|
|
|
import redis
|
|
from pydantic import BaseModel, validator
|
|
from pydantic.fields import FieldInfo as PydanticFieldInfo
|
|
from pydantic.fields import ModelField, Undefined, UndefinedType
|
|
from pydantic.typing import NoArgAnyCallable
|
|
from pydantic.utils import Representation
|
|
|
|
from .encoders import jsonable_encoder
|
|
|
|
_T = TypeVar("_T")
|
|
|
|
|
|
class RedisModelError(Exception):
|
|
pass
|
|
|
|
|
|
class NotFoundError(Exception):
|
|
pass
|
|
|
|
|
|
class Operations(Enum):
|
|
EQ = 1
|
|
LT = 2
|
|
GT = 3
|
|
|
|
|
|
@dataclass
|
|
class Expression:
|
|
field: ModelField
|
|
op: Operations
|
|
right_value: Any
|
|
|
|
|
|
class PrimaryKeyCreator(Protocol):
|
|
def create_pk(self, *args, **kwargs) -> str:
|
|
"""Create a new primary key"""
|
|
|
|
|
|
class Uuid4PrimaryKey:
|
|
def create_pk(self) -> str:
|
|
return str(uuid.uuid4())
|
|
|
|
|
|
class ExpressionProxy:
|
|
def __init__(self, field: ModelField):
|
|
self.field = field
|
|
|
|
def __eq__(self, other: Any) -> Expression:
|
|
return Expression(field=self.field, op=Operations.EQ, right_value=other)
|
|
|
|
def __lt__(self, other: Any) -> Expression:
|
|
return Expression(field=self.field, op=Operations.LT, right_value=other)
|
|
|
|
def __gt__(self, other: Any) -> Expression:
|
|
return Expression(field=self.field, op=Operations.GT, right_value=other)
|
|
|
|
|
|
def __dataclass_transform__(
|
|
*,
|
|
eq_default: bool = True,
|
|
order_default: bool = False,
|
|
kw_only_default: bool = False,
|
|
field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()),
|
|
) -> Callable[[_T], _T]:
|
|
return lambda a: a
|
|
|
|
|
|
class FieldInfo(PydanticFieldInfo):
|
|
def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
|
|
primary_key = kwargs.pop("primary_key", False)
|
|
nullable = kwargs.pop("nullable", Undefined)
|
|
foreign_key = kwargs.pop("foreign_key", Undefined)
|
|
index = kwargs.pop("index", Undefined)
|
|
unique = kwargs.pop("unique", Undefined)
|
|
super().__init__(default=default, **kwargs)
|
|
self.primary_key = primary_key
|
|
self.nullable = nullable
|
|
self.foreign_key = foreign_key
|
|
self.index = index
|
|
self.unique = unique
|
|
|
|
|
|
class RelationshipInfo(Representation):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
back_populates: Optional[str] = None,
|
|
link_model: Optional[Any] = None,
|
|
) -> None:
|
|
self.back_populates = back_populates
|
|
self.link_model = link_model
|
|
|
|
|
|
def Field(
|
|
default: Any = Undefined,
|
|
*,
|
|
default_factory: Optional[NoArgAnyCallable] = None,
|
|
alias: str = None,
|
|
title: str = None,
|
|
description: str = None,
|
|
exclude: Union[
|
|
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
|
|
] = None,
|
|
include: Union[
|
|
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
|
|
] = None,
|
|
const: bool = None,
|
|
gt: float = None,
|
|
ge: float = None,
|
|
lt: float = None,
|
|
le: float = None,
|
|
multiple_of: float = None,
|
|
min_items: int = None,
|
|
max_items: int = None,
|
|
min_length: int = None,
|
|
max_length: int = None,
|
|
allow_mutation: bool = True,
|
|
regex: str = None,
|
|
primary_key: bool = False,
|
|
unique: bool = False,
|
|
foreign_key: Optional[Any] = None,
|
|
nullable: Union[bool, UndefinedType] = Undefined,
|
|
index: Union[bool, UndefinedType] = Undefined,
|
|
schema_extra: Optional[Dict[str, Any]] = None,
|
|
) -> Any:
|
|
current_schema_extra = schema_extra or {}
|
|
field_info = FieldInfo(
|
|
default,
|
|
default_factory=default_factory,
|
|
alias=alias,
|
|
title=title,
|
|
description=description,
|
|
exclude=exclude,
|
|
include=include,
|
|
const=const,
|
|
gt=gt,
|
|
ge=ge,
|
|
lt=lt,
|
|
le=le,
|
|
multiple_of=multiple_of,
|
|
min_items=min_items,
|
|
max_items=max_items,
|
|
min_length=min_length,
|
|
max_length=max_length,
|
|
allow_mutation=allow_mutation,
|
|
regex=regex,
|
|
primary_key=primary_key,
|
|
unique=unique,
|
|
foreign_key=foreign_key,
|
|
nullable=nullable,
|
|
index=index,
|
|
**current_schema_extra,
|
|
)
|
|
field_info._validate()
|
|
return field_info
|
|
|
|
|
|
@dataclass
|
|
class PrimaryKey:
|
|
name: str
|
|
field: ModelField
|
|
|
|
|
|
class DefaultMeta:
|
|
global_key_prefix: Optional[str] = None
|
|
model_key_prefix: Optional[str] = None
|
|
primary_key_pattern: Optional[str] = None
|
|
database: Optional[redis.Redis] = None
|
|
primary_key: Optional[PrimaryKey] = None
|
|
primary_key_creator_cls: Type[PrimaryKeyCreator] = None
|
|
|
|
|
|
class RedisModel(BaseModel):
|
|
"""
|
|
TODO: Convert expressions to Redis commands, execute
|
|
TODO: Key prefix vs. "key pattern" (that's actually the primary key pattern)
|
|
TODO: Default key prefix is model name lowercase
|
|
TODO: Build primary key pattern from PK field name, model prefix
|
|
TODO: Default PK pattern is model name:pk field
|
|
"""
|
|
pk: Optional[str] = Field(default=None, primary_key=True)
|
|
|
|
class Config:
|
|
orm_mode = True
|
|
arbitrary_types_allowed = True
|
|
extra = 'allow'
|
|
|
|
Meta = DefaultMeta
|
|
|
|
def __init_subclass__(cls, **kwargs):
|
|
# Create proxies for each model field so that we can use the field
|
|
# in queries, like Model.get(Model.field_name == 1)
|
|
super().__init_subclass__(**kwargs)
|
|
|
|
for name, field in cls.__fields__.items():
|
|
setattr(cls, name, ExpressionProxy(field))
|
|
# Check if this is our FieldInfo version with extended ORM metadata.
|
|
if isinstance(field.field_info, FieldInfo):
|
|
if field.field_info.primary_key:
|
|
cls.Meta.primary_key = PrimaryKey(name=name, field=field)
|
|
# TODO: Raise exception here, global key prefix required?
|
|
if not getattr(cls.Meta, 'global_key_prefix'):
|
|
cls.Meta.global_key_prefix = ""
|
|
if not getattr(cls.Meta, 'model_key_prefix'):
|
|
cls.Meta.model_key_prefix = f"{cls.__name__.lower()}"
|
|
if not getattr(cls.Meta, 'primary_key_pattern'):
|
|
cls.Meta.primary_key_pattern = "{pk}"
|
|
if not getattr(cls.Meta, 'database'):
|
|
cls.Meta.database = redis.Redis(decode_responses=True)
|
|
if not getattr(cls.Meta, 'primary_key_creator_cls'):
|
|
cls.Meta.primary_key_creator_cls = Uuid4PrimaryKey
|
|
|
|
def __init__(__pydantic_self__, **data: Any) -> None:
|
|
super().__init__(**data)
|
|
__pydantic_self__.validate_primary_key()
|
|
|
|
@validator("pk", always=True)
|
|
def validate_pk(cls, v):
|
|
if not v:
|
|
v = cls.Meta.primary_key_creator_cls().create_pk()
|
|
return v
|
|
|
|
@classmethod
|
|
def validate_primary_key(cls):
|
|
"""Check for a primary key. We need one (and only one)."""
|
|
primary_keys = 0
|
|
for name, field in cls.__fields__.items():
|
|
if getattr(field.field_info, 'primary_key', None):
|
|
primary_keys += 1
|
|
if primary_keys == 0:
|
|
raise RedisModelError("You must define a primary key for the model")
|
|
elif primary_keys > 1:
|
|
raise RedisModelError("You must define only one primary key for a model")
|
|
|
|
@classmethod
|
|
def make_key(cls, part: str):
|
|
global_prefix = getattr(cls.Meta, 'global_key_prefix', '').strip(":")
|
|
model_prefix = getattr(cls.Meta, 'model_key_prefix', '').strip(":")
|
|
return f"{global_prefix}:{model_prefix}:{part}"
|
|
|
|
@classmethod
|
|
def make_primary_key(cls, pk: Any):
|
|
"""Return the Redis key for this model."""
|
|
return cls.make_key(cls.Meta.primary_key_pattern.format(pk=pk))
|
|
|
|
def key(self):
|
|
"""Return the Redis key for this model."""
|
|
pk = getattr(self, self.Meta.primary_key.field.name)
|
|
return self.make_primary_key(pk)
|
|
|
|
@classmethod
|
|
def db(cls):
|
|
return cls.Meta.database
|
|
|
|
@classmethod
|
|
def filter(cls, *expressions: Sequence[Expression]):
|
|
return cls
|
|
|
|
@classmethod
|
|
def exclude(cls, *expressions: Sequence[Expression]):
|
|
return cls
|
|
|
|
@classmethod
|
|
def add(cls, models: Sequence['RedisModel']) -> Sequence['RedisModel']:
|
|
return [model.save() for model in models]
|
|
|
|
@classmethod
|
|
def update(cls, **field_values):
|
|
return cls
|
|
|
|
@classmethod
|
|
def values(cls):
|
|
"""Return raw values from Redis instead of model instances."""
|
|
return cls
|
|
|
|
def delete(self):
|
|
return self.db().delete(self.key())
|
|
|
|
# TODO: Protocol
|
|
@classmethod
|
|
def get(cls, pk: Any):
|
|
raise NotImplementedError
|
|
|
|
def save(self, *args, **kwargs) -> 'RedisModel':
|
|
raise NotImplementedError
|
|
|
|
|
|
class HashModel(RedisModel):
|
|
def __init_subclass__(cls, **kwargs):
|
|
super().__init_subclass__(**kwargs)
|
|
|
|
for name, field in cls.__fields__.items():
|
|
if issubclass(field.outer_type_, RedisModel):
|
|
raise RedisModelError(f"HashModels cannot have embedded model "
|
|
f"fields. Field: {name}")
|
|
|
|
for typ in (Set, Mapping, List):
|
|
if issubclass(field.outer_type_, typ):
|
|
raise RedisModelError(f"HashModels cannot have set, list,"
|
|
f" or mapping fields. Field: {name}")
|
|
|
|
def save(self, *args, **kwargs) -> 'HashModel':
|
|
document = jsonable_encoder(self.dict())
|
|
success = self.db().hset(self.key(), mapping=document)
|
|
|
|
return success
|
|
|
|
@classmethod
|
|
def get(cls, pk: Any) -> 'HashModel':
|
|
document = cls.db().hgetall(cls.make_primary_key(pk))
|
|
if not document:
|
|
raise NotFoundError
|
|
return cls.parse_obj(document)
|
|
|
|
@classmethod
|
|
@no_type_check
|
|
def _get_value(cls, *args, **kwargs) -> Any:
|
|
"""
|
|
Always send None as an empty string.
|
|
|
|
TODO: We do this because redis-py's hset() method requires non-null
|
|
values. Is there a better way?
|
|
"""
|
|
val = super()._get_value(*args, **kwargs)
|
|
if val is None:
|
|
return ""
|
|
return val
|
|
|
|
|
|
class JsonModel(RedisModel):
|
|
def save(self, *args, **kwargs) -> 'JsonModel':
|
|
success = self.db().execute_command('JSON.SET', self.key(), ".", self.json())
|
|
return success
|
|
|
|
@classmethod
|
|
def get(cls, pk: Any) -> 'JsonModel':
|
|
document = cls.db().execute_command("JSON.GET", cls.make_primary_key(pk))
|
|
if not document:
|
|
raise NotFoundError
|
|
return cls.parse_raw(document)
|