Remove unique, unsued migration validation
This commit is contained in:
		
							parent
							
								
									85ba111260
								
							
						
					
					
						commit
						a788cbedbb
					
				
					 4 changed files with 116 additions and 86 deletions
				
			
		| 
						 | 
				
			
			@ -31,8 +31,7 @@ def schema_hash_key(index_name):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def create_index(index_name, schema, current_hash):
 | 
			
		||||
    redis.execute_command(f"ft.create {index_name} "
 | 
			
		||||
                          f"{schema}")
 | 
			
		||||
    redis.execute_command(f"ft.create {index_name} {schema}")
 | 
			
		||||
    redis.set(schema_hash_key(index_name), current_hash)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -41,7 +40,7 @@ class MigrationAction(Enum):
 | 
			
		|||
    DROP = 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass            
 | 
			
		||||
@dataclass
 | 
			
		||||
class IndexMigration:
 | 
			
		||||
    model_name: str
 | 
			
		||||
    index_name: str
 | 
			
		||||
| 
						 | 
				
			
			@ -49,16 +48,16 @@ class IndexMigration:
 | 
			
		|||
    hash: str
 | 
			
		||||
    action: MigrationAction
 | 
			
		||||
    previous_hash: Optional[str] = None
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def run(self):
 | 
			
		||||
        if self.action is MigrationAction.CREATE:
 | 
			
		||||
            self.create()
 | 
			
		||||
        elif self.action is MigrationAction.DROP:
 | 
			
		||||
            self.drop()
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def create(self):
 | 
			
		||||
        return create_index(self.index_name, self.schema, self.hash)
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def drop(self):
 | 
			
		||||
        redis.execute_command(f"FT.DROPINDEX {self.index_name}")
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -86,6 +85,7 @@ class Migrator:
 | 
			
		|||
                self.migrations.append(
 | 
			
		||||
                    IndexMigration(name, cls.Meta.index_name, schema, current_hash,
 | 
			
		||||
                                   MigrationAction.CREATE))
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            stored_hash = redis.get(hash_key)
 | 
			
		||||
            schema_out_of_date = current_hash != stored_hash
 | 
			
		||||
| 
						 | 
				
			
			@ -97,17 +97,7 @@ class Migrator:
 | 
			
		|||
                                   MigrationAction.DROP, stored_hash))
 | 
			
		||||
                self.migrations.append(
 | 
			
		||||
                     IndexMigration(name, cls.Meta.index_name, schema, current_hash,
 | 
			
		||||
                                   MigrationAction.CREATE, stored_hash))
 | 
			
		||||
    
 | 
			
		||||
    @property
 | 
			
		||||
    def valid_migrations(self):
 | 
			
		||||
        return self.missing_indexes.keys() + self.out_of_date_indexes.keys()
 | 
			
		||||
        
 | 
			
		||||
    def validate_migration(self, model_class_name):
 | 
			
		||||
        if model_class_name not in self.valid_migrations:
 | 
			
		||||
            migrations = ", ".join(self.valid_migrations)
 | 
			
		||||
            raise RuntimeError(f"No migration found for {model_class_name}."
 | 
			
		||||
                               f"Valid migrations are: {migrations}")
 | 
			
		||||
                                    MigrationAction.CREATE, stored_hash))
 | 
			
		||||
 | 
			
		||||
    def run(self):
 | 
			
		||||
        # TODO: Migration history
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,7 +2,7 @@ import abc
 | 
			
		|||
import dataclasses
 | 
			
		||||
import decimal
 | 
			
		||||
import operator
 | 
			
		||||
from copy import copy
 | 
			
		||||
from copy import copy, deepcopy
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from functools import reduce
 | 
			
		||||
| 
						 | 
				
			
			@ -60,9 +60,7 @@ class Operators(Enum):
 | 
			
		|||
    NOT = 9
 | 
			
		||||
    IN = 10
 | 
			
		||||
    NOT_IN = 11
 | 
			
		||||
    GTE = 12
 | 
			
		||||
    LTE = 13
 | 
			
		||||
    LIKE = 14
 | 
			
		||||
    LIKE = 12
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
| 
						 | 
				
			
			@ -129,7 +127,7 @@ class FindQuery:
 | 
			
		|||
 | 
			
		||||
        # TODO: GEO
 | 
			
		||||
        # TODO: TAG (other than PK)
 | 
			
		||||
        if any(isinstance(field_type, t) for t in NUMERIC_TYPES):
 | 
			
		||||
        if any(issubclass(field_type, t) for t in NUMERIC_TYPES):
 | 
			
		||||
            return RediSearchFieldTypes.NUMERIC
 | 
			
		||||
        else:
 | 
			
		||||
            return RediSearchFieldTypes.TEXT
 | 
			
		||||
