diff --git a/redis_developer/orm/model.py b/redis_developer/orm/model.py index 5f87994..bb54327 100644 --- a/redis_developer/orm/model.py +++ b/redis_developer/orm/model.py @@ -1,6 +1,7 @@ import abc import dataclasses import decimal +import logging import operator import re from copy import deepcopy @@ -22,7 +23,7 @@ from typing import ( Protocol, List, Type, - Pattern + Pattern, get_origin, get_args ) import uuid @@ -37,8 +38,8 @@ from pydantic.utils import Representation from .encoders import jsonable_encoder model_registry = {} - _T = TypeVar("_T") +log = logging.getLogger(__name__) class TokenEscaper: @@ -338,7 +339,6 @@ class FindQuery: field_type = cls.resolve_field_type(expression.left) field_name = expression.left.name else: - import ipdb; ipdb.set_trace() raise QueryNotSupportedError(f"A query expression should start with either a field " f"or an expression enclosed in parenthesis. See docs: " f"TODO") @@ -831,24 +831,19 @@ class HashModel(RedisModel, abc.ABC): return "" return val - @classmethod - def schema_for_type(cls, name, typ: Type, field_info: FieldInfo): - if any(issubclass(typ, t) for t in NUMERIC_TYPES): - return f"{name} NUMERIC" - elif issubclass(typ, str): - if getattr(field_info, 'full_text_search', False) is True: - return f"{name} TAG {name}_fts TEXT" - else: - return f"{name} TAG" - else: - return f"{name} TAG" - @classmethod def redisearch_schema(cls): hash_prefix = cls.make_key(cls._meta.primary_key_pattern.format(pk="")) schema_prefix = f"ON HASH PREFIX 1 {hash_prefix} SCHEMA" - schema_parts = [schema_prefix] + schema_parts = [schema_prefix] + cls.schema_for_fields() + return " ".join(schema_parts) + + @classmethod + def schema_for_fields(cls): + schema_parts = [] + for name, field in cls.__fields__.items(): + # TODO: Merge this code with schema_for_type() _type = field.outer_type_ if getattr(field.field_info, 'primary_key', None): if issubclass(_type, str): @@ -858,9 +853,49 @@ class HashModel(RedisModel, abc.ABC): schema_parts.append(redisearch_field) elif getattr(field.field_info, 'index', None) is True: schema_parts.append(cls.schema_for_type(name, _type, field.field_info)) + # TODO: Raise error if user embeds a model field or list and makes it + # sortable. Instead, the embedded model should mark individual fields + # as sortable. if getattr(field.field_info, 'sortable', False) is True: schema_parts.append("SORTABLE") - return " ".join(schema_parts) + elif get_origin(_type) == list: + embedded_cls = get_args(_type) + if not embedded_cls: + # TODO: Test if this can really happen. + log.warning("Model %s defined an empty list field: %s", cls, name) + continue + embedded_cls = embedded_cls[0] + schema_parts.append(cls.schema_for_type(name, embedded_cls, + field.field_info)) + elif issubclass(_type, RedisModel): + schema_parts.append(cls.schema_for_type(name, _type, field.field_info)) + return schema_parts + + @classmethod + def schema_for_type(cls, name, typ: Type, field_info: FieldInfo): + if get_origin(typ) == list: + embedded_cls = get_args(typ) + if not embedded_cls: + # TODO: Test if this can really happen. + log.warning("Model %s defined an empty list field: %s", cls, name) + return "" + embedded_cls = embedded_cls[0] + return cls.schema_for_type(name, embedded_cls, field_info) + elif any(issubclass(typ, t) for t in NUMERIC_TYPES): + return f"{name} NUMERIC" + elif issubclass(typ, str): + if getattr(field_info, 'full_text_search', False) is True: + return f"{name} TAG {name}_fts TEXT" + else: + return f"{name} TAG" + elif issubclass(typ, RedisModel): + sub_fields = [] + for embedded_name, field in typ.__fields__.items(): + sub_fields.append(cls.schema_for_type(f"{name}_{embedded_name}", field.outer_type_, + field.field_info)) + return " ".join(sub_fields) + else: + return f"{name} TAG" class JsonModel(RedisModel, abc.ABC): @@ -874,3 +909,88 @@ class JsonModel(RedisModel, abc.ABC): if not document: raise NotFoundError return cls.parse_raw(document) + + @classmethod + def redisearch_schema(cls): + key_prefix = cls.make_key(cls._meta.primary_key_pattern.format(pk="")) + schema_prefix = f"ON JSON PREFIX 1 {key_prefix} SCHEMA" + schema_parts = [schema_prefix] + cls.schema_for_fields() + return " ".join(schema_parts) + + @classmethod + def schema_for_fields(cls): + schema_parts = [] + json_path = "$" + + for name, field in cls.__fields__.items(): + # TODO: Merge this code with schema_for_type()? + _type = field.outer_type_ + if getattr(field.field_info, 'primary_key', None): + if issubclass(_type, str): + redisearch_field = f"{json_path}.{name} AS {name} TAG" + else: + redisearch_field = cls.schema_for_type(f"{json_path}.{name}", name, "", _type, field.field_info) + schema_parts.append(redisearch_field) + elif getattr(field.field_info, 'index', None) is True: + schema_parts.append(cls.schema_for_type(f"{json_path}.{name}", name, "", _type, field.field_info)) + # TODO: Raise error if user embeds a model field or list and makes it + # sortable. Instead, the embedded model should mark individual fields + # as sortable. + if getattr(field.field_info, 'sortable', False) is True: + schema_parts.append("SORTABLE") + elif get_origin(_type) == list: + embedded_cls = get_args(_type) + if not embedded_cls: + # TODO: Test if this can really happen. + log.warning("Model %s defined an empty list field: %s", cls, name) + continue + embedded_cls = embedded_cls[0] + schema_parts.append(cls.schema_for_type(f"{json_path}.{name}[]", name, f"{name}", + embedded_cls, field.field_info)) + elif issubclass(_type, RedisModel): + schema_parts.append(cls.schema_for_type(f"{json_path}.{name}", name, f"{name}", _type, + field.field_info)) + return schema_parts + + @classmethod + # TODO: We need both the "name" of the field (address_line_1) as we'll + # find it in the JSON document, AND the name of the field as it should + # be in the redisearch schema (address_address_line_1). Maybe both "name" + # and "name_prefix"? + def schema_for_type(cls, json_path: str, name: str, name_prefix: str, typ: Type, + field_info: FieldInfo) -> str: + index_field_name = f"{name_prefix}{name}" + should_index = getattr(field_info, 'index', False) + + if get_origin(typ) == list: + embedded_cls = get_args(typ) + if not embedded_cls: + # TODO: Test if this can really happen. + log.warning("Model %s defined an empty list field: %s", cls, name) + return "" + embedded_cls = embedded_cls[0] + # TODO: We need to pass the "JSON Path so far" which should include the + # correct syntax for an array. + return cls.schema_for_type(f"{json_path}[]", name, f"{name_prefix}{name}", embedded_cls, field_info) + elif issubclass(typ, RedisModel): + sub_fields = [] + for embedded_name, field in typ.__fields__.items(): + sub_fields.append(cls.schema_for_type(f"{json_path}.{embedded_name}", + embedded_name, + f"{name_prefix}_", + field.outer_type_, + field.field_info)) + return " ".join(filter(None, sub_fields)) + elif should_index: + if any(issubclass(typ, t) for t in NUMERIC_TYPES): + return f"{json_path} AS {index_field_name} NUMERIC" + elif issubclass(typ, str): + if getattr(field_info, 'full_text_search', False) is True: + return f"{json_path} AS {index_field_name} TAG " \ + f"{json_path} AS {index_field_name}_fts TEXT" + else: + return f"{json_path} AS {index_field_name} TAG" + else: + return f"{json_path} AS {index_field_name} TAG" + + return "" diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index 7f81aa6..898321d 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -74,7 +74,7 @@ def members(): def test_validates_required_fields(): - # Raises ValidationError: last_name, address are required + # Raises ValidationError: last_name is required with pytest.raises(ValidationError): Member( first_name="Andrew", diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 28d9f65..4d79565 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -2,6 +2,7 @@ import abc import decimal import datetime from typing import Optional, List +from unittest import mock import pytest import redis @@ -11,9 +12,10 @@ from redis_developer.orm import ( JsonModel, Field, ) +from redis_developer.orm.model import RedisModelError, QueryNotSupportedError, NotFoundError r = redis.Redis() -today = datetime.datetime.today() +today = datetime.date.today() class BaseJsonModel(JsonModel, abc.ABC): @@ -25,13 +27,14 @@ class Address(BaseJsonModel): address_line_1: str address_line_2: Optional[str] city: str + state: str country: str - postal_code: str + postal_code: str = Field(index=True) class Item(BaseJsonModel): price: decimal.Decimal - name: str + name: str = Field(index=True, full_text_search=True) class Order(BaseJsonModel): @@ -45,6 +48,7 @@ class Member(BaseJsonModel): last_name: str email: str = Field(index=True) join_date: datetime.date + age: int # Creates an embedded model. address: Address @@ -52,169 +56,308 @@ class Member(BaseJsonModel): # Creates an embedded list of models. orders: Optional[List[Order]] - class Meta(BaseJsonModel.Meta): - model_key_prefix = "member" # This is the default + +@pytest.fixture() +def address(): + yield Address( + address_line_1="1 Main St.", + city="Portland", + state="OR", + country="USA", + postal_code=11111 + ) -address = Address( - address_line_1="1 Main St.", - city="Happy Town", - state="WY", - postal_code=11111, - country="USA" -) +@pytest.fixture() +def members(address): + member1 = Member( + first_name="Andrew", + last_name="Brookins", + email="a@example.com", + age=38, + join_date=today, + address=address + ) + + member2 = Member( + first_name="Kim", + last_name="Brookins", + email="k@example.com", + age=34, + join_date=today, + address=address + ) + + member3 = Member( + first_name="Andrew", + last_name="Smith", + email="as@example.com", + age=100, + join_date=today, + address=address + ) + + member1.save() + member2.save() + member3.save() + + yield member1, member2, member3 -def test_validates_required_fields(): - # Raises ValidationError: last_name, address are required +def test_validates_required_fields(address): + # Raises ValidationError address is required with pytest.raises(ValidationError): Member( first_name="Andrew", + last_name="Brookins", zipcode="97086", - join_date=today + join_date=today, ) -def test_validates_field(): +def test_validates_field(address): # Raises ValidationError: join_date is not a date with pytest.raises(ValidationError): Member( first_name="Andrew", last_name="Brookins", - join_date="yesterday" + join_date="yesterday", + address=address ) # Passes validation -def test_validation_passes(): +def test_validation_passes(address): member = Member( first_name="Andrew", last_name="Brookins", email="a@example.com", - address=address, - join_date=today + join_date=today, + age=38, + address=address ) assert member.first_name == "Andrew" -def test_gets_pk(): - new_address = Address( - address_line_1="1 Main St.", - city="Happy Town", - state="WY", - postal_code=11111, - country="USA" - ) - assert new_address.pk is not None - - -def test_saves_model(): - # Save a model instance to Redis - address.save() - address2 = Address.get(address.pk) - assert address2 == address - - -def test_saves_with_embedded_models(): +def test_saves_model_and_creates_pk(address): member = Member( first_name="Andrew", last_name="Brookins", email="a@example.com", - address=address, - join_date=datetime.date.today() + join_date=today, + age=38, + address=address ) + # Save a model instance to Redis member.save() member2 = Member.get(member.pk) + assert member2 == member assert member2.address == address -def test_saves_with_deeply_embedded_models(): - hat = Item( - name="Cool hat", - price=2.99 - ) - shoe = Item( - name="Expensive shoe", - price=299.99 - ) - order = Order( - total=302.98, - items=[hat, shoe], - created_on=today, - ) - member = Member( - first_name="Andrew", - last_name="Brookins", - email="a@example.com", - address=address, - orders=[order], - join_date=today - ) - member.save() - - member2 = Member.get(member.pk) - assert member2.orders[0] == order - assert member2.orders[0].items[0] == hat - - -# Save many model instances to Redis @pytest.mark.skip("Not implemented yet") -def test_saves_many(): +def test_saves_many(address): members = [ Member( first_name="Andrew", last_name="Brookins", email="a@example.com", + join_date=today, address=address, - join_date=today + age=38 ), Member( first_name="Kim", last_name="Brookins", email="k@example.com", + join_date=today, address=address, - join_date=today + age=34 ) ] Member.add(members) -@pytest.mark.skip("No implemented yet") -def test_updates_a_model(): - member = Member( - first_name="Andrew", - last_name="Brookins", - email="a@example.com", - address=address, - join_date=today - ) - - # Update a model instance in Redis - member.first_name = "Brian" - member.last_name = "Sam-Bodden" - member.save() +@pytest.mark.skip("Not ready yet") +def test_updates_a_model(members): + member1, member2, member3 = members # Or, with an implicit save: - member.update(first_name="Brian", last_name="Sam-Bodden") + member1.update(last_name="Smith") + assert Member.find(Member.pk == member1.pk).first() == member1 # Or, affecting multiple model instances with an implicit save: - Member.filter(Member.last_name == "Brookins").update(last_name="Sam-Bodden") + Member.find(Member.last_name == "Brookins").update(last_name="Smith") + results = Member.find(Member.last_name == "Smith") + assert sorted(results) == members + + # Or, updating a field in an embedded model: + member2.update(address__city="Happy Valley") + assert Member.find(Member.pk == member2.pk).first().address.city == "Happy Valley" -@pytest.mark.skip("Not implemented yet") -def test_exact_match_queries(): - # TODO: Should get() support expressions? I.e., ... - # What if the field wasn't unique and there were two "a@example.com" - # entries? This would raise a MultipleObjectsReturned error: - member = Member.get(Member.email == "a.m.brookins@gmail.com") +def test_paginate_query(members): + member1, member2, member3 = members + actual = Member.find().all(batch_size=1) + assert sorted(actual) == [member1, member2, member3] - # What if you know there might be multiple results? Use filter(): - members = Member.filter(Member.last_name == "Brookins") - # What if you want to only return values that don't match a query? - members = Member.exclude(Member.last_name == "Brookins") +def test_access_result_by_index_cached(members): + member1, member2, member3 = members + query = Member.find().sort_by('age') + # Load the cache, throw away the result. + assert query._model_cache == [] + query.execute() + assert query._model_cache == [member2, member1, member3] - # You can combine filer() and exclude(): - members = Member.filter(Member.last_name == "Brookins").exclude( - Member.first_name == "Andrew") + # Access an item that should be in the cache. + with mock.patch.object(query.model, 'db') as mock_db: + assert query[0] == member2 + assert not mock_db.called + + +def test_access_result_by_index_not_cached(members): + member1, member2, member3 = members + query = Member.find().sort_by('age') + + # Assert that we don't have any models in the cache yet -- we + # haven't made any requests of Redis. + assert query._model_cache == [] + assert query[0] == member2 + assert query[1] == member1 + assert query[2] == member3 + + +def test_exact_match_queries(members): + member1, member2, member3 = members + + actual = Member.find(Member.last_name == "Brookins").all() + assert sorted(actual) == [member1, member2] + + actual = Member.find( + (Member.last_name == "Brookins") & ~(Member.first_name == "Andrew")).all() + assert actual == [member2] + + actual = Member.find(~(Member.last_name == "Brookins")).all() + assert actual == [member3] + + actual = Member.find(Member.last_name != "Brookins").all() + assert actual == [member3] + + actual = Member.find( + (Member.last_name == "Brookins") & (Member.first_name == "Andrew") + | (Member.first_name == "Kim") + ).all() + assert actual == [member2, member1] + + actual = Member.find(Member.first_name == "Kim", Member.last_name == "Brookins").all() + assert actual == [member2] + + actual = Member.find(Member.address.city == "Portland").all() + assert actual == [member1, member2, member3] + + +def test_recursive_query_resolution(members): + member1, member2, member3 = members + + actual = Member.find((Member.last_name == "Brookins") | ( + Member.age == 100 + ) & (Member.last_name == "Smith")).all() + assert sorted(actual) == [member1, member2, member3] + + +def test_tag_queries_boolean_logic(members): + member1, member2, member3 = members + + actual = Member.find( + (Member.first_name == "Andrew") & + (Member.last_name == "Brookins") | (Member.last_name == "Smith")).all() + assert sorted(actual) == [member1, member3] + + +def test_tag_queries_punctuation(): + member = Member( + first_name="Andrew, the Michael", # This string uses the TAG field separator. + last_name="St. Brookins-on-Pier", + email="a@example.com", + age=38, + join_date=today + ) + member.save() + + assert Member.find(Member.first_name == "Andrew the Michael").first() == member + assert Member.find(Member.last_name == "St. Brookins-on-Pier").first() == member + assert Member.find(Member.email == "a@example.com").first() == member + + +def test_tag_queries_negation(members): + member1, member2, member3 = members + + actual = Member.find( + ~(Member.first_name == "Andrew") & + (Member.last_name == "Brookins") | (Member.last_name == "Smith")).all() + assert sorted(actual) == [member2, member3] + + actual = Member.find( + (Member.first_name == "Andrew") & ~(Member.last_name == "Brookins")).all() + assert sorted(actual) == [member3] + + +def test_numeric_queries(members): + member1, member2, member3 = members + + actual = Member.find(Member.age == 34).all() + assert actual == [member2] + + actual = Member.find(Member.age > 34).all() + assert sorted(actual) == [member1, member3] + + actual = Member.find(Member.age < 35).all() + assert actual == [member2] + + actual = Member.find(Member.age <= 34).all() + assert actual == [member2] + + actual = Member.find(Member.age >= 100).all() + assert actual == [member3] + + actual = Member.find(~(Member.age == 100)).all() + assert sorted(actual) == [member1, member2] + + +def test_sorting(members): + member1, member2, member3 = members + + actual = Member.find(Member.age > 34).sort_by('age').all() + assert sorted(actual) == [member3, member1] + + actual = Member.find(Member.age > 34).sort_by('-age').all() + assert sorted(actual) == [member1, member3] + + with pytest.raises(QueryNotSupportedError): + # This field does not exist. + Member.find().sort_by('not-a-real-field').all() + + with pytest.raises(QueryNotSupportedError): + # This field is not sortable. + Member.find().sort_by('join_date').all() + + +def test_not_found(): + with pytest.raises(NotFoundError): + # This ID does not exist. + Member.get(1000) + + +def test_schema(): + assert Member.redisearch_schema() == "ON JSON PREFIX 1 " \ + "redis-developer:tests.test_json_model.Member: " \ + "SCHEMA $.pk AS pk TAG " \ + "$.email AS email TAG " \ + "$.address.pk AS address_pk TAG " \ + "$.address.postal_code AS address_postal_code TAG " \ + "$.orders[].pk AS orders_pk TAG " \ + "$.orders[].items[].pk AS orders_items_pk TAG " \ + "$.orders[].items[].name AS orders_items_name TAG " \ + "$.orders[].items[].name AS orders_items_name_fts TEXT"