diff --git a/Makefile b/Makefile index 0ff65d0..9cd4b1d 100644 --- a/Makefile +++ b/Makefile @@ -33,6 +33,7 @@ clean: rm -rf build rm -rf dist rm -rf redis_om + rm -rf tests_sync docker-compose down diff --git a/aredis_om/connections.py b/aredis_om/connections.py index bbae0c6..940682c 100644 --- a/aredis_om/connections.py +++ b/aredis_om/connections.py @@ -2,6 +2,7 @@ import os import aioredis + URL = os.environ.get("REDIS_OM_URL", None) diff --git a/make_sync.py b/make_sync.py index a457882..881a74f 100644 --- a/make_sync.py +++ b/make_sync.py @@ -7,6 +7,8 @@ ADDITIONAL_REPLACEMENTS = { "aredis_om": "redis_om", "aioredis": "redis", ":tests.": ":tests_sync.", + "pytest_asyncio": "pytest", + "py_test_mark_asyncio": "py_test_mark_sync", } diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..641c4b5 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +asyncio_mode = strict diff --git a/tests/conftest.py b/tests/conftest.py index bd530bb..9f067a3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,14 @@ from aredis_om import get_redis_connection TEST_PREFIX = "redis-om:testing" +py_test_mark_asyncio = pytest.mark.asyncio + + +# "pytest_mark_sync" causes problem in pytest +def py_test_mark_sync(f): + return f # no-op decorator + + @pytest.fixture(scope="session") def event_loop(request): loop = asyncio.get_event_loop_policy().new_event_loop() diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index b7c9b3f..0a79aa6 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -9,6 +9,7 @@ from typing import Dict, List, Optional, Set from unittest import mock import pytest +import pytest_asyncio from pydantic import ValidationError from aredis_om import ( @@ -22,7 +23,7 @@ from aredis_om import ( # We need to run this check as sync code (during tests) even in async mode # because we call it in the top-level module scope. from redis_om import has_redisearch - +from tests.conftest import py_test_mark_asyncio if not has_redisearch(): pytestmark = pytest.mark.skip @@ -30,7 +31,7 @@ if not has_redisearch(): today = datetime.date.today() -@pytest.fixture +@pytest_asyncio.fixture async def m(key_prefix, redis): class BaseHashModel(HashModel, abc.ABC): class Meta: @@ -60,7 +61,7 @@ async def m(key_prefix, redis): ) -@pytest.fixture +@pytest_asyncio.fixture async def members(m): member1 = m.Member( first_name="Andrew", @@ -95,7 +96,7 @@ async def members(m): yield member1, member2, member3 -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_exact_match_queries(members, m): member1, member2, member3 = members @@ -129,7 +130,7 @@ async def test_exact_match_queries(members, m): assert actual == [member2] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_full_text_search_queries(members, m): member1, member2, member3 = members @@ -142,7 +143,7 @@ async def test_full_text_search_queries(members, m): assert actual == [member1, member3] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_recursive_query_resolution(members, m): member1, member2, member3 = members @@ -157,7 +158,7 @@ async def test_recursive_query_resolution(members, m): assert actual == [member2, member1, member3] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_tag_queries_boolean_logic(members, m): member1, member2, member3 = members @@ -172,7 +173,7 @@ async def test_tag_queries_boolean_logic(members, m): assert actual == [member1, member3] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_tag_queries_punctuation(m): member1 = m.Member( first_name="Andrew, the Michael", @@ -210,7 +211,7 @@ async def test_tag_queries_punctuation(m): assert results == [member2] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_tag_queries_negation(members, m): member1, member2, member3 = members @@ -282,7 +283,7 @@ async def test_tag_queries_negation(members, m): assert actual == [member3] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_numeric_queries(members, m): member1, member2, member3 = members @@ -313,7 +314,7 @@ async def test_numeric_queries(members, m): assert actual == [member2, member1] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_sorting(members, m): member1, member2, member3 = members @@ -357,7 +358,8 @@ def test_validation_passes(m): ) assert member.first_name == "Andrew" -@pytest.mark.asyncio + +@py_test_mark_asyncio async def test_retrieve_first(m): member = m.Member( first_name="Simon", @@ -395,7 +397,8 @@ async def test_retrieve_first(m): first_one = await m.Member.find().sort_by("age").first() assert first_one == member3 -@pytest.mark.asyncio + +@py_test_mark_asyncio async def test_saves_model_and_creates_pk(m): member = m.Member( first_name="Andrew", @@ -411,7 +414,8 @@ async def test_saves_model_and_creates_pk(m): member2 = await m.Member.get(member.pk) assert member2 == member -@pytest.mark.asyncio + +@py_test_mark_asyncio async def test_all_pks(m): member = m.Member( first_name="Simon", @@ -441,7 +445,8 @@ async def test_all_pks(m): assert len(pk_list) == 2 -@pytest.mark.asyncio + +@py_test_mark_asyncio async def test_delete(m): member = m.Member( first_name="Simon", @@ -457,7 +462,7 @@ async def test_delete(m): assert response == 1 -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_expire(m): member = m.Member( first_name="Expire", @@ -521,7 +526,7 @@ def test_raises_error_with_lists(m): friend_ids: List[str] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_saves_many(m): member1 = m.Member( first_name="Andrew", @@ -547,7 +552,7 @@ async def test_saves_many(m): assert await m.Member.get(pk=member2.pk) == member2 -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_updates_a_model(members, m): member1, member2, member3 = members await member1.update(last_name="Smith") @@ -555,14 +560,14 @@ async def test_updates_a_model(members, m): assert member.last_name == "Smith" -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_paginate_query(members, m): member1, member2, member3 = members actual = await m.Member.find().sort_by("age").all(batch_size=1) assert actual == [member2, member1, member3] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_access_result_by_index_cached(members, m): member1, member2, member3 = members query = m.Member.find().sort_by("age") @@ -577,7 +582,7 @@ async def test_access_result_by_index_cached(members, m): assert not mock_db.called -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_access_result_by_index_not_cached(members, m): member1, member2, member3 = members query = m.Member.find().sort_by("age") diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 13c2834..8a114f9 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -9,6 +9,7 @@ from typing import Dict, List, Optional, Set from unittest import mock import pytest +import pytest_asyncio from pydantic import ValidationError from aredis_om import ( @@ -24,7 +25,7 @@ from aredis_om import ( # We need to run this check as sync code (during tests) even in async mode # because we call it in the top-level module scope. from redis_om import has_redis_json - +from tests.conftest import py_test_mark_asyncio if not has_redis_json(): pytestmark = pytest.mark.skip @@ -32,7 +33,7 @@ if not has_redis_json(): today = datetime.date.today() -@pytest.fixture +@pytest_asyncio.fixture async def m(key_prefix, redis): class BaseJsonModel(JsonModel, abc.ABC): class Meta: @@ -94,7 +95,7 @@ def address(m): ) -@pytest.fixture() +@pytest_asyncio.fixture() async def members(address, m): member1 = m.Member( first_name="Andrew", @@ -130,7 +131,7 @@ async def members(address, m): yield member1, member2, member3 -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_validates_required_fields(address, m): # Raises ValidationError address is required with pytest.raises(ValidationError): @@ -142,7 +143,7 @@ async def test_validates_required_fields(address, m): ) -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_validates_field(address, m): # Raises ValidationError: join_date is not a date with pytest.raises(ValidationError): @@ -154,7 +155,7 @@ async def test_validates_field(address, m): ) -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_validation_passes(address, m): member = m.Member( first_name="Andrew", @@ -167,7 +168,7 @@ async def test_validation_passes(address, m): assert member.first_name == "Andrew" -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_saves_model_and_creates_pk(address, m, redis): await Migrator().run() @@ -186,7 +187,8 @@ async def test_saves_model_and_creates_pk(address, m, redis): assert member2 == member assert member2.address == address -@pytest.mark.asyncio + +@py_test_mark_asyncio async def test_all_pks(address, m, redis): member = m.Member( first_name="Andrew", @@ -197,7 +199,7 @@ async def test_all_pks(address, m, redis): address=address, ) - await member.save() + await member.save() member1 = m.Member( first_name="Simon", @@ -216,7 +218,8 @@ async def test_all_pks(address, m, redis): assert len(pk_list) == 2 -@pytest.mark.asyncio + +@py_test_mark_asyncio async def test_delete(address, m, redis): member = m.Member( first_name="Simon", @@ -232,7 +235,7 @@ async def test_delete(address, m, redis): assert response == 1 -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_saves_many_implicit_pipeline(address, m): member1 = m.Member( first_name="Andrew", @@ -258,7 +261,7 @@ async def test_saves_many_implicit_pipeline(address, m): assert await m.Member.get(pk=member2.pk) == member2 -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_saves_many_explicit_transaction(address, m): member1 = m.Member( first_name="Andrew", @@ -300,7 +303,7 @@ async def save(members): return members -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_updates_a_model(members, m): member1, member2, member3 = await save(members) @@ -315,14 +318,14 @@ async def test_updates_a_model(members, m): assert member.address.city == "Happy Valley" -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_paginate_query(members, m): member1, member2, member3 = members actual = await m.Member.find().sort_by("age").all(batch_size=1) assert actual == [member2, member1, member3] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_access_result_by_index_cached(members, m): member1, member2, member3 = members query = m.Member.find().sort_by("age") @@ -337,7 +340,7 @@ async def test_access_result_by_index_cached(members, m): assert not mock_db.called -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_access_result_by_index_not_cached(members, m): member1, member2, member3 = members query = m.Member.find().sort_by("age") @@ -350,7 +353,7 @@ async def test_access_result_by_index_not_cached(members, m): assert await query.get_item(2) == member3 -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_in_query(members, m): member1, member2, member3 = members actual = await ( @@ -361,7 +364,7 @@ async def test_in_query(members, m): assert actual == [member2, member1, member3] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_update_query(members, m): member1, member2, member3 = members await m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk]).update( @@ -376,7 +379,7 @@ async def test_update_query(members, m): assert all([m.first_name == "Bobby" for m in actual]) -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_exact_match_queries(members, m): member1, member2, member3 = members @@ -415,7 +418,7 @@ async def test_exact_match_queries(members, m): assert actual == [member2, member1, member3] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_recursive_query_expression_resolution(members, m): member1, member2, member3 = members @@ -430,7 +433,7 @@ async def test_recursive_query_expression_resolution(members, m): assert actual == [member2, member1, member3] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_recursive_query_field_resolution(members, m): member1, _, _ = members member1.address.note = m.Note( @@ -455,7 +458,7 @@ async def test_recursive_query_field_resolution(members, m): assert actual[0].orders[0].items[0].name == "Ball" -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_full_text_search(members, m): member1, member2, _ = members await member1.update(bio="Hates sunsets, likes beaches") @@ -468,7 +471,7 @@ async def test_full_text_search(members, m): assert actual == [member2] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_tag_queries_boolean_logic(members, m): member1, member2, member3 = members @@ -483,7 +486,7 @@ async def test_tag_queries_boolean_logic(members, m): assert actual == [member1, member3] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_tag_queries_punctuation(address, m): member1 = m.Member( first_name="Andrew, the Michael", @@ -524,7 +527,7 @@ async def test_tag_queries_punctuation(address, m): ] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_tag_queries_negation(members, m): member1, member2, member3 = members @@ -596,7 +599,7 @@ async def test_tag_queries_negation(members, m): assert actual == [member3] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_numeric_queries(members, m): member1, member2, member3 = members @@ -627,7 +630,7 @@ async def test_numeric_queries(members, m): assert actual == [member1, member3] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_sorting(members, m): member1, member2, member3 = members @@ -646,14 +649,14 @@ async def test_sorting(members, m): await m.Member.find().sort_by("join_date").all() -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_not_found(m): with pytest.raises(NotFoundError): # This ID does not exist. await m.Member.get(1000) -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_list_field_limitations(m, redis): with pytest.raises(RedisModelError): @@ -706,7 +709,7 @@ async def test_list_field_limitations(m, redis): assert actual == [witch] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_allows_dataclasses(m): @dataclasses.dataclass class Address: @@ -724,7 +727,7 @@ async def test_allows_dataclasses(m): assert member2.address.address_line_1 == "hey" -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_allows_and_serializes_dicts(m): class ValidMember(m.BaseJsonModel): address: Dict[str, str] @@ -737,7 +740,7 @@ async def test_allows_and_serializes_dicts(m): assert member2.address["address_line_1"] == "hey" -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_allows_and_serializes_sets(m): class ValidMember(m.BaseJsonModel): friend_ids: Set[int] @@ -750,7 +753,7 @@ async def test_allows_and_serializes_sets(m): assert member2.friend_ids == {1, 2} -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_allows_and_serializes_lists(m): class ValidMember(m.BaseJsonModel): friend_ids: List[int] @@ -763,7 +766,7 @@ async def test_allows_and_serializes_lists(m): assert member2.friend_ids == [1, 2] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_schema(m, key_prefix): # We need to build the key prefix because it will differ based on whether # these tests were copied into the tests_sync folder and unasynce'd. diff --git a/tests/test_oss_redis_features.py b/tests/test_oss_redis_features.py index 69aded9..bdad16c 100644 --- a/tests/test_oss_redis_features.py +++ b/tests/test_oss_redis_features.py @@ -5,15 +5,17 @@ from collections import namedtuple from typing import Optional import pytest +import pytest_asyncio from pydantic import ValidationError from aredis_om import HashModel, Migrator, NotFoundError, RedisModelError +from tests.conftest import py_test_mark_asyncio today = datetime.date.today() -@pytest.fixture +@pytest_asyncio.fixture async def m(key_prefix, redis): class BaseHashModel(HashModel, abc.ABC): class Meta: @@ -42,7 +44,7 @@ async def m(key_prefix, redis): ) -@pytest.fixture +@pytest_asyncio.fixture async def members(m): member1 = m.Member( first_name="Andrew", @@ -74,14 +76,14 @@ async def members(m): yield member1, member2, member3 -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_all_keys(members, m): pks = sorted([pk async for pk in await m.Member.all_pks()]) assert len(pks) == 3 assert pks == sorted([m.pk for m in members]) -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_not_found(m): with pytest.raises(NotFoundError): # This ID does not exist. @@ -113,7 +115,7 @@ def test_validation_passes(m): assert member.first_name == "Andrew" -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_saves_model_and_creates_pk(m): member = m.Member( first_name="Andrew", @@ -143,7 +145,7 @@ def test_raises_error_with_embedded_models(m): address: Address -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_saves_many(m): member1 = m.Member( first_name="Andrew", @@ -167,7 +169,7 @@ async def test_saves_many(m): assert await m.Member.get(pk=member2.pk) == member2 -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_updates_a_model(members, m): member1, member2, member3 = members await member1.update(last_name="Smith") diff --git a/tests/test_pydantic_integrations.py b/tests/test_pydantic_integrations.py index 1fc9457..5ff735f 100644 --- a/tests/test_pydantic_integrations.py +++ b/tests/test_pydantic_integrations.py @@ -3,6 +3,7 @@ import datetime from collections import namedtuple import pytest +import pytest_asyncio from pydantic import EmailStr, ValidationError from aredis_om import Field, HashModel, Migrator @@ -11,7 +12,7 @@ from aredis_om import Field, HashModel, Migrator today = datetime.date.today() -@pytest.fixture +@pytest_asyncio.fixture async def m(key_prefix, redis): class BaseHashModel(HashModel, abc.ABC): class Meta: