Generate a seemingly correct schema for JSON models

This commit is contained in:
Andrew Brookins 2021-10-04 13:55:33 -07:00
parent 9a47d1046d
commit 976fdf3ad4
3 changed files with 384 additions and 121 deletions

View file

@ -1,6 +1,7 @@
import abc import abc
import dataclasses import dataclasses
import decimal import decimal
import logging
import operator import operator
import re import re
from copy import deepcopy from copy import deepcopy
@ -22,7 +23,7 @@ from typing import (
Protocol, Protocol,
List, List,
Type, Type,
Pattern Pattern, get_origin, get_args
) )
import uuid import uuid
@ -37,8 +38,8 @@ from pydantic.utils import Representation
from .encoders import jsonable_encoder from .encoders import jsonable_encoder
model_registry = {} model_registry = {}
_T = TypeVar("_T") _T = TypeVar("_T")
log = logging.getLogger(__name__)
class TokenEscaper: class TokenEscaper:
@ -338,7 +339,6 @@ class FindQuery:
field_type = cls.resolve_field_type(expression.left) field_type = cls.resolve_field_type(expression.left)
field_name = expression.left.name field_name = expression.left.name
else: else:
import ipdb; ipdb.set_trace()
raise QueryNotSupportedError(f"A query expression should start with either a field " raise QueryNotSupportedError(f"A query expression should start with either a field "
f"or an expression enclosed in parenthesis. See docs: " f"or an expression enclosed in parenthesis. See docs: "
f"TODO") f"TODO")
@ -831,24 +831,19 @@ class HashModel(RedisModel, abc.ABC):
return "" return ""
return val return val
@classmethod
def schema_for_type(cls, name, typ: Type, field_info: FieldInfo):
if any(issubclass(typ, t) for t in NUMERIC_TYPES):
return f"{name} NUMERIC"
elif issubclass(typ, str):
if getattr(field_info, 'full_text_search', False) is True:
return f"{name} TAG {name}_fts TEXT"
else:
return f"{name} TAG"
else:
return f"{name} TAG"
@classmethod @classmethod
def redisearch_schema(cls): def redisearch_schema(cls):
hash_prefix = cls.make_key(cls._meta.primary_key_pattern.format(pk="")) hash_prefix = cls.make_key(cls._meta.primary_key_pattern.format(pk=""))
schema_prefix = f"ON HASH PREFIX 1 {hash_prefix} SCHEMA" schema_prefix = f"ON HASH PREFIX 1 {hash_prefix} SCHEMA"
schema_parts = [schema_prefix] schema_parts = [schema_prefix] + cls.schema_for_fields()
return " ".join(schema_parts)
@classmethod
def schema_for_fields(cls):
schema_parts = []
for name, field in cls.__fields__.items(): for name, field in cls.__fields__.items():
# TODO: Merge this code with schema_for_type()
_type = field.outer_type_ _type = field.outer_type_
if getattr(field.field_info, 'primary_key', None): if getattr(field.field_info, 'primary_key', None):
if issubclass(_type, str): if issubclass(_type, str):
@ -858,9 +853,49 @@ class HashModel(RedisModel, abc.ABC):
schema_parts.append(redisearch_field) schema_parts.append(redisearch_field)
elif getattr(field.field_info, 'index', None) is True: elif getattr(field.field_info, 'index', None) is True:
schema_parts.append(cls.schema_for_type(name, _type, field.field_info)) schema_parts.append(cls.schema_for_type(name, _type, field.field_info))
# TODO: Raise error if user embeds a model field or list and makes it
# sortable. Instead, the embedded model should mark individual fields
# as sortable.
if getattr(field.field_info, 'sortable', False) is True: if getattr(field.field_info, 'sortable', False) is True:
schema_parts.append("SORTABLE") schema_parts.append("SORTABLE")
return " ".join(schema_parts) elif get_origin(_type) == list:
embedded_cls = get_args(_type)
if not embedded_cls:
# TODO: Test if this can really happen.
log.warning("Model %s defined an empty list field: %s", cls, name)
continue
embedded_cls = embedded_cls[0]
schema_parts.append(cls.schema_for_type(name, embedded_cls,
field.field_info))
elif issubclass(_type, RedisModel):
schema_parts.append(cls.schema_for_type(name, _type, field.field_info))
return schema_parts
@classmethod
def schema_for_type(cls, name, typ: Type, field_info: FieldInfo):
if get_origin(typ) == list:
embedded_cls = get_args(typ)
if not embedded_cls:
# TODO: Test if this can really happen.
log.warning("Model %s defined an empty list field: %s", cls, name)
return ""
embedded_cls = embedded_cls[0]
return cls.schema_for_type(name, embedded_cls, field_info)
elif any(issubclass(typ, t) for t in NUMERIC_TYPES):
return f"{name} NUMERIC"
elif issubclass(typ, str):
if getattr(field_info, 'full_text_search', False) is True:
return f"{name} TAG {name}_fts TEXT"
else:
return f"{name} TAG"
elif issubclass(typ, RedisModel):
sub_fields = []
for embedded_name, field in typ.__fields__.items():
sub_fields.append(cls.schema_for_type(f"{name}_{embedded_name}", field.outer_type_,
field.field_info))
return " ".join(sub_fields)
else:
return f"{name} TAG"
class JsonModel(RedisModel, abc.ABC): class JsonModel(RedisModel, abc.ABC):
@ -874,3 +909,88 @@ class JsonModel(RedisModel, abc.ABC):
if not document: if not document:
raise NotFoundError raise NotFoundError
return cls.parse_raw(document) return cls.parse_raw(document)
@classmethod
def redisearch_schema(cls):
key_prefix = cls.make_key(cls._meta.primary_key_pattern.format(pk=""))
schema_prefix = f"ON JSON PREFIX 1 {key_prefix} SCHEMA"
schema_parts = [schema_prefix] + cls.schema_for_fields()
return " ".join(schema_parts)
@classmethod
def schema_for_fields(cls):
schema_parts = []
json_path = "$"
for name, field in cls.__fields__.items():
# TODO: Merge this code with schema_for_type()?
_type = field.outer_type_
if getattr(field.field_info, 'primary_key', None):
if issubclass(_type, str):
redisearch_field = f"{json_path}.{name} AS {name} TAG"
else:
redisearch_field = cls.schema_for_type(f"{json_path}.{name}", name, "", _type, field.field_info)
schema_parts.append(redisearch_field)
elif getattr(field.field_info, 'index', None) is True:
schema_parts.append(cls.schema_for_type(f"{json_path}.{name}", name, "", _type, field.field_info))
# TODO: Raise error if user embeds a model field or list and makes it
# sortable. Instead, the embedded model should mark individual fields
# as sortable.
if getattr(field.field_info, 'sortable', False) is True:
schema_parts.append("SORTABLE")
elif get_origin(_type) == list:
embedded_cls = get_args(_type)
if not embedded_cls:
# TODO: Test if this can really happen.
log.warning("Model %s defined an empty list field: %s", cls, name)
continue
embedded_cls = embedded_cls[0]
schema_parts.append(cls.schema_for_type(f"{json_path}.{name}[]", name, f"{name}",
embedded_cls, field.field_info))
elif issubclass(_type, RedisModel):
schema_parts.append(cls.schema_for_type(f"{json_path}.{name}", name, f"{name}", _type,
field.field_info))
return schema_parts
@classmethod
# TODO: We need both the "name" of the field (address_line_1) as we'll
# find it in the JSON document, AND the name of the field as it should
# be in the redisearch schema (address_address_line_1). Maybe both "name"
# and "name_prefix"?
def schema_for_type(cls, json_path: str, name: str, name_prefix: str, typ: Type,
field_info: FieldInfo) -> str:
index_field_name = f"{name_prefix}{name}"
should_index = getattr(field_info, 'index', False)
if get_origin(typ) == list:
embedded_cls = get_args(typ)
if not embedded_cls:
# TODO: Test if this can really happen.
log.warning("Model %s defined an empty list field: %s", cls, name)
return ""
embedded_cls = embedded_cls[0]
# TODO: We need to pass the "JSON Path so far" which should include the
# correct syntax for an array.
return cls.schema_for_type(f"{json_path}[]", name, f"{name_prefix}{name}", embedded_cls, field_info)
elif issubclass(typ, RedisModel):
sub_fields = []
for embedded_name, field in typ.__fields__.items():
sub_fields.append(cls.schema_for_type(f"{json_path}.{embedded_name}",
embedded_name,
f"{name_prefix}_",
field.outer_type_,
field.field_info))
return " ".join(filter(None, sub_fields))
elif should_index:
if any(issubclass(typ, t) for t in NUMERIC_TYPES):
return f"{json_path} AS {index_field_name} NUMERIC"
elif issubclass(typ, str):
if getattr(field_info, 'full_text_search', False) is True:
return f"{json_path} AS {index_field_name} TAG " \
f"{json_path} AS {index_field_name}_fts TEXT"
else:
return f"{json_path} AS {index_field_name} TAG"
else:
return f"{json_path} AS {index_field_name} TAG"
return ""

View file

@ -74,7 +74,7 @@ def members():
def test_validates_required_fields(): def test_validates_required_fields():
# Raises ValidationError: last_name, address are required # Raises ValidationError: last_name is required
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
Member( Member(
first_name="Andrew", first_name="Andrew",

View file

@ -2,6 +2,7 @@ import abc
import decimal import decimal
import datetime import datetime
from typing import Optional, List from typing import Optional, List
from unittest import mock
import pytest import pytest
import redis import redis
@ -11,9 +12,10 @@ from redis_developer.orm import (
JsonModel, JsonModel,
Field, Field,
) )
from redis_developer.orm.model import RedisModelError, QueryNotSupportedError, NotFoundError
r = redis.Redis() r = redis.Redis()
today = datetime.datetime.today() today = datetime.date.today()
class BaseJsonModel(JsonModel, abc.ABC): class BaseJsonModel(JsonModel, abc.ABC):
@ -25,13 +27,14 @@ class Address(BaseJsonModel):
address_line_1: str address_line_1: str
address_line_2: Optional[str] address_line_2: Optional[str]
city: str city: str
state: str
country: str country: str
postal_code: str postal_code: str = Field(index=True)
class Item(BaseJsonModel): class Item(BaseJsonModel):
price: decimal.Decimal price: decimal.Decimal
name: str name: str = Field(index=True, full_text_search=True)
class Order(BaseJsonModel): class Order(BaseJsonModel):
@ -45,6 +48,7 @@ class Member(BaseJsonModel):
last_name: str last_name: str
email: str = Field(index=True) email: str = Field(index=True)
join_date: datetime.date join_date: datetime.date
age: int
# Creates an embedded model. # Creates an embedded model.
address: Address address: Address
@ -52,169 +56,308 @@ class Member(BaseJsonModel):
# Creates an embedded list of models. # Creates an embedded list of models.
orders: Optional[List[Order]] orders: Optional[List[Order]]
class Meta(BaseJsonModel.Meta):
model_key_prefix = "member" # This is the default
@pytest.fixture()
address = Address( def address():
yield Address(
address_line_1="1 Main St.", address_line_1="1 Main St.",
city="Happy Town", city="Portland",
state="WY", state="OR",
postal_code=11111, country="USA",
country="USA" postal_code=11111
)
def test_validates_required_fields():
# Raises ValidationError: last_name, address are required
with pytest.raises(ValidationError):
Member(
first_name="Andrew",
zipcode="97086",
join_date=today
) )
def test_validates_field(): @pytest.fixture()
def members(address):
member1 = Member(
first_name="Andrew",
last_name="Brookins",
email="a@example.com",
age=38,
join_date=today,
address=address
)
member2 = Member(
first_name="Kim",
last_name="Brookins",
email="k@example.com",
age=34,
join_date=today,
address=address
)
member3 = Member(
first_name="Andrew",
last_name="Smith",
email="as@example.com",
age=100,
join_date=today,
address=address
)
member1.save()
member2.save()
member3.save()
yield member1, member2, member3
def test_validates_required_fields(address):
# Raises ValidationError address is required
with pytest.raises(ValidationError):
Member(
first_name="Andrew",
last_name="Brookins",
zipcode="97086",
join_date=today,
)
def test_validates_field(address):
# Raises ValidationError: join_date is not a date # Raises ValidationError: join_date is not a date
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
Member( Member(
first_name="Andrew", first_name="Andrew",
last_name="Brookins", last_name="Brookins",
join_date="yesterday" join_date="yesterday",
address=address
) )
# Passes validation # Passes validation
def test_validation_passes(): def test_validation_passes(address):
member = Member( member = Member(
first_name="Andrew", first_name="Andrew",
last_name="Brookins", last_name="Brookins",
email="a@example.com", email="a@example.com",
address=address, join_date=today,
join_date=today age=38,
address=address
) )
assert member.first_name == "Andrew" assert member.first_name == "Andrew"
def test_gets_pk(): def test_saves_model_and_creates_pk(address):
new_address = Address(
address_line_1="1 Main St.",
city="Happy Town",
state="WY",
postal_code=11111,
country="USA"
)
assert new_address.pk is not None
def test_saves_model():
# Save a model instance to Redis
address.save()
address2 = Address.get(address.pk)
assert address2 == address
def test_saves_with_embedded_models():
member = Member( member = Member(
first_name="Andrew", first_name="Andrew",
last_name="Brookins", last_name="Brookins",
email="a@example.com", email="a@example.com",
address=address, join_date=today,
join_date=datetime.date.today() age=38,
address=address
) )
# Save a model instance to Redis
member.save() member.save()
member2 = Member.get(member.pk) member2 = Member.get(member.pk)
assert member2 == member
assert member2.address == address assert member2.address == address
def test_saves_with_deeply_embedded_models():
hat = Item(
name="Cool hat",
price=2.99
)
shoe = Item(
name="Expensive shoe",
price=299.99
)
order = Order(
total=302.98,
items=[hat, shoe],
created_on=today,
)
member = Member(
first_name="Andrew",
last_name="Brookins",
email="a@example.com",
address=address,
orders=[order],
join_date=today
)
member.save()
member2 = Member.get(member.pk)
assert member2.orders[0] == order
assert member2.orders[0].items[0] == hat
# Save many model instances to Redis
@pytest.mark.skip("Not implemented yet") @pytest.mark.skip("Not implemented yet")
def test_saves_many(): def test_saves_many(address):
members = [ members = [
Member( Member(
first_name="Andrew", first_name="Andrew",
last_name="Brookins", last_name="Brookins",
email="a@example.com", email="a@example.com",
join_date=today,
address=address, address=address,
join_date=today age=38
), ),
Member( Member(
first_name="Kim", first_name="Kim",
last_name="Brookins", last_name="Brookins",
email="k@example.com", email="k@example.com",
join_date=today,
address=address, address=address,
join_date=today age=34
) )
] ]
Member.add(members) Member.add(members)
@pytest.mark.skip("No implemented yet") @pytest.mark.skip("Not ready yet")
def test_updates_a_model(): def test_updates_a_model(members):
member = Member( member1, member2, member3 = members
first_name="Andrew",
last_name="Brookins",
email="a@example.com",
address=address,
join_date=today
)
# Update a model instance in Redis
member.first_name = "Brian"
member.last_name = "Sam-Bodden"
member.save()
# Or, with an implicit save: # Or, with an implicit save:
member.update(first_name="Brian", last_name="Sam-Bodden") member1.update(last_name="Smith")
assert Member.find(Member.pk == member1.pk).first() == member1
# Or, affecting multiple model instances with an implicit save: # Or, affecting multiple model instances with an implicit save:
Member.filter(Member.last_name == "Brookins").update(last_name="Sam-Bodden") Member.find(Member.last_name == "Brookins").update(last_name="Smith")
results = Member.find(Member.last_name == "Smith")
assert sorted(results) == members
# Or, updating a field in an embedded model:
member2.update(address__city="Happy Valley")
assert Member.find(Member.pk == member2.pk).first().address.city == "Happy Valley"
@pytest.mark.skip("Not implemented yet") def test_paginate_query(members):
def test_exact_match_queries(): member1, member2, member3 = members
# TODO: Should get() support expressions? I.e., ... actual = Member.find().all(batch_size=1)
# What if the field wasn't unique and there were two "a@example.com" assert sorted(actual) == [member1, member2, member3]
# entries? This would raise a MultipleObjectsReturned error:
member = Member.get(Member.email == "a.m.brookins@gmail.com")
# What if you know there might be multiple results? Use filter():
members = Member.filter(Member.last_name == "Brookins")
# What if you want to only return values that don't match a query? def test_access_result_by_index_cached(members):
members = Member.exclude(Member.last_name == "Brookins") member1, member2, member3 = members
query = Member.find().sort_by('age')
# Load the cache, throw away the result.
assert query._model_cache == []
query.execute()
assert query._model_cache == [member2, member1, member3]
# You can combine filer() and exclude(): # Access an item that should be in the cache.
members = Member.filter(Member.last_name == "Brookins").exclude( with mock.patch.object(query.model, 'db') as mock_db:
Member.first_name == "Andrew") assert query[0] == member2
assert not mock_db.called
def test_access_result_by_index_not_cached(members):
member1, member2, member3 = members
query = Member.find().sort_by('age')
# Assert that we don't have any models in the cache yet -- we
# haven't made any requests of Redis.
assert query._model_cache == []
assert query[0] == member2
assert query[1] == member1
assert query[2] == member3
def test_exact_match_queries(members):
member1, member2, member3 = members
actual = Member.find(Member.last_name == "Brookins").all()
assert sorted(actual) == [member1, member2]
actual = Member.find(
(Member.last_name == "Brookins") & ~(Member.first_name == "Andrew")).all()
assert actual == [member2]
actual = Member.find(~(Member.last_name == "Brookins")).all()
assert actual == [member3]
actual = Member.find(Member.last_name != "Brookins").all()
assert actual == [member3]
actual = Member.find(
(Member.last_name == "Brookins") & (Member.first_name == "Andrew")
| (Member.first_name == "Kim")
).all()
assert actual == [member2, member1]
actual = Member.find(Member.first_name == "Kim", Member.last_name == "Brookins").all()
assert actual == [member2]
actual = Member.find(Member.address.city == "Portland").all()
assert actual == [member1, member2, member3]
def test_recursive_query_resolution(members):
member1, member2, member3 = members
actual = Member.find((Member.last_name == "Brookins") | (
Member.age == 100
) & (Member.last_name == "Smith")).all()
assert sorted(actual) == [member1, member2, member3]
def test_tag_queries_boolean_logic(members):
member1, member2, member3 = members
actual = Member.find(
(Member.first_name == "Andrew") &
(Member.last_name == "Brookins") | (Member.last_name == "Smith")).all()
assert sorted(actual) == [member1, member3]
def test_tag_queries_punctuation():
member = Member(
first_name="Andrew, the Michael", # This string uses the TAG field separator.
last_name="St. Brookins-on-Pier",
email="a@example.com",
age=38,
join_date=today
)
member.save()
assert Member.find(Member.first_name == "Andrew the Michael").first() == member
assert Member.find(Member.last_name == "St. Brookins-on-Pier").first() == member
assert Member.find(Member.email == "a@example.com").first() == member
def test_tag_queries_negation(members):
member1, member2, member3 = members
actual = Member.find(
~(Member.first_name == "Andrew") &
(Member.last_name == "Brookins") | (Member.last_name == "Smith")).all()
assert sorted(actual) == [member2, member3]
actual = Member.find(
(Member.first_name == "Andrew") & ~(Member.last_name == "Brookins")).all()
assert sorted(actual) == [member3]
def test_numeric_queries(members):
member1, member2, member3 = members
actual = Member.find(Member.age == 34).all()
assert actual == [member2]
actual = Member.find(Member.age > 34).all()
assert sorted(actual) == [member1, member3]
actual = Member.find(Member.age < 35).all()
assert actual == [member2]
actual = Member.find(Member.age <= 34).all()
assert actual == [member2]
actual = Member.find(Member.age >= 100).all()
assert actual == [member3]
actual = Member.find(~(Member.age == 100)).all()
assert sorted(actual) == [member1, member2]
def test_sorting(members):
member1, member2, member3 = members
actual = Member.find(Member.age > 34).sort_by('age').all()
assert sorted(actual) == [member3, member1]
actual = Member.find(Member.age > 34).sort_by('-age').all()
assert sorted(actual) == [member1, member3]
with pytest.raises(QueryNotSupportedError):
# This field does not exist.
Member.find().sort_by('not-a-real-field').all()
with pytest.raises(QueryNotSupportedError):
# This field is not sortable.
Member.find().sort_by('join_date').all()
def test_not_found():
with pytest.raises(NotFoundError):
# This ID does not exist.
Member.get(1000)
def test_schema():
assert Member.redisearch_schema() == "ON JSON PREFIX 1 " \
"redis-developer:tests.test_json_model.Member: " \
"SCHEMA $.pk AS pk TAG " \
"$.email AS email TAG " \
"$.address.pk AS address_pk TAG " \
"$.address.postal_code AS address_postal_code TAG " \
"$.orders[].pk AS orders_pk TAG " \
"$.orders[].items[].pk AS orders_items_pk TAG " \
"$.orders[].items[].name AS orders_items_name TAG " \
"$.orders[].items[].name AS orders_items_name_fts TEXT"