Broken schema generation

This commit is contained in:
Andrew Brookins 2021-10-12 14:22:57 -07:00
parent 8f32b359f0
commit 5d05de95f8
5 changed files with 234 additions and 58 deletions

View file

@ -94,6 +94,14 @@ class Operators(Enum):
ExpressionOrModelField = Union['Expression', 'NegatedExpression', ModelField]
def embedded(cls):
"""
Mark a model as embedded to avoid creating multiple indexes if the model is
only ever used embedded within other models.
"""
setattr(cls.Meta, 'embedded', True)
class ExpressionProtocol(Protocol):
op: Operators
left: ExpressionOrModelField
@ -166,15 +174,16 @@ class Expression:
op: Operators
left: ExpressionOrModelField
right: ExpressionOrModelField
parents: List[Tuple[str, 'RedisModel']]
def __invert__(self):
return NegatedExpression(self)
def __and__(self, other: ExpressionOrModelField):
return Expression(left=self, op=Operators.AND, right=other)
return Expression(left=self, op=Operators.AND, right=other, parents=self.parents)
def __or__(self, other: ExpressionOrModelField):
return Expression(left=self, op=Operators.OR, right=other)
return Expression(left=self, op=Operators.OR, right=other, parents=self.parents)
@property
def name(self):
@ -189,26 +198,34 @@ ExpressionOrNegated = Union[Expression, NegatedExpression]
class ExpressionProxy:
def __init__(self, field: ModelField):
def __init__(self, field: ModelField, parents: List[Tuple[str, 'RedisModel']]):
self.field = field
self.parents = parents
def __eq__(self, other: Any) -> Expression: # type: ignore[override]
return Expression(left=self.field, op=Operators.EQ, right=other)
return Expression(left=self.field, op=Operators.EQ, right=other, parents=self.parents)
def __ne__(self, other: Any) -> Expression: # type: ignore[override]
return Expression(left=self.field, op=Operators.NE, right=other)
return Expression(left=self.field, op=Operators.NE, right=other, parents=self.parents)
def __lt__(self, other: Any) -> Expression: # type: ignore[override]
return Expression(left=self.field, op=Operators.LT, right=other)
return Expression(left=self.field, op=Operators.LT, right=other, parents=self.parents)
def __le__(self, other: Any) -> Expression: # type: ignore[override]
return Expression(left=self.field, op=Operators.LE, right=other)
return Expression(left=self.field, op=Operators.LE, right=other, parents=self.parents)
def __gt__(self, other: Any) -> Expression: # type: ignore[override]
return Expression(left=self.field, op=Operators.GT, right=other)
return Expression(left=self.field, op=Operators.GT, right=other, parents=self.parents)
def __ge__(self, other: Any) -> Expression: # type: ignore[override]
return Expression(left=self.field, op=Operators.GE, right=other)
return Expression(left=self.field, op=Operators.GE, right=other, parents=self.parents)
def __getattr__(self, item):
attr = getattr(self.field.outer_type_, item)
if isinstance(attr, self.__class__):
attr.parents.insert(0, (self.field.name, self.field.outer_type_))
attr.parents = attr.parents + self.parents
return attr
class QueryNotSupportedError(Exception):
@ -265,7 +282,10 @@ class FindQuery:
if self.expressions:
self._expression = reduce(operator.and_, self.expressions)
else:
self._expression = Expression(left=None, right=None, op=Operators.ALL)
# TODO: Is there a better way to support the "give me all records" query?
# Also -- if we do it this way, we need different type annotations.
self._expression = Expression(left=None, right=None, op=Operators.ALL,
parents=[])
return self._expression
@property
@ -316,7 +336,11 @@ class FindQuery:
@classmethod
def resolve_value(cls, field_name: str, field_type: RediSearchFieldTypes,
field_info: PydanticFieldInfo, op: Operators, value: Any) -> str:
field_info: PydanticFieldInfo, op: Operators, value: Any,
parents: List[Tuple[str, 'RedisModel']]) -> str:
if parents:
prefix = "_".join([p[0] for p in parents])
field_name = f"{prefix}_{field_name}"
result = ""
if field_type is RediSearchFieldTypes.TEXT:
result = f"@{field_name}:"
@ -427,6 +451,9 @@ class FindQuery:
field_type = cls.resolve_field_type(expression.left)
field_name = expression.left.name
field_info = expression.left.field_info
if not field_info or not getattr(field_info, "index", None):
raise QueryNotSupportedError(f"You tried to query by a field ({field_name}) "
f"that isn't indexed. See docs: TODO")
else:
raise QueryNotSupportedError(f"A query expression should start with either a field "
f"or an expression enclosed in parenthesis. See docs: "
@ -454,7 +481,8 @@ class FindQuery:
if isinstance(right, ModelField):
raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO")
else:
result += cls.resolve_value(field_name, field_type, field_info, expression.op, right)
result += cls.resolve_value(field_name, field_type, field_info,
expression.op, right, expression.parents)
if encompassing_expression_is_negated:
result = f"-({result})"
@ -705,6 +733,7 @@ class MetaProtocol(Protocol):
primary_key_creator_cls: Type[PrimaryKeyCreator]
index_name: str
abstract: bool
embedded: bool
@dataclasses.dataclass
@ -722,6 +751,7 @@ class DefaultMeta:
primary_key_creator_cls: Optional[Type[PrimaryKeyCreator]] = None
index_name: Optional[str] = None
abstract: Optional[bool] = False
embedded: Optional[bool] = False
class ModelMeta(ModelMetaclass):
@ -730,6 +760,11 @@ class ModelMeta(ModelMetaclass):
def __new__(cls, name, bases, attrs, **kwargs): # noqa C901
meta = attrs.pop('Meta', None)
new_class = super().__new__(cls, name, bases, attrs, **kwargs)
# The fact that there is a Meta field and _meta field is important: a
# user may have given us a Meta object with their configuration, while
# we might have inherited _meta from a parent class, and should
# therefore use some of the inherited fields.
meta = meta or getattr(new_class, 'Meta', None)
base_meta = getattr(new_class, '_meta', None)
@ -739,8 +774,9 @@ class ModelMeta(ModelMetaclass):
elif base_meta:
new_class._meta = deepcopy(base_meta)
new_class.Meta = new_class._meta
# Unset inherited values we don't want to reuse (typically based on the model name).
new_class._meta.abstract = False
# Unset inherited values we don't want to reuse (typically based on
# the model name).
new_class._meta.embedded = False
new_class._meta.model_key_prefix = None
new_class._meta.index_name = None
else:
@ -750,7 +786,7 @@ class ModelMeta(ModelMetaclass):
# Create proxies for each model field so that we can use the field
# in queries, like Model.get(Model.field_name == 1)
for field_name, field in new_class.__fields__.items():
setattr(new_class, field_name, ExpressionProxy(field))
setattr(new_class, field_name, ExpressionProxy(field, []))
# Check if this is our FieldInfo version with extended ORM metadata.
if isinstance(field.field_info, FieldInfo):
if field.field_info.primary_key:
@ -774,8 +810,9 @@ class ModelMeta(ModelMetaclass):
new_class._meta.index_name = f"{new_class._meta.global_key_prefix}:" \
f"{new_class._meta.model_key_prefix}:index"
# Not an abstract model class
if abc.ABC not in bases:
# Not an abstract model class or embedded model, so we should let the
# Migrator create indexes for it.
if abc.ABC not in bases and not new_class._meta.embedded:
key = f"{new_class.__module__}.{new_class.__qualname__}"
model_registry[key] = new_class
@ -967,7 +1004,7 @@ class HashModel(RedisModel, abc.ABC):
schema_parts = []
for name, field in cls.__fields__.items():
# TODO: Merge this code with schema_for_type()
# TODO: Merge this code with schema_for_type()?
_type = field.outer_type_
if getattr(field.field_info, 'primary_key', None):
if issubclass(_type, str):
@ -1047,6 +1084,8 @@ class JsonModel(RedisModel, abc.ABC):
schema_parts = []
json_path = "$"
if cls.__name__ == "Address":
import ipdb; ipdb.set_trace()
for name, field in cls.__fields__.items():
# TODO: Merge this code with schema_for_type()?
_type = field.outer_type_
@ -1070,21 +1109,20 @@ class JsonModel(RedisModel, abc.ABC):
log.warning("Model %s defined an empty list field: %s", cls, name)
continue
embedded_cls = embedded_cls[0]
schema_parts.append(cls.schema_for_type(f"{json_path}.{name}[]", name, f"{name}",
# TODO: Should this have a name prefix?
schema_parts.append(cls.schema_for_type(f"{json_path}.{name}[]", name, name,
embedded_cls, field.field_info))
elif issubclass(_type, RedisModel):
schema_parts.append(cls.schema_for_type(f"{json_path}.{name}", name, f"{name}", _type,
schema_parts.append(cls.schema_for_type(f"{json_path}.{name}", name, name, _type,
field.field_info))
return schema_parts
@classmethod
# TODO: We need both the "name" of the field (address_line_1) as we'll
# find it in the JSON document, AND the name of the field as it should
# be in the redisearch schema (address_address_line_1). Maybe both "name"
# and "name_prefix"?
def schema_for_type(cls, json_path: str, name: str, name_prefix: str, typ: Any,
field_info: PydanticFieldInfo) -> str:
index_field_name = f"{name_prefix}{name}"
if name == "description":
import ipdb; ipdb.set_trace()
index_field_name = f"{name_prefix}_{name}"
should_index = getattr(field_info, 'index', False)
if get_origin(typ) == list:
@ -1094,15 +1132,14 @@ class JsonModel(RedisModel, abc.ABC):
log.warning("Model %s defined an empty list field: %s", cls, name)
return ""
embedded_cls = embedded_cls[0]
# TODO: We need to pass the "JSON Path so far" which should include the
# correct syntax for an array.
return cls.schema_for_type(f"{json_path}[]", name, f"{name_prefix}{name}", embedded_cls, field_info)
return cls.schema_for_type(f"{json_path}[]", name, f"{name_prefix}{name}",
embedded_cls, field_info)
elif issubclass(typ, RedisModel):
sub_fields = []
for embedded_name, field in typ.__fields__.items():
sub_fields.append(cls.schema_for_type(f"{json_path}.{embedded_name}",
embedded_name,
f"{name_prefix}_",
f"{name_prefix}_{embedded_name}",
field.outer_type_,
field.field_info))
return " ".join(filter(None, sub_fields))