# type: ignore import abc import dataclasses import datetime import decimal from collections import namedtuple from typing import Dict, List, Optional, Set from unittest import mock import pytest from pydantic import ValidationError from aredis_om import ( EmbeddedJsonModel, Field, JsonModel, Migrator, NotFoundError, QueryNotSupportedError, RedisModelError, ) # We need to run this check as sync code (during tests) even in async mode # because we call it in the top-level module scope. from redis_om import has_redis_json if not has_redis_json(): pytestmark = pytest.mark.skip today = datetime.date.today() @pytest.fixture async def m(key_prefix, redis): class BaseJsonModel(JsonModel, abc.ABC): class Meta: global_key_prefix = key_prefix class Note(EmbeddedJsonModel): # TODO: This was going to be a full-text search example, but # we can't index embedded documents for full-text search in # the preview release. description: str = Field(index=True) created_on: datetime.datetime class Address(EmbeddedJsonModel): address_line_1: str address_line_2: Optional[str] city: str = Field(index=True) state: str country: str postal_code: str = Field(index=True) note: Optional[Note] class Item(EmbeddedJsonModel): price: decimal.Decimal name: str = Field(index=True) class Order(EmbeddedJsonModel): items: List[Item] created_on: datetime.datetime class Member(BaseJsonModel): first_name: str = Field(index=True) last_name: str = Field(index=True) email: str = Field(index=True) join_date: datetime.date age: int = Field(index=True) bio: Optional[str] = Field(index=True, full_text_search=True, default="") # Creates an embedded model. address: Address # Creates an embedded list of models. orders: Optional[List[Order]] await Migrator().run() return namedtuple( "Models", ["BaseJsonModel", "Note", "Address", "Item", "Order", "Member"] )(BaseJsonModel, Note, Address, Item, Order, Member) @pytest.fixture() def address(m): yield m.Address( address_line_1="1 Main St.", city="Portland", state="OR", country="USA", postal_code=11111, ) @pytest.fixture() async def members(address, m): member1 = m.Member( first_name="Andrew", last_name="Brookins", email="a@example.com", age=38, join_date=today, address=address, ) member2 = m.Member( first_name="Kim", last_name="Brookins", email="k@example.com", age=34, join_date=today, address=address, ) member3 = m.Member( first_name="Andrew", last_name="Smith", email="as@example.com", age=100, join_date=today, address=address, ) await member1.save() await member2.save() await member3.save() yield member1, member2, member3 @pytest.mark.asyncio async def test_validates_required_fields(address, m): # Raises ValidationError address is required with pytest.raises(ValidationError): m.Member( first_name="Andrew", last_name="Brookins", zipcode="97086", join_date=today, ) @pytest.mark.asyncio async def test_validates_field(address, m): # Raises ValidationError: join_date is not a date with pytest.raises(ValidationError): m.Member( first_name="Andrew", last_name="Brookins", join_date="yesterday", address=address, ) @pytest.mark.asyncio async def test_validation_passes(address, m): member = m.Member( first_name="Andrew", last_name="Brookins", email="a@example.com", join_date=today, age=38, address=address, ) assert member.first_name == "Andrew" @pytest.mark.asyncio async def test_saves_model_and_creates_pk(address, m, redis): await Migrator().run() member = m.Member( first_name="Andrew", last_name="Brookins", email="a@example.com", join_date=today, age=38, address=address, ) # Save a model instance to Redis await member.save() member2 = await m.Member.get(member.pk) assert member2 == member assert member2.address == address @pytest.mark.asyncio async def test_delete(address, m, redis): member = m.Member( first_name="Simon", last_name="Prickett", email="s@example.com", join_date=today, age=38, address=address, ) await member.save() response = await m.Member.delete(member.pk) assert response == 1 @pytest.mark.asyncio async def test_saves_many_implicit_pipeline(address, m): member1 = m.Member( first_name="Andrew", last_name="Brookins", email="a@example.com", join_date=today, address=address, age=38, ) member2 = m.Member( first_name="Kim", last_name="Brookins", email="k@example.com", join_date=today, address=address, age=34, ) members = [member1, member2] result = await m.Member.add(members) assert result == [member1, member2] assert await m.Member.get(pk=member1.pk) == member1 assert await m.Member.get(pk=member2.pk) == member2 @pytest.mark.asyncio async def test_saves_many_explicit_transaction(address, m): member1 = m.Member( first_name="Andrew", last_name="Brookins", email="a@example.com", join_date=today, address=address, age=38, ) member2 = m.Member( first_name="Kim", last_name="Brookins", email="k@example.com", join_date=today, address=address, age=34, ) members = [member1, member2] result = await m.Member.add(members) assert result == [member1, member2] assert await m.Member.get(pk=member1.pk) == member1 assert await m.Member.get(pk=member2.pk) == member2 # Test the explicit pipeline path -- here, we add multiple Members # using a single Redis transaction, with MULTI/EXEC. async with m.Member.db().pipeline(transaction=True) as pipeline: await m.Member.add(members, pipeline=pipeline) assert result == [member1, member2] assert await pipeline.execute() == ["OK", "OK"] assert await m.Member.get(pk=member1.pk) == member1 assert await m.Member.get(pk=member2.pk) == member2 async def save(members): for m in members: await m.save() return members @pytest.mark.asyncio async def test_updates_a_model(members, m): member1, member2, member3 = await save(members) # Update a field directly on the model await member1.update(last_name="Apples to oranges") member = await m.Member.get(member1.pk) assert member.last_name == "Apples to oranges" # Update a field in an embedded model await member2.update(address__city="Happy Valley") member = await m.Member.get(member2.pk) assert member.address.city == "Happy Valley" @pytest.mark.asyncio async def test_paginate_query(members, m): member1, member2, member3 = members actual = await m.Member.find().sort_by("age").all(batch_size=1) assert actual == [member2, member1, member3] @pytest.mark.asyncio async def test_access_result_by_index_cached(members, m): member1, member2, member3 = members query = m.Member.find().sort_by("age") # Load the cache, throw away the result. assert query._model_cache == [] await query.execute() assert query._model_cache == [member2, member1, member3] # Access an item that should be in the cache. with mock.patch.object(query.model, "db") as mock_db: assert await query.get_item(0) == member2 assert not mock_db.called @pytest.mark.asyncio async def test_access_result_by_index_not_cached(members, m): member1, member2, member3 = members query = m.Member.find().sort_by("age") # Assert that we don't have any models in the cache yet -- we # haven't made any requests of Redis. assert query._model_cache == [] assert await query.get_item(0) == member2 assert await query.get_item(1) == member1 assert await query.get_item(2) == member3 @pytest.mark.asyncio async def test_in_query(members, m): member1, member2, member3 = members actual = await ( m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk]) .sort_by("age") .all() ) assert actual == [member2, member1, member3] @pytest.mark.asyncio async def test_update_query(members, m): member1, member2, member3 = members await m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk]).update( first_name="Bobby" ) actual = await ( m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk]) .sort_by("age") .all() ) assert len(actual) == 3 assert all([m.first_name == "Bobby" for m in actual]) @pytest.mark.asyncio async def test_exact_match_queries(members, m): member1, member2, member3 = members actual = await m.Member.find(m.Member.last_name == "Brookins").sort_by("age").all() assert actual == [member2, member1] actual = await m.Member.find( (m.Member.last_name == "Brookins") & ~(m.Member.first_name == "Andrew") ).all() assert actual == [member2] actual = await m.Member.find(~(m.Member.last_name == "Brookins")).all() assert actual == [member3] actual = await m.Member.find(m.Member.last_name != "Brookins").all() assert actual == [member3] actual = await ( m.Member.find( (m.Member.last_name == "Brookins") & (m.Member.first_name == "Andrew") | (m.Member.first_name == "Kim") ) .sort_by("age") .all() ) assert actual == [member2, member1] actual = await m.Member.find( m.Member.first_name == "Kim", m.Member.last_name == "Brookins" ).all() assert actual == [member2] actual = ( await m.Member.find(m.Member.address.city == "Portland").sort_by("age").all() ) assert actual == [member2, member1, member3] @pytest.mark.asyncio async def test_recursive_query_expression_resolution(members, m): member1, member2, member3 = members actual = await ( m.Member.find( (m.Member.last_name == "Brookins") | (m.Member.age == 100) & (m.Member.last_name == "Smith") ) .sort_by("age") .all() ) assert actual == [member2, member1, member3] @pytest.mark.asyncio async def test_recursive_query_field_resolution(members, m): member1, _, _ = members member1.address.note = m.Note( description="Weird house", created_on=datetime.datetime.now() ) await member1.save() actual = await m.Member.find( m.Member.address.note.description == "Weird house" ).all() assert actual == [member1] member1.orders = [ m.Order( items=[m.Item(price=10.99, name="Ball")], total=10.99, created_on=datetime.datetime.now(), ) ] await member1.save() actual = await m.Member.find(m.Member.orders.items.name == "Ball").all() assert actual == [member1] assert actual[0].orders[0].items[0].name == "Ball" @pytest.mark.asyncio async def test_full_text_search(members, m): member1, member2, _ = members await member1.update(bio="Hates sunsets, likes beaches") await member2.update(bio="Hates beaches, likes forests") actual = await m.Member.find(m.Member.bio % "beaches").sort_by("age").all() assert actual == [member2, member1] actual = await m.Member.find(m.Member.bio % "forests").all() assert actual == [member2] @pytest.mark.asyncio async def test_tag_queries_boolean_logic(members, m): member1, member2, member3 = members actual = ( await m.Member.find( (m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins") | (m.Member.last_name == "Smith") ) .sort_by("age") .all() ) assert actual == [member1, member3] @pytest.mark.asyncio async def test_tag_queries_punctuation(address, m): member1 = m.Member( first_name="Andrew, the Michael", last_name="St. Brookins-on-Pier", email="a|b@example.com", # NOTE: This string uses the TAG field separator. age=38, join_date=today, address=address, ) await member1.save() member2 = m.Member( first_name="Bob", last_name="the Villain", email="a|villain@example.com", # NOTE: This string uses the TAG field separator. age=38, join_date=today, address=address, ) await member2.save() assert ( await m.Member.find(m.Member.first_name == "Andrew, the Michael").first() == member1 ) assert ( await m.Member.find(m.Member.last_name == "St. Brookins-on-Pier").first() == member1 ) # Notice that when we index and query multiple values that use the internal # TAG separator for single-value exact-match fields, like an indexed string, # the queries will succeed. We apply a workaround that queries for the union # of the two values separated by the tag separator. assert await m.Member.find(m.Member.email == "a|b@example.com").all() == [member1] assert await m.Member.find(m.Member.email == "a|villain@example.com").all() == [ member2 ] @pytest.mark.asyncio async def test_tag_queries_negation(members, m): member1, member2, member3 = members """ ┌first_name NOT EQ┤ └Andrew """ query = m.Member.find(~(m.Member.first_name == "Andrew")) assert await query.all() == [member2] """ ┌first_name ┌NOT EQ┤ | └Andrew AND┤ | ┌last_name └EQ┤ └Brookins """ query = m.Member.find( ~(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins") ) assert await query.all() == [member2] """ ┌first_name ┌NOT EQ┤ | └Andrew AND┤ | ┌last_name | ┌EQ┤ | | └Brookins └OR┤ | ┌last_name └EQ┤ └Smith """ query = m.Member.find( ~(m.Member.first_name == "Andrew") & ((m.Member.last_name == "Brookins") | (m.Member.last_name == "Smith")) ) assert await query.all() == [member2] """ ┌first_name ┌NOT EQ┤ | └Andrew ┌AND┤ | | ┌last_name | └EQ┤ | └Brookins OR┤ | ┌last_name └EQ┤ └Smith """ query = m.Member.find( ~(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins") | (m.Member.last_name == "Smith") ) assert await query.sort_by("age").all() == [member2, member3] actual = await m.Member.find( (m.Member.first_name == "Andrew") & ~(m.Member.last_name == "Brookins") ).all() assert actual == [member3] @pytest.mark.asyncio async def test_numeric_queries(members, m): member1, member2, member3 = members actual = await m.Member.find(m.Member.age == 34).all() assert actual == [member2] actual = await m.Member.find(m.Member.age > 34).sort_by("age").all() assert actual == [member1, member3] actual = await m.Member.find(m.Member.age < 35).all() assert actual == [member2] actual = await m.Member.find(m.Member.age <= 34).all() assert actual == [member2] actual = await m.Member.find(m.Member.age >= 100).all() assert actual == [member3] actual = await m.Member.find(~(m.Member.age == 100)).sort_by("age").all() assert actual == [member2, member1] actual = ( await m.Member.find(m.Member.age > 30, m.Member.age < 40).sort_by("age").all() ) assert actual == [member2, member1] actual = await m.Member.find(m.Member.age != 34).sort_by("age").all() assert actual == [member1, member3] @pytest.mark.asyncio async def test_sorting(members, m): member1, member2, member3 = members actual = await m.Member.find(m.Member.age > 34).sort_by("age").all() assert actual == [member1, member3] actual = await m.Member.find(m.Member.age > 34).sort_by("-age").all() assert actual == [member3, member1] with pytest.raises(QueryNotSupportedError): # This field does not exist. await m.Member.find().sort_by("not-a-real-field").all() with pytest.raises(QueryNotSupportedError): # This field is not sortable. await m.Member.find().sort_by("join_date").all() @pytest.mark.asyncio async def test_not_found(m): with pytest.raises(NotFoundError): # This ID does not exist. await m.Member.get(1000) @pytest.mark.asyncio async def test_list_field_limitations(m, redis): with pytest.raises(RedisModelError): class SortableTarotWitch(m.BaseJsonModel): # We support indexing lists of strings for quality and membership # queries. Sorting is not supported, but is planned. tarot_cards: List[str] = Field(index=True, sortable=True) with pytest.raises(RedisModelError): class SortableFullTextSearchAlchemicalWitch(m.BaseJsonModel): # We don't support indexing a list of strings for full-text search # queries. Support for this feature is not planned. potions: List[str] = Field(index=True, full_text_search=True) with pytest.raises(RedisModelError): class NumerologyWitch(m.BaseJsonModel): # We don't support indexing a list of numbers. Support for this # feature is To Be Determined. lucky_numbers: List[int] = Field(index=True) with pytest.raises(RedisModelError): class ReadingWithPrice(EmbeddedJsonModel): gold_coins_charged: int = Field(index=True) class TarotWitchWhoCharges(m.BaseJsonModel): tarot_cards: List[str] = Field(index=True) # The preview release does not support indexing numeric fields on models # found within a list or tuple. This is the same limitation that stops # us from indexing plain lists (or tuples) containing numeric values. # The fate of this feature is To Be Determined. readings: List[ReadingWithPrice] class TarotWitch(m.BaseJsonModel): # We support indexing lists of strings for quality and membership # queries. Sorting is not supported, but is planned. tarot_cards: List[str] = Field(index=True) # We need to import and run this manually because we defined # our model classes within a function that runs after the test # suite's migrator has already looked for migrations to run. await Migrator().run() witch = TarotWitch(tarot_cards=["death"]) await witch.save() actual = await TarotWitch.find(TarotWitch.tarot_cards << "death").all() assert actual == [witch] @pytest.mark.asyncio async def test_allows_dataclasses(m): @dataclasses.dataclass class Address: address_line_1: str class ValidMember(m.BaseJsonModel): address: Address address = Address(address_line_1="hey") member = ValidMember(address=address) await member.save() member2 = await ValidMember.get(member.pk) assert member2 == member assert member2.address.address_line_1 == "hey" @pytest.mark.asyncio async def test_allows_and_serializes_dicts(m): class ValidMember(m.BaseJsonModel): address: Dict[str, str] member = ValidMember(address={"address_line_1": "hey"}) await member.save() member2 = await ValidMember.get(member.pk) assert member2 == member assert member2.address["address_line_1"] == "hey" @pytest.mark.asyncio async def test_allows_and_serializes_sets(m): class ValidMember(m.BaseJsonModel): friend_ids: Set[int] member = ValidMember(friend_ids={1, 2}) await member.save() member2 = await ValidMember.get(member.pk) assert member2 == member assert member2.friend_ids == {1, 2} @pytest.mark.asyncio async def test_allows_and_serializes_lists(m): class ValidMember(m.BaseJsonModel): friend_ids: List[int] member = ValidMember(friend_ids=[1, 2]) await member.save() member2 = await ValidMember.get(member.pk) assert member2 == member assert member2.friend_ids == [1, 2] @pytest.mark.asyncio async def test_schema(m, key_prefix): # We need to build the key prefix because it will differ based on whether # these tests were copied into the tests_sync folder and unasynce'd. key_prefix = m.Member.make_key(m.Member._meta.primary_key_pattern.format(pk="")) assert ( m.Member.redisearch_schema() == f"ON JSON PREFIX 1 {key_prefix} SCHEMA $.pk AS pk TAG SEPARATOR | $.first_name AS first_name TAG SEPARATOR | $.last_name AS last_name TAG SEPARATOR | $.email AS email TAG SEPARATOR | $.age AS age NUMERIC $.bio AS bio TAG SEPARATOR | $.bio AS bio_fts TEXT $.address.pk AS address_pk TAG SEPARATOR | $.address.city AS address_city TAG SEPARATOR | $.address.postal_code AS address_postal_code TAG SEPARATOR | $.address.note.pk AS address_note_pk TAG SEPARATOR | $.address.note.description AS address_note_description TAG SEPARATOR | $.orders[*].pk AS orders_pk TAG SEPARATOR | $.orders[*].items[*].pk AS orders_items_pk TAG SEPARATOR | $.orders[*].items[*].name AS orders_items_name TAG SEPARATOR |" )