| 
						 | 
				
			
			@ -159,9 +157,9 @@ class FindQuery:
 | 
			
		|||
                result += f"@{field_name}:[({value} +inf]"
 | 
			
		||||
            elif op is Operators.LT:
 | 
			
		||||
                result += f"@{field_name}:[-inf ({value}]"
 | 
			
		||||
            elif op is Operators.GTE:
 | 
			
		||||
            elif op is Operators.GE:
 | 
			
		||||
                result += f"@{field_name}:[{value} +inf]"
 | 
			
		||||
            elif op is Operators.LTE:
 | 
			
		||||
            elif op is Operators.LE:
 | 
			
		||||
                result += f"@{field_name}:[-inf {value}]"
 | 
			
		||||
 | 
			
		||||
        return result
 | 
			
		||||
| 
						 | 
				
			
			@ -264,10 +262,6 @@ class ExpressionProxy:
 | 
			
		|||
    def __ge__(self, other: Any) -> Expression:
 | 
			
		||||
        return Expression(left=self.field, op=Operators.GE, right=other)
 | 
			
		||||
 | 
			
		||||
    def __invert__(self):
 | 
			
		||||
        import ipdb; ipdb.set_trace()
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def __dataclass_transform__(
 | 
			
		||||
    *,
 | 
			
		||||
| 
						 | 
				
			
			@ -283,15 +277,13 @@ class FieldInfo(PydanticFieldInfo):
 | 
			
		|||
    def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
 | 
			
		||||
        primary_key = kwargs.pop("primary_key", False)
 | 
			
		||||
        sortable = kwargs.pop("sortable", Undefined)
 | 
			
		||||
        foreign_key = kwargs.pop("foreign_key", Undefined)
 | 
			
		||||
        index = kwargs.pop("index", Undefined)
 | 
			
		||||
        unique = kwargs.pop("unique", Undefined)
 | 
			
		||||
        full_text_search = kwargs.pop("full_text_search", Undefined)
 | 
			
		||||
        super().__init__(default=default, **kwargs)
 | 
			
		||||
        self.primary_key = primary_key
 | 
			
		||||
        self.sortable = sortable
 | 
			
		||||
        self.foreign_key = foreign_key
 | 
			
		||||
        self.index = index
 | 
			
		||||
        self.unique = unique
 | 
			
		||||
        self.full_text_search = full_text_search
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RelationshipInfo(Representation):
 | 
			
		||||
| 
						 | 
				
			
			@ -331,10 +323,9 @@ def Field(
 | 
			
		|||
    allow_mutation: bool = True,
 | 
			
		||||
    regex: str = None,
 | 
			
		||||
    primary_key: bool = False,
 | 
			
		||||
    unique: bool = False,
 | 
			
		||||
    foreign_key: Optional[Any] = None,
 | 
			
		||||
    sortable: Union[bool, UndefinedType] = Undefined,
 | 
			
		||||
    index: Union[bool, UndefinedType] = Undefined,
 | 
			
		||||
    full_text_search: Union[bool, UndefinedType] = Undefined,
 | 
			
		||||
    schema_extra: Optional[Dict[str, Any]] = None,
 | 
			
		||||
) -> Any:
 | 
			
		||||
    current_schema_extra = schema_extra or {}
 | 
			
		||||
| 
						 | 
				
			
			@ -359,10 +350,9 @@ def Field(
 | 
			
		|||
        allow_mutation=allow_mutation,
 | 
			
		||||
        regex=regex,
 | 
			
		||||
        primary_key=primary_key,
 | 
			
		||||
        unique=unique,
 | 
			
		||||
        foreign_key=foreign_key,
 | 
			
		||||
        sortable=sortable,
 | 
			
		||||
        index=index,
 | 
			
		||||
        full_text_search=full_text_search,
 | 
			
		||||
        **current_schema_extra,
 | 
			
		||||
    )
 | 
			
		||||
    field_info._validate()
 | 
			
		||||
| 
						 | 
				
			
			@ -394,39 +384,39 @@ class ModelMeta(ModelMetaclass):
 | 
			
		|||
        meta = meta or getattr(new_class, 'Meta', None)
 | 
			
		||||
        base_meta = getattr(new_class, '_meta', None)
 | 
			
		||||
 | 
			
		||||
        if meta and meta is not DefaultMeta:
 | 
			
		||||
        if meta and meta != DefaultMeta and meta != base_meta:
 | 
			
		||||
            new_class.Meta = meta
 | 
			
		||||
            new_class._meta = meta
 | 
			
		||||
        elif base_meta:
 | 
			
		||||
            new_class._meta = copy(base_meta)
 | 
			
		||||
            new_class._meta = deepcopy(base_meta)
 | 
			
		||||
            new_class.Meta = new_class._meta
 | 
			
		||||
            # Unset inherited values we don't want to reuse (typically based on the model name).
 | 
			
		||||
            new_class._meta.abstract = False
 | 
			
		||||
            new_class._meta.model_key_prefix = None
 | 
			
		||||
            new_class._meta.index_name = None
 | 
			
		||||
        else:
 | 
			
		||||
            new_class._meta = copy(DefaultMeta)
 | 
			
		||||
            new_class._meta = deepcopy(DefaultMeta)
 | 
			
		||||
            new_class.Meta = new_class._meta
 | 
			
		||||
 | 
			
		||||
        # Not an abstract model class
 | 
			
		||||
        if abc.ABC not in bases:
 | 
			
		||||
            key = f"{new_class.__module__}.{new_class.__qualname__}"
 | 
			
		||||
            key = f"{new_class.__module__}.{new_class.__name__}"
 | 
			
		||||
            model_registry[key] = new_class
 | 
			
		||||
 | 
			
		||||
        # Create proxies for each model field so that we can use the field
 | 
			
		||||
        # in queries, like Model.get(Model.field_name == 1)
 | 
			
		||||
        for name, field in new_class.__fields__.items():
 | 
			
		||||
            setattr(new_class, name, ExpressionProxy(field))
 | 
			
		||||
        for field_name, field in new_class.__fields__.items():
 | 
			
		||||
            setattr(new_class, field_name, ExpressionProxy(field))
 | 
			
		||||
            # Check if this is our FieldInfo version with extended ORM metadata.
 | 
			
		||||
            if isinstance(field.field_info, FieldInfo):
 | 
			
		||||
                if field.field_info.primary_key:
 | 
			
		||||
                    new_class._meta.primary_key = PrimaryKey(name=name, field=field)
 | 
			
		||||
                    new_class._meta.primary_key = PrimaryKey(name=field_name, field=field)
 | 
			
		||||
 | 
			
		||||
        if not getattr(new_class._meta, 'global_key_prefix', None):
 | 
			
		||||
            new_class._meta.global_key_prefix = getattr(base_meta, "global_key_prefix", "")
 | 
			
		||||
        if not getattr(new_class._meta, 'model_key_prefix', None):
 | 
			
		||||
            # Don't look at the base class for this.
 | 
			
		||||
            new_class._meta.model_key_prefix = f"{new_class.__name__.lower()}"
 | 
			
		||||
            new_class._meta.model_key_prefix = f"{new_class.__module__}.{new_class.__name__}"
 | 
			
		||||
        if not getattr(new_class._meta, 'primary_key_pattern', None):
 | 
			
		||||
            new_class._meta.primary_key_pattern = getattr(base_meta, "primary_key_pattern",
 | 
			
		||||
                                                          "{pk}")
 | 
			
		||||
| 
						 | 
				
			
			@ -457,6 +447,13 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
 | 
			
		|||
        super().__init__(**data)
 | 
			
		||||
        __pydantic_self__.validate_primary_key()
 | 
			
		||||
 | 
			
		||||
    def __lt__(self, other):
 | 
			
		||||
        my_keys = set(self.__fields__.keys())
 | 
			
		||||
        other_keys = set(other.__fields__.keys())
 | 
			
		||||
        shared_keys = list(my_keys & other_keys)
 | 
			
		||||
        lt = [getattr(self, k) < getattr(other, k) for k in shared_keys]
 | 
			
		||||
        return len(lt) > len(shared_keys) / 2
 | 
			
		||||
 | 
			
		||||
    @validator("pk", always=True)
 | 
			
		||||
    def validate_pk(cls, v):
 | 
			
		||||
        if not v:
 | 
			
		||||
| 
						 | 
				
			
			@ -607,11 +604,16 @@ class HashModel(RedisModel, abc.ABC):
 | 
			
		|||
        return val
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def schema_for_type(cls, name, typ: Type):
 | 
			
		||||
    def schema_for_type(cls, name, typ: Type, field_info: FieldInfo):
 | 
			
		||||
        if any(issubclass(typ, t) for t in NUMERIC_TYPES):
 | 
			
		||||
            return f"{name} NUMERIC"
 | 
			
		||||
        elif issubclass(typ, str):
 | 
			
		||||
            if getattr(field_info, 'full_text_search', False) is True:
 | 
			
		||||
                return f"{name} TAG {name}_fts TEXT"
 | 
			
		||||
            else:
 | 
			
		||||
                return f"{name} TAG"
 | 
			
		||||
        else:
 | 
			
		||||
            return f"{name} TEXT"
 | 
			
		||||
            return f"{name} TAG"
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def schema(cls):
 | 
			
		||||
| 
						 | 
				
			
			@ -624,12 +626,12 @@ class HashModel(RedisModel, abc.ABC):
 | 
			
		|||
                if issubclass(_type, str):
 | 
			
		||||
                    redisearch_field = f"{name} TAG"
 | 
			
		||||
                else:
 | 
			
		||||
                    redisearch_field = cls.schema_for_type(name, _type)
 | 
			
		||||
                    redisearch_field = cls.schema_for_type(name, _type, field.field_info)
 | 
			
		||||
                schema_parts.append(redisearch_field)
 | 
			
		||||
            else:
 | 
			
		||||
                schema_parts.append(cls.schema_for_type(name, _type))
 | 
			
		||||
            if getattr(field.field_info, 'sortable', False):
 | 
			
		||||
                schema_parts.append("SORTABLE")
 | 
			
		||||
            elif getattr(field.field_info, 'index', None) is True:
 | 
			
		||||
                schema_parts.append(cls.schema_for_type(name, _type, field.field_info))
 | 
			
		||||
                if getattr(field.field_info, 'sortable', False) is True:
 | 
			
		||||
                    schema_parts.append("SORTABLE")
 | 
			
		||||
        return " ".join(schema_parts)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue