Remove unique, unsued migration validation

This commit is contained in:
Andrew Brookins 2021-09-17 09:27:11 -07:00
parent 85ba111260
commit a788cbedbb
4 changed files with 116 additions and 86 deletions

View file

@ -31,8 +31,7 @@ def schema_hash_key(index_name):
def create_index(index_name, schema, current_hash):
redis.execute_command(f"ft.create {index_name} "
f"{schema}")
redis.execute_command(f"ft.create {index_name} {schema}")
redis.set(schema_hash_key(index_name), current_hash)
@ -86,6 +85,7 @@ class Migrator:
self.migrations.append(
IndexMigration(name, cls.Meta.index_name, schema, current_hash,
MigrationAction.CREATE))
continue
stored_hash = redis.get(hash_key)
schema_out_of_date = current_hash != stored_hash
@ -99,16 +99,6 @@ class Migrator:
IndexMigration(name, cls.Meta.index_name, schema, current_hash,
MigrationAction.CREATE, stored_hash))
@property
def valid_migrations(self):
return self.missing_indexes.keys() + self.out_of_date_indexes.keys()
def validate_migration(self, model_class_name):
if model_class_name not in self.valid_migrations:
migrations = ", ".join(self.valid_migrations)
raise RuntimeError(f"No migration found for {model_class_name}."
f"Valid migrations are: {migrations}")
def run(self):
# TODO: Migration history
# TODO: Dry run with output

View file

@ -2,7 +2,7 @@ import abc
import dataclasses
import decimal
import operator
from copy import copy
from copy import copy, deepcopy
from dataclasses import dataclass
from enum import Enum
from functools import reduce
@ -60,9 +60,7 @@ class Operators(Enum):
NOT = 9
IN = 10
NOT_IN = 11
GTE = 12
LTE = 13
LIKE = 14
LIKE = 12
@dataclass
@ -129,7 +127,7 @@ class FindQuery:
# TODO: GEO
# TODO: TAG (other than PK)
if any(isinstance(field_type, t) for t in NUMERIC_TYPES):
if any(issubclass(field_type, t) for t in NUMERIC_TYPES):
return RediSearchFieldTypes.NUMERIC
else:
return RediSearchFieldTypes.TEXT
@ -159,9 +157,9 @@ class FindQuery:
result += f"@{field_name}:[({value} +inf]"
elif op is Operators.LT:
result += f"@{field_name}:[-inf ({value}]"
elif op is Operators.GTE:
elif op is Operators.GE:
result += f"@{field_name}:[{value} +inf]"
elif op is Operators.LTE:
elif op is Operators.LE:
result += f"@{field_name}:[-inf {value}]"
return result
@ -264,10 +262,6 @@ class ExpressionProxy:
def __ge__(self, other: Any) -> Expression:
return Expression(left=self.field, op=Operators.GE, right=other)
def __invert__(self):
import ipdb; ipdb.set_trace()
pass
def __dataclass_transform__(
*,
@ -283,15 +277,13 @@ class FieldInfo(PydanticFieldInfo):
def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
primary_key = kwargs.pop("primary_key", False)
sortable = kwargs.pop("sortable", Undefined)
foreign_key = kwargs.pop("foreign_key", Undefined)
index = kwargs.pop("index", Undefined)
unique = kwargs.pop("unique", Undefined)
full_text_search = kwargs.pop("full_text_search", Undefined)
super().__init__(default=default, **kwargs)
self.primary_key = primary_key
self.sortable = sortable
self.foreign_key = foreign_key
self.index = index
self.unique = unique
self.full_text_search = full_text_search
class RelationshipInfo(Representation):
@ -331,10 +323,9 @@ def Field(
allow_mutation: bool = True,
regex: str = None,
primary_key: bool = False,
unique: bool = False,
foreign_key: Optional[Any] = None,
sortable: Union[bool, UndefinedType] = Undefined,
index: Union[bool, UndefinedType] = Undefined,
full_text_search: Union[bool, UndefinedType] = Undefined,
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any:
current_schema_extra = schema_extra or {}
@ -359,10 +350,9 @@ def Field(
allow_mutation=allow_mutation,
regex=regex,
primary_key=primary_key,
unique=unique,
foreign_key=foreign_key,
sortable=sortable,
index=index,
full_text_search=full_text_search,
**current_schema_extra,
)
field_info._validate()
@ -394,39 +384,39 @@ class ModelMeta(ModelMetaclass):
meta = meta or getattr(new_class, 'Meta', None)
base_meta = getattr(new_class, '_meta', None)
if meta and meta is not DefaultMeta:
if meta and meta != DefaultMeta and meta != base_meta:
new_class.Meta = meta
new_class._meta = meta
elif base_meta:
new_class._meta = copy(base_meta)
new_class._meta = deepcopy(base_meta)
new_class.Meta = new_class._meta
# Unset inherited values we don't want to reuse (typically based on the model name).
new_class._meta.abstract = False
new_class._meta.model_key_prefix = None
new_class._meta.index_name = None
else:
new_class._meta = copy(DefaultMeta)
new_class._meta = deepcopy(DefaultMeta)
new_class.Meta = new_class._meta
# Not an abstract model class
if abc.ABC not in bases:
key = f"{new_class.__module__}.{new_class.__qualname__}"
key = f"{new_class.__module__}.{new_class.__name__}"
model_registry[key] = new_class
# Create proxies for each model field so that we can use the field
# in queries, like Model.get(Model.field_name == 1)
for name, field in new_class.__fields__.items():
setattr(new_class, name, ExpressionProxy(field))
for field_name, field in new_class.__fields__.items():
setattr(new_class, field_name, ExpressionProxy(field))
# Check if this is our FieldInfo version with extended ORM metadata.
if isinstance(field.field_info, FieldInfo):
if field.field_info.primary_key:
new_class._meta.primary_key = PrimaryKey(name=name, field=field)
new_class._meta.primary_key = PrimaryKey(name=field_name, field=field)
if not getattr(new_class._meta, 'global_key_prefix', None):
new_class._meta.global_key_prefix = getattr(base_meta, "global_key_prefix", "")
if not getattr(new_class._meta, 'model_key_prefix', None):
# Don't look at the base class for this.
new_class._meta.model_key_prefix = f"{new_class.__name__.lower()}"
new_class._meta.model_key_prefix = f"{new_class.__module__}.{new_class.__name__}"
if not getattr(new_class._meta, 'primary_key_pattern', None):
new_class._meta.primary_key_pattern = getattr(base_meta, "primary_key_pattern",
"{pk}")
@ -457,6 +447,13 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
super().__init__(**data)
__pydantic_self__.validate_primary_key()
def __lt__(self, other):
my_keys = set(self.__fields__.keys())
other_keys = set(other.__fields__.keys())
shared_keys = list(my_keys & other_keys)
lt = [getattr(self, k) < getattr(other, k) for k in shared_keys]
return len(lt) > len(shared_keys) / 2
@validator("pk", always=True)
def validate_pk(cls, v):
if not v:
@ -607,11 +604,16 @@ class HashModel(RedisModel, abc.ABC):
return val
@classmethod
def schema_for_type(cls, name, typ: Type):
def schema_for_type(cls, name, typ: Type, field_info: FieldInfo):
if any(issubclass(typ, t) for t in NUMERIC_TYPES):
return f"{name} NUMERIC"
elif issubclass(typ, str):
if getattr(field_info, 'full_text_search', False) is True:
return f"{name} TAG {name}_fts TEXT"
else:
return f"{name} TEXT"
return f"{name} TAG"
else:
return f"{name} TAG"
@classmethod
def schema(cls):
@ -624,11 +626,11 @@ class HashModel(RedisModel, abc.ABC):
if issubclass(_type, str):
redisearch_field = f"{name} TAG"
else:
redisearch_field = cls.schema_for_type(name, _type)
redisearch_field = cls.schema_for_type(name, _type, field.field_info)
schema_parts.append(redisearch_field)
else:
schema_parts.append(cls.schema_for_type(name, _type))
if getattr(field.field_info, 'sortable', False):
elif getattr(field.field_info, 'index', None) is True:
schema_parts.append(cls.schema_for_type(name, _type, field.field_info))
if getattr(field.field_info, 'sortable', False) is True:
schema_parts.append("SORTABLE")
return " ".join(schema_parts)

View file

@ -31,14 +31,47 @@ class Order(BaseHashModel):
class Member(BaseHashModel):
first_name: str
last_name: str
email: str = Field(unique=True, index=True)
email: str = Field(index=True)
join_date: datetime.date
age: int
class Meta:
model_key_prefix = "member"
primary_key_pattern = ""
@pytest.fixture()
def members():
member1 = Member(
first_name="Andrew",
last_name="Brookins",
email="a@example.com",
age=38,
join_date=today
)
member2 = Member(
first_name="Kim",
last_name="Brookins",
email="k@example.com",
age=34,
join_date=today
)
member3 = Member(
first_name="Andrew",
last_name="Smith",
email="as@example.com",
age=100,
join_date=today
)
member1.save()
member2.save()
member3.save()
yield member1, member2, member3
def test_validates_required_fields():
# Raises ValidationError: last_name, address are required
with pytest.raises(ValidationError):
@ -65,7 +98,8 @@ def test_validation_passes():
first_name="Andrew",
last_name="Brookins",
email="a@example.com",
join_date=today
join_date=today,
age=38
)
assert member.first_name == "Andrew"
@ -75,7 +109,8 @@ def test_saves_model_and_creates_pk():
first_name="Andrew",
last_name="Brookins",
email="a@example.com",
join_date=today
join_date=today,
age=38
)
# Save a model instance to Redis
member.save()
@ -137,38 +172,14 @@ def test_updates_a_model():
Member.find(Member.last_name == "Brookins").update(last_name="Smith")
def test_exact_match_queries():
member1 = Member(
first_name="Andrew",
last_name="Brookins",
email="a@example.com",
join_date=today
)
def test_exact_match_queries(members):
member1, member2, member3 = members
member2 = Member(
first_name="Kim",
last_name="Brookins",
email="k@example.com",
join_date=today
)
member3 = Member(
first_name="Andrew",
last_name="Smith",
email="as@example.com",
join_date=today
)
member1.save()
member2.save()
member3.save()
# # TODO: How to help IDEs know that last_name is not a str, but a wrapped expression?
actual = Member.find(Member.last_name == "Brookins")
assert actual == [member2, member1]
assert actual == sorted([member1, member2])
actual = Member.find(
(Member.last_name == "Brookins") & ~(Member.first_name == "Andrew"))
assert actual == [member2]
actual = Member.find(~(Member.last_name == "Brookins"))
@ -187,12 +198,39 @@ def test_exact_match_queries():
assert actual == member2
def test_numeric_queries(members):
member1, member2, member3 = members
actual = Member.find_one(Member.age == 34)
assert actual == member2
actual = Member.find(Member.age > 34)
assert sorted(actual) == [member1, member3]
actual = Member.find(Member.age < 35)
assert actual == [member2]
actual = Member.find(Member.age <= 34)
assert actual == [member2]
actual = Member.find(Member.age >= 100)
assert actual == [member3]
actual = Member.find(~(Member.age == 100))
assert sorted(actual) == [member1, member2]
def test_schema():
class Address(BaseHashModel):
a_string: str
an_integer: int
a_float: float
a_string: str = Field(index=True)
a_full_text_string: str = Field(index=True, full_text_search=True)
an_integer: int = Field(index=True, sortable=True)
a_float: float = Field(index=True)
another_integer: int
another_float: float
# TODO: Fix
assert Address.schema() == "ON HASH PREFIX 1 redis-developer:basehashmodel: SCHEMA pk TAG SORTABLE a_string TEXT an_integer NUMERIC " \
assert Address.schema() == "ON HASH PREFIX 1 redis-developer:address: " \
"SCHEMA pk TAG a_string TAG a_full_text_string TAG " \
"a_full_text_string_fts TEXT an_integer NUMERIC SORTABLE " \
"a_float NUMERIC"

View file

@ -43,7 +43,7 @@ class Order(BaseJsonModel):
class Member(BaseJsonModel):
first_name: str
last_name: str
email: str = Field(unique=True, index=True)
email: str = Field(index=True)
join_date: datetime.date
# Creates an embedded model.