Broken schema generation
This commit is contained in:
		
							parent
							
								
									8f32b359f0
								
							
						
					
					
						commit
						5d05de95f8
					
				
					 5 changed files with 234 additions and 58 deletions
				
			
		| 
						 | 
				
			
			@ -94,6 +94,14 @@ class Operators(Enum):
 | 
			
		|||
ExpressionOrModelField = Union['Expression', 'NegatedExpression', ModelField]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def embedded(cls):
 | 
			
		||||
    """
 | 
			
		||||
    Mark a model as embedded to avoid creating multiple indexes if the model is
 | 
			
		||||
    only ever used embedded within other models.
 | 
			
		||||
    """
 | 
			
		||||
    setattr(cls.Meta, 'embedded', True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ExpressionProtocol(Protocol):
 | 
			
		||||
    op: Operators
 | 
			
		||||
    left: ExpressionOrModelField
 | 
			
		||||
| 
						 | 
				
			
			@ -166,15 +174,16 @@ class Expression:
 | 
			
		|||
    op: Operators
 | 
			
		||||
    left: ExpressionOrModelField
 | 
			
		||||
    right: ExpressionOrModelField
 | 
			
		||||
    parents: List[Tuple[str, 'RedisModel']]
 | 
			
		||||
 | 
			
		||||
    def __invert__(self):
 | 
			
		||||
        return NegatedExpression(self)
 | 
			
		||||
 | 
			
		||||
    def __and__(self, other: ExpressionOrModelField):
 | 
			
		||||
        return Expression(left=self, op=Operators.AND, right=other)
 | 
			
		||||
        return Expression(left=self, op=Operators.AND, right=other, parents=self.parents)
 | 
			
		||||
 | 
			
		||||
    def __or__(self, other: ExpressionOrModelField):
 | 
			
		||||
        return Expression(left=self, op=Operators.OR, right=other)
 | 
			
		||||
        return Expression(left=self, op=Operators.OR, right=other, parents=self.parents)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def name(self):
 | 
			
		||||
| 
						 | 
				
			
			@ -189,26 +198,34 @@ ExpressionOrNegated = Union[Expression, NegatedExpression]
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class ExpressionProxy:
 | 
			
		||||
    def __init__(self, field: ModelField):
 | 
			
		||||
    def __init__(self, field: ModelField, parents: List[Tuple[str, 'RedisModel']]):
 | 
			
		||||
        self.field = field
 | 
			
		||||
        self.parents = parents
 | 
			
		||||
 | 
			
		||||
    def __eq__(self, other: Any) -> Expression:  # type: ignore[override]
 | 
			
		||||
        return Expression(left=self.field, op=Operators.EQ, right=other)
 | 
			
		||||
        return Expression(left=self.field, op=Operators.EQ, right=other, parents=self.parents)
 | 
			
		||||
 | 
			
		||||
    def __ne__(self, other: Any) -> Expression:  # type: ignore[override]
 | 
			
		||||
        return Expression(left=self.field, op=Operators.NE, right=other)
 | 
			
		||||
        return Expression(left=self.field, op=Operators.NE, right=other, parents=self.parents)
 | 
			
		||||
 | 
			
		||||
    def __lt__(self, other: Any) -> Expression:  # type: ignore[override]
 | 
			
		||||
        return Expression(left=self.field, op=Operators.LT, right=other)
 | 
			
		||||
        return Expression(left=self.field, op=Operators.LT, right=other, parents=self.parents)
 | 
			
		||||
 | 
			
		||||
    def __le__(self, other: Any) -> Expression:  # type: ignore[override]
 | 
			
		||||
        return Expression(left=self.field, op=Operators.LE, right=other)
 | 
			
		||||
        return Expression(left=self.field, op=Operators.LE, right=other, parents=self.parents)
 | 
			
		||||
 | 
			
		||||
    def __gt__(self, other: Any) -> Expression:  # type: ignore[override]
 | 
			
		||||
        return Expression(left=self.field, op=Operators.GT, right=other)
 | 
			
		||||
        return Expression(left=self.field, op=Operators.GT, right=other, parents=self.parents)
 | 
			
		||||
 | 
			
		||||
    def __ge__(self, other: Any) -> Expression:  # type: ignore[override]
 | 
			
		||||
        return Expression(left=self.field, op=Operators.GE, right=other)
 | 
			
		||||
        return Expression(left=self.field, op=Operators.GE, right=other, parents=self.parents)
 | 
			
		||||
 | 
			
		||||
    def __getattr__(self, item):
 | 
			
		||||
        attr = getattr(self.field.outer_type_, item)
 | 
			
		||||
        if isinstance(attr, self.__class__):
 | 
			
		||||
            attr.parents.insert(0, (self.field.name, self.field.outer_type_))
 | 
			
		||||
            attr.parents = attr.parents + self.parents
 | 
			
		||||
        return attr
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class QueryNotSupportedError(Exception):
 | 
			
		||||
| 
						 | 
				
			
			@ -265,7 +282,10 @@ class FindQuery:
 | 
			
		|||
        if self.expressions:
 | 
			
		||||
            self._expression = reduce(operator.and_, self.expressions)
 | 
			
		||||
        else:
 | 
			
		||||
            self._expression = Expression(left=None, right=None, op=Operators.ALL)
 | 
			
		||||
            # TODO: Is there a better way to support the "give me all records" query?
 | 
			
		||||
            # Also -- if we do it this way, we need different type annotations.
 | 
			
		||||
            self._expression = Expression(left=None, right=None, op=Operators.ALL,
 | 
			
		||||
                                          parents=[])
 | 
			
		||||
        return self._expression
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
| 
						 | 
				
			
			@ -316,7 +336,11 @@ class FindQuery:
 | 
			
		|||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def resolve_value(cls, field_name: str, field_type: RediSearchFieldTypes,
 | 
			
		||||
                      field_info: PydanticFieldInfo, op: Operators, value: Any) -> str:
 | 
			
		||||
                      field_info: PydanticFieldInfo, op: Operators, value: Any,
 | 
			
		||||
                      parents: List[Tuple[str, 'RedisModel']]) -> str:
 | 
			
		||||
        if parents:
 | 
			
		||||
            prefix = "_".join([p[0] for p in parents])
 | 
			
		||||
            field_name = f"{prefix}_{field_name}"
 | 
			
		||||
        result = ""
 | 
			
		||||
        if field_type is RediSearchFieldTypes.TEXT:
 | 
			
		||||
            result = f"@{field_name}:"
 | 
			
		||||
| 
						 | 
				
			
			@ -427,6 +451,9 @@ class FindQuery:
 | 
			
		|||
            field_type = cls.resolve_field_type(expression.left)
 | 
			
		||||
            field_name = expression.left.name
 | 
			
		||||
            field_info = expression.left.field_info
 | 
			
		||||
            if not field_info or not getattr(field_info, "index", None):
 | 
			
		||||
                raise QueryNotSupportedError(f"You tried to query by a field ({field_name}) "
 | 
			
		||||
                                             f"that isn't indexed. See docs: TODO")
 | 
			
		||||
        else:
 | 
			
		||||
            raise QueryNotSupportedError(f"A query expression should start with either a field "
 | 
			
		||||
                                         f"or an expression enclosed in parenthesis. See docs: "
 | 
			
		||||
| 
						 | 
				
			
			@ -454,7 +481,8 @@ class FindQuery:
 | 
			
		|||
            if isinstance(right, ModelField):
 | 
			
		||||
                raise QueryNotSupportedError("Comparing fields is not supported. See docs: TODO")
 | 
			
		||||
            else:
 | 
			
		||||
                result += cls.resolve_value(field_name, field_type, field_info, expression.op, right)
 | 
			
		||||
                result += cls.resolve_value(field_name, field_type, field_info,
 | 
			
		||||
                                            expression.op, right, expression.parents)
 | 
			
		||||
 | 
			
		||||
        if encompassing_expression_is_negated:
 | 
			
		||||
            result = f"-({result})"
 | 
			
		||||
| 
						 | 
				
			
			@ -705,6 +733,7 @@ class MetaProtocol(Protocol):
 | 
			
		|||
    primary_key_creator_cls: Type[PrimaryKeyCreator]
 | 
			
		||||
    index_name: str
 | 
			
		||||
    abstract: bool
 | 
			
		||||
    embedded: bool
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
| 
						 | 
				
			
			@ -722,6 +751,7 @@ class DefaultMeta:
 | 
			
		|||
    primary_key_creator_cls: Optional[Type[PrimaryKeyCreator]] = None
 | 
			
		||||
    index_name: Optional[str] = None
 | 
			
		||||
    abstract: Optional[bool] = False
 | 
			
		||||
    embedded: Optional[bool] = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ModelMeta(ModelMetaclass):
 | 
			
		||||
| 
						 | 
				
			
			@ -730,6 +760,11 @@ class ModelMeta(ModelMetaclass):
 | 
			
		|||
    def __new__(cls, name, bases, attrs, **kwargs):  # noqa C901
 | 
			
		||||
        meta = attrs.pop('Meta', None)
 | 
			
		||||
        new_class = super().__new__(cls, name, bases, attrs, **kwargs)
 | 
			
		||||
 | 
			
		||||
        # The fact that there is a Meta field and _meta field is important: a
 | 
			
		||||
        # user may have given us a Meta object with their configuration, while
 | 
			
		||||
        # we might have inherited _meta from a parent class, and should
 | 
			
		||||
        # therefore use some of the inherited fields.
 | 
			
		||||
        meta = meta or getattr(new_class, 'Meta', None)
 | 
			
		||||
        base_meta = getattr(new_class, '_meta', None)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -739,8 +774,9 @@ class ModelMeta(ModelMetaclass):
 | 
			
		|||
        elif base_meta:
 | 
			
		||||
            new_class._meta = deepcopy(base_meta)
 | 
			
		||||
            new_class.Meta = new_class._meta
 | 
			
		||||
            # Unset inherited values we don't want to reuse (typically based on the model name).
 | 
			
		||||
            new_class._meta.abstract = False
 | 
			
		||||
            # Unset inherited values we don't want to reuse (typically based on
 | 
			
		||||
            # the model name).
 | 
			
		||||
            new_class._meta.embedded = False
 | 
			
		||||
            new_class._meta.model_key_prefix = None
 | 
			
		||||
            new_class._meta.index_name = None
 | 
			
		||||
        else:
 | 
			
		||||
| 
						 | 
				
			
			@ -750,7 +786,7 @@ class ModelMeta(ModelMetaclass):
 | 
			
		|||
        # Create proxies for each model field so that we can use the field
 | 
			
		||||
        # in queries, like Model.get(Model.field_name == 1)
 | 
			
		||||
        for field_name, field in new_class.__fields__.items():
 | 
			
		||||
            setattr(new_class, field_name, ExpressionProxy(field))
 | 
			
		||||
            setattr(new_class, field_name, ExpressionProxy(field, []))
 | 
			
		||||
            # Check if this is our FieldInfo version with extended ORM metadata.
 | 
			
		||||
            if isinstance(field.field_info, FieldInfo):
 | 
			
		||||
                if field.field_info.primary_key:
 | 
			
		||||
| 
						 | 
				
			
			@ -774,8 +810,9 @@ class ModelMeta(ModelMetaclass):
 | 
			
		|||
            new_class._meta.index_name = f"{new_class._meta.global_key_prefix}:" \
 | 
			
		||||
                                         f"{new_class._meta.model_key_prefix}:index"
 | 
			
		||||
 | 
			
		||||
        # Not an abstract model class
 | 
			
		||||
        if abc.ABC not in bases:
 | 
			
		||||
        # Not an abstract model class or embedded model, so we should let the
 | 
			
		||||
        # Migrator create indexes for it.
 | 
			
		||||
        if abc.ABC not in bases and not new_class._meta.embedded:
 | 
			
		||||
            key = f"{new_class.__module__}.{new_class.__qualname__}"
 | 
			
		||||
            model_registry[key] = new_class
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -967,7 +1004,7 @@ class HashModel(RedisModel, abc.ABC):
 | 
			
		|||
        schema_parts = []
 | 
			
		||||
 | 
			
		||||
        for name, field in cls.__fields__.items():
 | 
			
		||||
            # TODO: Merge this code with schema_for_type()
 | 
			
		||||
            # TODO: Merge this code with schema_for_type()?
 | 
			
		||||
            _type = field.outer_type_
 | 
			
		||||
            if getattr(field.field_info, 'primary_key', None):
 | 
			
		||||
                if issubclass(_type, str):
 | 
			
		||||
| 
						 | 
				
			
			@ -1047,6 +1084,8 @@ class JsonModel(RedisModel, abc.ABC):
 | 
			
		|||
        schema_parts = []
 | 
			
		||||
        json_path = "$"
 | 
			
		||||
 | 
			
		||||
        if cls.__name__ == "Address":
 | 
			
		||||
            import ipdb; ipdb.set_trace()
 | 
			
		||||
        for name, field in cls.__fields__.items():
 | 
			
		||||
            # TODO: Merge this code with schema_for_type()?
 | 
			
		||||
            _type = field.outer_type_
 | 
			
		||||
| 
						 | 
				
			
			@ -1070,21 +1109,20 @@ class JsonModel(RedisModel, abc.ABC):
 | 
			
		|||
                    log.warning("Model %s defined an empty list field: %s", cls, name)
 | 
			
		||||
                    continue
 | 
			
		||||
                embedded_cls = embedded_cls[0]
 | 
			
		||||
                schema_parts.append(cls.schema_for_type(f"{json_path}.{name}[]", name, f"{name}",
 | 
			
		||||
                # TODO: Should this have a name prefix?
 | 
			
		||||
                schema_parts.append(cls.schema_for_type(f"{json_path}.{name}[]", name, name,
 | 
			
		||||
                                                        embedded_cls, field.field_info))
 | 
			
		||||
            elif issubclass(_type, RedisModel):
 | 
			
		||||
                schema_parts.append(cls.schema_for_type(f"{json_path}.{name}", name, f"{name}", _type,
 | 
			
		||||
                schema_parts.append(cls.schema_for_type(f"{json_path}.{name}", name, name, _type,
 | 
			
		||||
                                                        field.field_info))
 | 
			
		||||
        return schema_parts
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    # TODO: We need both the "name" of the field (address_line_1) as we'll
 | 
			
		||||
    #  find it in the JSON document, AND the name of the field as it should
 | 
			
		||||
    #  be in the redisearch schema (address_address_line_1). Maybe both "name"
 | 
			
		||||
    #  and "name_prefix"?
 | 
			
		||||
    def schema_for_type(cls, json_path: str, name: str, name_prefix: str, typ: Any,
 | 
			
		||||
                        field_info: PydanticFieldInfo) -> str:
 | 
			
		||||
        index_field_name = f"{name_prefix}{name}"
 | 
			
		||||
        if name == "description":
 | 
			
		||||
            import ipdb; ipdb.set_trace()
 | 
			
		||||
        index_field_name = f"{name_prefix}_{name}"
 | 
			
		||||
        should_index = getattr(field_info, 'index', False)
 | 
			
		||||
 | 
			
		||||
        if get_origin(typ) == list:
 | 
			
		||||
| 
						 | 
				
			
			@ -1094,15 +1132,14 @@ class JsonModel(RedisModel, abc.ABC):
 | 
			
		|||
                log.warning("Model %s defined an empty list field: %s", cls, name)
 | 
			
		||||
                return ""
 | 
			
		||||
            embedded_cls = embedded_cls[0]
 | 
			
		||||
            # TODO: We need to pass the "JSON Path so far" which should include the
 | 
			
		||||
            #  correct syntax for an array.
 | 
			
		||||
            return cls.schema_for_type(f"{json_path}[]", name, f"{name_prefix}{name}", embedded_cls, field_info)
 | 
			
		||||
            return cls.schema_for_type(f"{json_path}[]", name, f"{name_prefix}{name}",
 | 
			
		||||
                                       embedded_cls, field_info)
 | 
			
		||||
        elif issubclass(typ, RedisModel):
 | 
			
		||||
            sub_fields = []
 | 
			
		||||
            for embedded_name, field in typ.__fields__.items():
 | 
			
		||||
                sub_fields.append(cls.schema_for_type(f"{json_path}.{embedded_name}",
 | 
			
		||||
                                                      embedded_name,
 | 
			
		||||
                                                      f"{name_prefix}_",
 | 
			
		||||
                                                      f"{name_prefix}_{embedded_name}",
 | 
			
		||||
                                                      field.outer_type_,
 | 
			
		||||
                                                      field.field_info))
 | 
			
		||||
            return " ".join(filter(None, sub_fields))
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue