Broken schema generation
This commit is contained in:
parent
8f32b359f0
commit
5d05de95f8
5 changed files with 234 additions and 58 deletions
|
@ -94,6 +94,14 @@ class Operators(Enum):
|
|||
ExpressionOrModelField = Union['Expression', 'NegatedExpression', ModelField]
|
||||
|
||||
|
||||
def embedded(cls):
|
||||
"""
|
||||
Mark a model as embedded to avoid creating multiple indexes if the model is
|
||||
only ever used embedded within other models.
|
||||
"""
|
||||
setattr(cls.Meta, 'embedded', True)
|
||||
|
||||
|
||||
class ExpressionProtocol(Protocol):
|
||||
op: Operators
|
||||
left: ExpressionOrModelField
|
||||
|
@ -166,15 +174,16 @@ class Expression:
|
|||
op: Operators
|
||||
left: ExpressionOrModelField
|
||||
right: ExpressionOrModelField
|
||||
parents: List[Tuple[str, 'RedisModel']]
|
||||
|
||||
def __invert__(self):
|
||||
return NegatedExpression(self)
|
||||
|
||||
def __and__(self, other: ExpressionOrModelField):
|
||||
return Expression(left=self, op=Operators.AND, right=other)
|
||||
return Expression(left=self, op=Operators.AND, right=other, parents=self.parents)
|
||||
|
||||
def __or__(self, other: ExpressionOrModelField):
|
||||
return Expression(left=self, op=Operators.OR, right=other)
|
||||
return Expression(left=self, op=Operators.OR, right=other, parents=self.parents)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
|
@ -189,26 +198,34 @@ ExpressionOrNegated = Union[Expression, NegatedExpression]
|
|||
|
||||
|
||||
class ExpressionProxy:
|
||||
def __init__(self, field: ModelField):
|
||||
def __init__(self, field: ModelField, parents: List[Tuple[str, 'RedisModel']]):
|
||||
self.field = field
|
||||
self.parents = parents
|
||||
|
||||
def __eq__(self, other: Any) -> Expression: # type: ignore[override]
|
||||
return Expression(left=self.field, op=Operators.EQ, right=other)
|
||||
return Expression(left=self.field, op=Operators.EQ, right=other, parents=self.parents)
|
||||
|
||||
def __ne__(self, other: Any) -> Expression: # type: ignore[override]
|
||||
return Expression(left=self.field, op=Operators.NE, right=other)
|
||||
return Expression(left=self.field, op=Operators.NE, right=other, parents=self.parents)
|
||||
|
||||
def __lt__(self, other: Any) -> Expression: # type: ignore[override]
|
||||
return Expression(left=self.field, op=Operators.LT, right=other)
|
||||
return Expression(left=self.field, op=Operators.LT, right=other, parents=self.parents)
|
||||
|
||||
def __le__(self, other: Any) -> Expression: # type: ignore[override]
|
||||
return Expression(left=self.field, op=Operators.LE, right=other)
|
||||
return Expression(left=self.field, op=Operators.LE, right=other, parents=self.parents)
|
||||
|
||||
def __gt__(self, other: Any) -> Expression: # type: ignore[override]
|
||||
return Expression(left=self.field, op=Operators.GT, right=other)
|
||||
return Expression(left=self.field, op=Operators.GT, right=other, parents=self.parents)
|
||||
|
||||
def __ge__(self, other: Any) -> Expression: # type: ignore[override]
|
||||
return Expression(left=self.field, op=Operators.GE, right=other)
|
||||
return Expression(left=self.field, op=Operators.GE, right=other, parents=self.parents)
|
||||
|
||||
def __getattr__(self, item):
|
||||
attr = getattr(self.field.outer_type_, item)
|
||||
if isinstance(attr, self.__class__):
|
||||
attr.parents.insert(0, (self.field.name, self.field.outer_type_))
|
||||
attr.parents = attr.parents + self.parents
|
||||
return attr
|
||||
|
||||
|
||||
class QueryNotSupportedError(Exception):
|
||||
|
@ -265,7 +282,10 @@ class FindQuery:
|
|||
if self.expressions:
|
||||
self._expression = reduce(operator.and_, self.expressions)
|
||||
else:
|
||||
self._expression = Expression(left=None, right=None, op=Operators.ALL)
|
||||
# TODO: Is there a better way to support the "give me all records" query?
|
||||
# Also -- if we do it this way, we need different type annotations.
|
||||
self._expression = Expression(left=None, right=None, op=Operators.ALL,
|
||||
parents=[])
|
||||
return self._expression
|
||||
|
||||
@property
|
||||
|
@ -316,7 +336,11 @@ class FindQuery:
|
|||
|
||||
@classmethod
|
||||
def resolve_value(cls, field_name: str, field_type: RediSearchFieldTypes,
|
||||
field_info: PydanticFieldInfo, op: Operators, value: Any) -> str:
|
||||
field_info: PydanticFieldInfo, op: Operators, value: Any,
|
||||
parents: List[Tuple[str, 'RedisModel']]) -> str:
|
||||
if parents:
|
||||
prefix = "_".join([p[0] for p in parents])
|
||||
field_name = f"{prefix}_{field_name}"
|
||||
result = ""
|
||||
if field_type is RediSearchFieldTypes.TEXT:
|
||||
result = f"@{field_name}:"
|
||||
|
@ -427,6 +451,9 @@ class FindQuery:
|
|||
field_type = cls.resolve_field_type(expression.left)
|
||||
field_name = expression.left.name
|
||||
field_info = expression.left.field_info
|
||||
if not field_info or not getattr(field_info, "index", None):
|
||||
raise QueryNotSupportedError(f"You tried to query by a field ({field_name}) "
|
||||
f"that isn't indexed. See docs: TODO")
|
||||
else:
|
||||
raise QueryNotSupportedError(f"A query expression should start with either a field "
|
||||
f"or an expression enclosed in parenthesis. See docs: "
|
||||
|
@ -454,7 +481,8 @@ class FindQuery:
|
|||
if isinstance(right, ModelField):
|
||||
raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO")
|
||||
else:
|
||||
result += cls.resolve_value(field_name, field_type, field_info, expression.op, right)
|
||||
result += cls.resolve_value(field_name, field_type, field_info,
|
||||
expression.op, right, expression.parents)
|
||||
|
||||
if encompassing_expression_is_negated:
|
||||
result = f"-({result})"
|
||||
|
@ -705,6 +733,7 @@ class MetaProtocol(Protocol):
|
|||
primary_key_creator_cls: Type[PrimaryKeyCreator]
|
||||
index_name: str
|
||||
abstract: bool
|
||||
embedded: bool
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -722,6 +751,7 @@ class DefaultMeta:
|
|||
primary_key_creator_cls: Optional[Type[PrimaryKeyCreator]] = None
|
||||
index_name: Optional[str] = None
|
||||
abstract: Optional[bool] = False
|
||||
embedded: Optional[bool] = False
|
||||
|
||||
|
||||
class ModelMeta(ModelMetaclass):
|
||||
|
@ -730,6 +760,11 @@ class ModelMeta(ModelMetaclass):
|
|||
def __new__(cls, name, bases, attrs, **kwargs): # noqa C901
|
||||
meta = attrs.pop('Meta', None)
|
||||
new_class = super().__new__(cls, name, bases, attrs, **kwargs)
|
||||
|
||||
# The fact that there is a Meta field and _meta field is important: a
|
||||
# user may have given us a Meta object with their configuration, while
|
||||
# we might have inherited _meta from a parent class, and should
|
||||
# therefore use some of the inherited fields.
|
||||
meta = meta or getattr(new_class, 'Meta', None)
|
||||
base_meta = getattr(new_class, '_meta', None)
|
||||
|
||||
|
@ -739,8 +774,9 @@ class ModelMeta(ModelMetaclass):
|
|||
elif base_meta:
|
||||
new_class._meta = deepcopy(base_meta)
|
||||
new_class.Meta = new_class._meta
|
||||
# Unset inherited values we don't want to reuse (typically based on the model name).
|
||||
new_class._meta.abstract = False
|
||||
# Unset inherited values we don't want to reuse (typically based on
|
||||
# the model name).
|
||||
new_class._meta.embedded = False
|
||||
new_class._meta.model_key_prefix = None
|
||||
new_class._meta.index_name = None
|
||||
else:
|
||||
|
@ -750,7 +786,7 @@ class ModelMeta(ModelMetaclass):
|
|||
# Create proxies for each model field so that we can use the field
|
||||
# in queries, like Model.get(Model.field_name == 1)
|
||||
for field_name, field in new_class.__fields__.items():
|
||||
setattr(new_class, field_name, ExpressionProxy(field))
|
||||
setattr(new_class, field_name, ExpressionProxy(field, []))
|
||||
# Check if this is our FieldInfo version with extended ORM metadata.
|
||||
if isinstance(field.field_info, FieldInfo):
|
||||
if field.field_info.primary_key:
|
||||
|
@ -774,8 +810,9 @@ class ModelMeta(ModelMetaclass):
|
|||
new_class._meta.index_name = f"{new_class._meta.global_key_prefix}:" \
|
||||
f"{new_class._meta.model_key_prefix}:index"
|
||||
|
||||
# Not an abstract model class
|
||||
if abc.ABC not in bases:
|
||||
# Not an abstract model class or embedded model, so we should let the
|
||||
# Migrator create indexes for it.
|
||||
if abc.ABC not in bases and not new_class._meta.embedded:
|
||||
key = f"{new_class.__module__}.{new_class.__qualname__}"
|
||||
model_registry[key] = new_class
|
||||
|
||||
|
@ -967,7 +1004,7 @@ class HashModel(RedisModel, abc.ABC):
|
|||
schema_parts = []
|
||||
|
||||
for name, field in cls.__fields__.items():
|
||||
# TODO: Merge this code with schema_for_type()
|
||||
# TODO: Merge this code with schema_for_type()?
|
||||
_type = field.outer_type_
|
||||
if getattr(field.field_info, 'primary_key', None):
|
||||
if issubclass(_type, str):
|
||||
|
@ -1047,6 +1084,8 @@ class JsonModel(RedisModel, abc.ABC):
|
|||
schema_parts = []
|
||||
json_path = "$"
|
||||
|
||||
if cls.__name__ == "Address":
|
||||
import ipdb; ipdb.set_trace()
|
||||
for name, field in cls.__fields__.items():
|
||||
# TODO: Merge this code with schema_for_type()?
|
||||
_type = field.outer_type_
|
||||
|
@ -1070,21 +1109,20 @@ class JsonModel(RedisModel, abc.ABC):
|
|||
log.warning("Model %s defined an empty list field: %s", cls, name)
|
||||
continue
|
||||
embedded_cls = embedded_cls[0]
|
||||
schema_parts.append(cls.schema_for_type(f"{json_path}.{name}[]", name, f"{name}",
|
||||
# TODO: Should this have a name prefix?
|
||||
schema_parts.append(cls.schema_for_type(f"{json_path}.{name}[]", name, name,
|
||||
embedded_cls, field.field_info))
|
||||
elif issubclass(_type, RedisModel):
|
||||
schema_parts.append(cls.schema_for_type(f"{json_path}.{name}", name, f"{name}", _type,
|
||||
schema_parts.append(cls.schema_for_type(f"{json_path}.{name}", name, name, _type,
|
||||
field.field_info))
|
||||
return schema_parts
|
||||
|
||||
@classmethod
|
||||
# TODO: We need both the "name" of the field (address_line_1) as we'll
|
||||
# find it in the JSON document, AND the name of the field as it should
|
||||
# be in the redisearch schema (address_address_line_1). Maybe both "name"
|
||||
# and "name_prefix"?
|
||||
def schema_for_type(cls, json_path: str, name: str, name_prefix: str, typ: Any,
|
||||
field_info: PydanticFieldInfo) -> str:
|
||||
index_field_name = f"{name_prefix}{name}"
|
||||
if name == "description":
|
||||
import ipdb; ipdb.set_trace()
|
||||
index_field_name = f"{name_prefix}_{name}"
|
||||
should_index = getattr(field_info, 'index', False)
|
||||
|
||||
if get_origin(typ) == list:
|
||||
|
@ -1094,15 +1132,14 @@ class JsonModel(RedisModel, abc.ABC):
|
|||
log.warning("Model %s defined an empty list field: %s", cls, name)
|
||||
return ""
|
||||
embedded_cls = embedded_cls[0]
|
||||
# TODO: We need to pass the "JSON Path so far" which should include the
|
||||
# correct syntax for an array.
|
||||
return cls.schema_for_type(f"{json_path}[]", name, f"{name_prefix}{name}", embedded_cls, field_info)
|
||||
return cls.schema_for_type(f"{json_path}[]", name, f"{name_prefix}{name}",
|
||||
embedded_cls, field_info)
|
||||
elif issubclass(typ, RedisModel):
|
||||
sub_fields = []
|
||||
for embedded_name, field in typ.__fields__.items():
|
||||
sub_fields.append(cls.schema_for_type(f"{json_path}.{embedded_name}",
|
||||
embedded_name,
|
||||
f"{name_prefix}_",
|
||||
f"{name_prefix}_{embedded_name}",
|
||||
field.outer_type_,
|
||||
field.field_info))
|
||||
return " ".join(filter(None, sub_fields))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue