redis-om-python/redis_developer/orm/model.py

391 lines
12 KiB
Python
Raw Normal View History

import datetime
from dataclasses import dataclass
from enum import Enum
from typing import (
AbstractSet,
Any,
Callable,
Dict,
Mapping,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
Sequence, ClassVar, TYPE_CHECKING, no_type_check,
)
import redis
from pydantic import BaseModel
from pydantic.fields import FieldInfo as PydanticFieldInfo
from pydantic.fields import ModelField, Undefined, UndefinedType
from pydantic.main import BaseConfig, ModelMetaclass, validate_model
from pydantic.typing import NoArgAnyCallable, resolve_annotations
from pydantic.utils import Representation
from .encoders import jsonable_encoder
from .util import uuid_from_time
_T = TypeVar("_T")
class RedisModelError(Exception):
pass
class NotFoundError(Exception):
pass
class Operations(Enum):
EQ = 1
LT = 2
GT = 3
@dataclass
class Expression:
field: ModelField
op: Operations
right_value: Any
class ExpressionProxy:
def __init__(self, field: ModelField):
self.field = field
def __eq__(self, other: Any) -> Expression:
return Expression(field=self.field, op=Operations.EQ, right_value=other)
def __lt__(self, other: Any) -> Expression:
return Expression(field=self.field, op=Operations.LT, right_value=other)
def __gt__(self, other: Any) -> Expression:
return Expression(field=self.field, op=Operations.GT, right_value=other)
def __dataclass_transform__(
*,
eq_default: bool = True,
order_default: bool = False,
kw_only_default: bool = False,
field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()),
) -> Callable[[_T], _T]:
return lambda a: a
class FieldInfo(PydanticFieldInfo):
def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
primary_key = kwargs.pop("primary_key", False)
nullable = kwargs.pop("nullable", Undefined)
foreign_key = kwargs.pop("foreign_key", Undefined)
index = kwargs.pop("index", Undefined)
unique = kwargs.pop("unique", Undefined)
super().__init__(default=default, **kwargs)
self.primary_key = primary_key
self.nullable = nullable
self.foreign_key = foreign_key
self.index = index
self.unique = unique
class RelationshipInfo(Representation):
def __init__(
self,
*,
back_populates: Optional[str] = None,
link_model: Optional[Any] = None,
) -> None:
self.back_populates = back_populates
self.link_model = link_model
def Field(
default: Any = Undefined,
*,
default_factory: Optional[NoArgAnyCallable] = None,
alias: str = None,
title: str = None,
description: str = None,
exclude: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
include: Union[
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
] = None,
const: bool = None,
gt: float = None,
ge: float = None,
lt: float = None,
le: float = None,
multiple_of: float = None,
min_items: int = None,
max_items: int = None,
min_length: int = None,
max_length: int = None,
allow_mutation: bool = True,
regex: str = None,
primary_key: bool = False,
unique: bool = False,
foreign_key: Optional[Any] = None,
nullable: Union[bool, UndefinedType] = Undefined,
index: Union[bool, UndefinedType] = Undefined,
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any:
current_schema_extra = schema_extra or {}
field_info = FieldInfo(
default,
default_factory=default_factory,
alias=alias,
title=title,
description=description,
exclude=exclude,
include=include,
const=const,
gt=gt,
ge=ge,
lt=lt,
le=le,
multiple_of=multiple_of,
min_items=min_items,
max_items=max_items,
min_length=min_length,
max_length=max_length,
allow_mutation=allow_mutation,
regex=regex,
primary_key=primary_key,
unique=unique,
foreign_key=foreign_key,
nullable=nullable,
index=index,
**current_schema_extra,
)
field_info._validate()
return field_info
def Relationship(
*,
back_populates: Optional[str] = None,
link_model: Optional[Any] = None
) -> Any:
relationship_info = RelationshipInfo(
back_populates=back_populates,
link_model=link_model,
)
return relationship_info
@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
class RedisModelMetaclass(ModelMetaclass):
__redismodel_relationships__: Dict[str, RelationshipInfo]
__config__: Type[BaseConfig]
__fields__: Dict[str, ModelField]
# From Pydantic
def __new__(cls, name, bases, class_dict: dict, **kwargs) -> Any:
relationships: Dict[str, RelationshipInfo] = {}
dict_for_pydantic = {}
original_annotations = resolve_annotations(
class_dict.get("__annotations__", {}), class_dict.get("__module__", None)
)
pydantic_annotations = {}
relationship_annotations = {}
for k, v in class_dict.items():
if isinstance(v, RelationshipInfo):
relationships[k] = v
else:
dict_for_pydantic[k] = v
for k, v in original_annotations.items():
if k in relationships:
relationship_annotations[k] = v
else:
pydantic_annotations[k] = v
dict_used = {
**dict_for_pydantic,
"__weakref__": None,
"__redismodel_relationships__": relationships,
"__annotations__": pydantic_annotations,
}
# Duplicate logic from Pydantic to filter config kwargs because if they are
# passed directly including the registry Pydantic will pass them over to the
# superclass causing an error
allowed_config_kwargs: Set[str] = {
key
for key in dir(BaseConfig)
if not (
key.startswith("__") and key.endswith("__")
) # skip dunder methods and attributes
}
pydantic_kwargs = kwargs.copy()
config_kwargs = {
key: pydantic_kwargs.pop(key)
for key in pydantic_kwargs.keys() & allowed_config_kwargs
}
new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
new_cls.__annotations__ = {
**relationship_annotations,
**pydantic_annotations,
**new_cls.__annotations__,
}
return new_cls
@dataclass
class PrimaryKey:
name: str
field: ModelField
class DefaultMeta:
global_key_prefix: Optional[str] = None
model_key_prefix: Optional[str] = None
primary_key_pattern: Optional[str] = None
database: Optional[redis.Redis] = None
primary_key: Optional[PrimaryKey] = None
class RedisModel(BaseModel, metaclass=RedisModelMetaclass):
"""
TODO: Convert expressions to Redis commands, execute
TODO: Key prefix vs. "key pattern" (that's actually the primary key pattern)
TODO: Default key prefix is model name lowercase
TODO: Build primary key pattern from PK field name, model prefix
TODO: Default PK pattern is model name:pk field
"""
pk: Optional[str] = Field(default=None, primary_key=True)
class Config:
orm_mode = True
arbitrary_types_allowed = True
extra = 'allow'
Meta = DefaultMeta
def __init_subclass__(cls, **kwargs):
# Create proxies for each model field so that we can use the field
# in queries, like Model.get(Model.field_name == 1)
super().__init_subclass__(**kwargs)
for name, field in cls.__fields__.items():
setattr(cls, name, ExpressionProxy(field))
# Check if this is our FieldInfo version with extended ORM metadata.
if isinstance(field.field_info, FieldInfo):
if field.field_info.primary_key:
cls.Meta.primary_key = PrimaryKey(name=name, field=field)
if not hasattr(cls.Meta, 'primary_key_pattern'):
cls.Meta.primary_key_pattern = f"{cls.Meta.primary_key.name}:{{pk}}"
def __init__(__pydantic_self__, **data: Any) -> None:
# Uses something other than `self` the first arg to allow "self" as a
# settable attribute
if TYPE_CHECKING:
__pydantic_self__.__dict__: Dict[str, Any] = {}
__pydantic_self__.__fields_set__: Set[str] = set()
values, fields_set, validation_error = validate_model(
__pydantic_self__.__class__, data
)
if validation_error:
raise validation_error
__pydantic_self__.validate_primary_key()
object.__setattr__(__pydantic_self__, '__dict__', values)
@classmethod
@no_type_check
def _get_value(cls, *args, **kwargs) -> Any:
"""
Always send None as an empty string.
TODO: How broken is this?
"""
val = super()._get_value(*args, **kwargs)
if val is None:
return ""
return val
@classmethod
def validate_primary_key(cls):
"""Check for a primary key. We need one (and only one)."""
primary_keys = 0
for name, field in cls.__fields__.items():
if field.field_info.primary_key:
primary_keys += 1
# TODO: Automatically create a primary key field instead?
if primary_keys == 0:
raise RedisModelError("You must define a primary key for the model")
elif primary_keys > 1:
raise RedisModelError("You must define only one primary key for a model")
@classmethod
def key(cls, part: str):
global_prefix = getattr(cls.Meta, 'global_key_prefix', '')
model_prefix = getattr(cls.Meta, 'model_key_prefix', '')
return f"{global_prefix}{model_prefix}{part}"
@classmethod
def get(cls, pk: Any):
# TODO: Getting related objects
pk_pattern = cls.Meta.primary_key_pattern.format(pk=str(pk))
print("GET ", cls.key(pk_pattern))
document = cls.db().hgetall(cls.key(pk_pattern))
if not document:
raise NotFoundError
return cls.parse_obj(document)
def delete(self):
# TODO: deleting relationships
pk = self.__fields__[self.Meta.primary_key.field.name]
pk_pattern = self.Meta.primary_key_pattern.format(pk=pk)
return self.db().delete(self.key(pk_pattern))
@classmethod
def db(cls):
return cls.Meta.database
@classmethod
def filter(cls, *expressions: Sequence[Expression]):
return cls
@classmethod
def exclude(cls, *expressions: Sequence[Expression]):
return cls
@classmethod
def add(cls, models: Sequence['RedisModel']) -> Sequence['RedisModel']:
return [model.save() for model in models]
@classmethod
def update(cls, **field_values):
return cls
@classmethod
def values(cls):
"""Return raw values from Redis instead of model instances."""
return cls
def save(self) -> 'RedisModel':
pk_field = self.Meta.primary_key.field
document = jsonable_encoder(self.dict())
pk = document[pk_field.name]
if not pk:
pk = str(uuid_from_time(datetime.datetime.now()))
setattr(self, pk_field.name, pk)
document[pk_field.name] = pk
pk_pattern = self.Meta.primary_key_pattern.format(pk=pk)
success = self.db().hset(self.key(pk_pattern), mapping=document)
return success
Meta = DefaultMeta
def __init__(self, **data: Any) -> None:
"""Validate that a model instance has a primary key."""
super().__init__(**data)