# type: ignore import abc import dataclasses import datetime import decimal from collections import namedtuple from typing import Dict, List, Optional, Set from unittest import mock import pytest from pydantic import ValidationError from aredis_om import ( Field, HashModel, Migrator, QueryNotSupportedError, RedisModelError, ) # We need to run this check as sync code (during tests) even in async mode # because we call it in the top-level module scope. from redis_om import has_redisearch if not has_redisearch(): pytestmark = pytest.mark.skip today = datetime.date.today() @pytest.fixture async def m(key_prefix, redis): class BaseHashModel(HashModel, abc.ABC): class Meta: global_key_prefix = key_prefix class Order(BaseHashModel): total: decimal.Decimal currency: str created_on: datetime.datetime class Member(BaseHashModel): first_name: str = Field(index=True) last_name: str = Field(index=True) email: str = Field(index=True) join_date: datetime.date age: int = Field(index=True, sortable=True) bio: str = Field(index=True, full_text_search=True) class Meta: model_key_prefix = "member" primary_key_pattern = "" await Migrator().run() return namedtuple("Models", ["BaseHashModel", "Order", "Member"])( BaseHashModel, Order, Member ) @pytest.fixture async def members(m): member1 = m.Member( first_name="Andrew", last_name="Brookins", email="a@example.com", age=38, join_date=today, bio="This is member 1 whose greatness makes him the life and soul of any party he goes to.", ) member2 = m.Member( first_name="Kim", last_name="Brookins", email="k@example.com", age=34, join_date=today, bio="This is member 2 who can be quite anxious until you get to know them.", ) member3 = m.Member( first_name="Andrew", last_name="Smith", email="as@example.com", age=100, join_date=today, bio="This is member 3 who is a funny and lively sort of person.", ) await member1.save() await member2.save() await member3.save() yield member1, member2, member3 @pytest.mark.asyncio async def test_exact_match_queries(members, m): member1, member2, member3 = members actual = await m.Member.find(m.Member.last_name == "Brookins").sort_by("age").all() assert actual == [member2, member1] actual = await m.Member.find( (m.Member.last_name == "Brookins") & ~(m.Member.first_name == "Andrew") ).all() assert actual == [member2] actual = await m.Member.find(~(m.Member.last_name == "Brookins")).all() assert actual == [member3] actual = await m.Member.find(m.Member.last_name != "Brookins").all() assert actual == [member3] actual = await ( m.Member.find( (m.Member.last_name == "Brookins") & (m.Member.first_name == "Andrew") | (m.Member.first_name == "Kim") ) .sort_by("age") .all() ) assert actual == [member2, member1] actual = await m.Member.find( m.Member.first_name == "Kim", m.Member.last_name == "Brookins" ).all() assert actual == [member2] @pytest.mark.asyncio async def test_full_text_search_queries(members, m): member1, member2, member3 = members actual = await (m.Member.find(m.Member.bio % "great").all()) assert actual == [member1] actual = await (m.Member.find(~(m.Member.bio % "anxious")).all()) assert actual == [member1, member3] @pytest.mark.asyncio async def test_recursive_query_resolution(members, m): member1, member2, member3 = members actual = await ( m.Member.find( (m.Member.last_name == "Brookins") | (m.Member.age == 100) & (m.Member.last_name == "Smith") ) .sort_by("age") .all() ) assert actual == [member2, member1, member3] @pytest.mark.asyncio async def test_tag_queries_boolean_logic(members, m): member1, member2, member3 = members actual = await ( m.Member.find( (m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins") | (m.Member.last_name == "Smith") ) .sort_by("age") .all() ) assert actual == [member1, member3] @pytest.mark.asyncio async def test_tag_queries_punctuation(m): member1 = m.Member( first_name="Andrew, the Michael", last_name="St. Brookins-on-Pier", email="a|b@example.com", # NOTE: This string uses the TAG field separator. age=38, join_date=today, bio="This is a test user on our system.", ) await member1.save() member2 = m.Member( first_name="Bob", last_name="the Villain", email="a|villain@example.com", # NOTE: This string uses the TAG field separator. age=38, join_date=today, bio="This is a villain, they are a really bad person!", ) await member2.save() result = await (m.Member.find(m.Member.first_name == "Andrew, the Michael").first()) assert result == member1 result = await (m.Member.find(m.Member.last_name == "St. Brookins-on-Pier").first()) assert result == member1 # Notice that when we index and query multiple values that use the internal # TAG separator for single-value exact-match fields, like an indexed string, # the queries will succeed. We apply a workaround that queries for the union # of the two values separated by the tag separator. results = await m.Member.find(m.Member.email == "a|b@example.com").all() assert results == [member1] results = await m.Member.find(m.Member.email == "a|villain@example.com").all() assert results == [member2] @pytest.mark.asyncio async def test_tag_queries_negation(members, m): member1, member2, member3 = members """ ┌first_name NOT EQ┤ └Andrew """ query = m.Member.find(~(m.Member.first_name == "Andrew")) assert await query.all() == [member2] """ ┌first_name ┌NOT EQ┤ | └Andrew AND┤ | ┌last_name └EQ┤ └Brookins """ query = m.Member.find( ~(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins") ) assert await query.all() == [member2] """ ┌first_name ┌NOT EQ┤ | └Andrew AND┤ | ┌last_name | ┌EQ┤ | | └Brookins └OR┤ | ┌last_name └EQ┤ └Smith """ query = m.Member.find( ~(m.Member.first_name == "Andrew") & ((m.Member.last_name == "Brookins") | (m.Member.last_name == "Smith")) ) assert await query.all() == [member2] """ ┌first_name ┌NOT EQ┤ | └Andrew ┌AND┤ | | ┌last_name | └EQ┤ | └Brookins OR┤ | ┌last_name └EQ┤ └Smith """ query = m.Member.find( ~(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins") | (m.Member.last_name == "Smith") ) assert await query.sort_by("age").all() == [member2, member3] actual = await m.Member.find( (m.Member.first_name == "Andrew") & ~(m.Member.last_name == "Brookins") ).all() assert actual == [member3] @pytest.mark.asyncio async def test_numeric_queries(members, m): member1, member2, member3 = members actual = await m.Member.find(m.Member.age == 34).all() assert actual == [member2] actual = await m.Member.find(m.Member.age > 34).sort_by("age").all() assert actual == [member1, member3] actual = await m.Member.find(m.Member.age < 35).all() assert actual == [member2] actual = await m.Member.find(m.Member.age <= 34).all() assert actual == [member2] actual = await m.Member.find(m.Member.age >= 100).all() assert actual == [member3] actual = await m.Member.find(m.Member.age != 34).sort_by("age").all() assert actual == [member1, member3] actual = await m.Member.find(~(m.Member.age == 100)).sort_by("age").all() assert actual == [member2, member1] actual = ( await m.Member.find(m.Member.age > 30, m.Member.age < 40).sort_by("age").all() ) assert actual == [member2, member1] @pytest.mark.asyncio async def test_sorting(members, m): member1, member2, member3 = members actual = await m.Member.find(m.Member.age > 34).sort_by("age").all() assert actual == [member1, member3] actual = await m.Member.find(m.Member.age > 34).sort_by("-age").all() assert actual == [member3, member1] with pytest.raises(QueryNotSupportedError): # This field does not exist. await m.Member.find().sort_by("not-a-real-field").all() with pytest.raises(QueryNotSupportedError): # This field is not sortable. await m.Member.find().sort_by("join_date").all() def test_validates_required_fields(m): # Raises ValidationError: last_name is required # TODO: Test the error value with pytest.raises(ValidationError): m.Member(first_name="Andrew", zipcode="97086", join_date=today) def test_validates_field(m): # Raises ValidationError: join_date is not a date # TODO: Test the error value with pytest.raises(ValidationError): m.Member(first_name="Andrew", last_name="Brookins", join_date="yesterday") def test_validation_passes(m): member = m.Member( first_name="Andrew", last_name="Brookins", email="a@example.com", join_date=today, age=38, bio="This is the bio field.", ) assert member.first_name == "Andrew" @pytest.mark.asyncio async def test_retrieve_first(m): member = m.Member( first_name="Simon", last_name="Prickett", email="s@example.com", join_date=today, age=99, bio="This is the bio field for this user.", ) await member.save() member2 = m.Member( first_name="Another", last_name="Member", email="m@example.com", join_date=today, age=98, bio="This is the bio field for this user.", ) await member2.save() member3 = m.Member( first_name="Third", last_name="Member", email="t@example.com", join_date=today, age=97, bio="This is the bio field for this user.", ) await member3.save() first_one = await m.Member.find().sort_by("age").first() assert first_one == member3 @pytest.mark.asyncio async def test_saves_model_and_creates_pk(m): member = m.Member( first_name="Andrew", last_name="Brookins", email="a@example.com", join_date=today, age=38, bio="This is the bio field for this user.", ) # Save a model instance to Redis await member.save() member2 = await m.Member.get(member.pk) assert member2 == member @pytest.mark.asyncio async def test_all_pks(m): member = m.Member( first_name="Simon", last_name="Prickett", email="s@example.com", join_date=today, age=97, bio="This is a test user to be deleted.", ) await member.save() member1 = m.Member( first_name="Andrew", last_name="Brookins", email="a@example.com", join_date=today, age=38, bio="This is a test user to be deleted.", ) await member1.save() pk_list = [] async for pk in await m.Member.all_pks(): pk_list.append(pk) assert len(pk_list) == 2 @pytest.mark.asyncio async def test_delete(m): member = m.Member( first_name="Simon", last_name="Prickett", email="s@example.com", join_date=today, age=97, bio="This is a test user to be deleted.", ) await member.save() response = await m.Member.delete(member.pk) assert response == 1 @pytest.mark.asyncio async def test_expire(m): member = m.Member( first_name="Expire", last_name="Test", email="e@example.com", join_date=today, age=93, bio="This is a test user for expiry", ) await member.save() await member.expire(60) ttl = await m.Member.db().ttl(member.key()) assert ttl > 0 def test_raises_error_with_embedded_models(m): class Address(m.BaseHashModel): address_line_1: str address_line_2: Optional[str] city: str country: str postal_code: str with pytest.raises(RedisModelError): class InvalidMember(m.BaseHashModel): address: Address def test_raises_error_with_dataclasses(m): @dataclasses.dataclass class Address: address_line_1: str with pytest.raises(RedisModelError): class InvalidMember(m.BaseHashModel): address: Address def test_raises_error_with_dicts(m): with pytest.raises(RedisModelError): class InvalidMember(m.BaseHashModel): address: Dict[str, str] def test_raises_error_with_sets(m): with pytest.raises(RedisModelError): class InvalidMember(m.BaseHashModel): friend_ids: Set[str] def test_raises_error_with_lists(m): with pytest.raises(RedisModelError): class InvalidMember(m.BaseHashModel): friend_ids: List[str] @pytest.mark.asyncio async def test_saves_many(m): member1 = m.Member( first_name="Andrew", last_name="Brookins", email="a@example.com", join_date=today, age=38, bio="This is the user bio.", ) member2 = m.Member( first_name="Kim", last_name="Brookins", email="k@example.com", join_date=today, age=34, bio="This is the bio for Kim.", ) members = [member1, member2] result = await m.Member.add(members) assert result == [member1, member2] assert await m.Member.get(pk=member1.pk) == member1 assert await m.Member.get(pk=member2.pk) == member2 @pytest.mark.asyncio async def test_updates_a_model(members, m): member1, member2, member3 = members await member1.update(last_name="Smith") member = await m.Member.get(member1.pk) assert member.last_name == "Smith" @pytest.mark.asyncio async def test_paginate_query(members, m): member1, member2, member3 = members actual = await m.Member.find().sort_by("age").all(batch_size=1) assert actual == [member2, member1, member3] @pytest.mark.asyncio async def test_access_result_by_index_cached(members, m): member1, member2, member3 = members query = m.Member.find().sort_by("age") # Load the cache, throw away the result. assert query._model_cache == [] await query.execute() assert query._model_cache == [member2, member1, member3] # Access an item that should be in the cache. with mock.patch.object(query.model, "db") as mock_db: assert await query.get_item(0) == member2 assert not mock_db.called @pytest.mark.asyncio async def test_access_result_by_index_not_cached(members, m): member1, member2, member3 = members query = m.Member.find().sort_by("age") # Assert that we don't have any models in the cache yet -- we # haven't made any requests of Redis. assert query._model_cache == [] assert await query.get_item(0) == member2 assert await query.get_item(1) == member1 assert await query.get_item(2) == member3 def test_schema(m): class Address(m.BaseHashModel): a_string: str = Field(index=True) a_full_text_string: str = Field(index=True, full_text_search=True) an_integer: int = Field(index=True, sortable=True) a_float: float = Field(index=True) another_integer: int another_float: float # We need to build the key prefix because it will differ based on whether # these tests were copied into the tests_sync folder and unasynce'd. key_prefix = Address.make_key(Address._meta.primary_key_pattern.format(pk="")) assert ( Address.redisearch_schema() == f"ON HASH PREFIX 1 {key_prefix} SCHEMA pk TAG SEPARATOR | a_string TAG SEPARATOR | a_full_text_string TAG SEPARATOR | a_full_text_string AS a_full_text_string_fts TEXT an_integer NUMERIC SORTABLE a_float NUMERIC" )