diff --git a/make_sync.py b/make_sync.py index 3282fc2..881a74f 100644 --- a/make_sync.py +++ b/make_sync.py @@ -1,6 +1,5 @@ import os from pathlib import Path -from typing import Iterable, Optional, Union import unasync @@ -9,44 +8,9 @@ ADDITIONAL_REPLACEMENTS = { "aioredis": "redis", ":tests.": ":tests_sync.", "pytest_asyncio": "pytest", + "py_test_mark_asyncio": "py_test_mark_sync", } -STRINGS_TO_REMOVE_FROM_SYNC_TESTS = { - "@pytest.mark.asyncio", -} - - -def remove_strings_from_files( - filepaths: Iterable[Union[bytes, str, os.PathLike]], - strings_to_remove: Iterable[str], -): - for filepath in filepaths: - tmp_filepath = f"{filepath}.tmp" - with open(filepath, "r") as read_file, open(tmp_filepath, "w") as write_file: - for line in read_file: - if line.strip() in strings_to_remove: - continue - print(line, end="", file=write_file) - os.replace(tmp_filepath, filepath) - - -def get_source_filepaths(directory: Optional[Union[bytes, str, os.PathLike]] = None): - walk_path = ( - Path(__file__).absolute().parent - if directory is None - else os.path.join(Path(__file__).absolute().parent, directory) - ) - - filepaths = [] - for root, _, filenames in os.walk(walk_path): - for filename in filenames: - if filename.rpartition(".")[-1] in ( - "py", - "pyi", - ): - filepaths.append(os.path.join(root, filename)) - return filepaths - def main(): rules = [ @@ -61,11 +25,15 @@ def main(): additional_replacements=ADDITIONAL_REPLACEMENTS, ), ] + filepaths = [] + for root, _, filenames in os.walk( + Path(__file__).absolute().parent + ): + for filename in filenames: + if filename.rpartition(".")[-1] in ("py", "pyi",): + filepaths.append(os.path.join(root, filename)) - unasync.unasync_files(get_source_filepaths(), rules) - remove_strings_from_files( - get_source_filepaths("tests_sync"), STRINGS_TO_REMOVE_FROM_SYNC_TESTS - ) + unasync.unasync_files(filepaths, rules) if __name__ == "__main__": diff --git a/tests/conftest.py b/tests/conftest.py index bd530bb..9f067a3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,14 @@ from aredis_om import get_redis_connection TEST_PREFIX = "redis-om:testing" +py_test_mark_asyncio = pytest.mark.asyncio + + +# "pytest_mark_sync" causes problem in pytest +def py_test_mark_sync(f): + return f # no-op decorator + + @pytest.fixture(scope="session") def event_loop(request): loop = asyncio.get_event_loop_policy().new_event_loop() diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index 659b862..0a79aa6 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -23,7 +23,7 @@ from aredis_om import ( # We need to run this check as sync code (during tests) even in async mode # because we call it in the top-level module scope. from redis_om import has_redisearch - +from tests.conftest import py_test_mark_asyncio if not has_redisearch(): pytestmark = pytest.mark.skip @@ -96,7 +96,7 @@ async def members(m): yield member1, member2, member3 -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_exact_match_queries(members, m): member1, member2, member3 = members @@ -130,7 +130,7 @@ async def test_exact_match_queries(members, m): assert actual == [member2] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_full_text_search_queries(members, m): member1, member2, member3 = members @@ -143,7 +143,7 @@ async def test_full_text_search_queries(members, m): assert actual == [member1, member3] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_recursive_query_resolution(members, m): member1, member2, member3 = members @@ -158,7 +158,7 @@ async def test_recursive_query_resolution(members, m): assert actual == [member2, member1, member3] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_tag_queries_boolean_logic(members, m): member1, member2, member3 = members @@ -173,7 +173,7 @@ async def test_tag_queries_boolean_logic(members, m): assert actual == [member1, member3] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_tag_queries_punctuation(m): member1 = m.Member( first_name="Andrew, the Michael", @@ -211,7 +211,7 @@ async def test_tag_queries_punctuation(m): assert results == [member2] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_tag_queries_negation(members, m): member1, member2, member3 = members @@ -283,7 +283,7 @@ async def test_tag_queries_negation(members, m): assert actual == [member3] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_numeric_queries(members, m): member1, member2, member3 = members @@ -314,7 +314,7 @@ async def test_numeric_queries(members, m): assert actual == [member2, member1] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_sorting(members, m): member1, member2, member3 = members @@ -359,7 +359,7 @@ def test_validation_passes(m): assert member.first_name == "Andrew" -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_retrieve_first(m): member = m.Member( first_name="Simon", @@ -398,7 +398,7 @@ async def test_retrieve_first(m): assert first_one == member3 -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_saves_model_and_creates_pk(m): member = m.Member( first_name="Andrew", @@ -415,7 +415,7 @@ async def test_saves_model_and_creates_pk(m): assert member2 == member -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_all_pks(m): member = m.Member( first_name="Simon", @@ -446,7 +446,7 @@ async def test_all_pks(m): assert len(pk_list) == 2 -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_delete(m): member = m.Member( first_name="Simon", @@ -462,7 +462,7 @@ async def test_delete(m): assert response == 1 -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_expire(m): member = m.Member( first_name="Expire", @@ -526,7 +526,7 @@ def test_raises_error_with_lists(m): friend_ids: List[str] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_saves_many(m): member1 = m.Member( first_name="Andrew", @@ -552,7 +552,7 @@ async def test_saves_many(m): assert await m.Member.get(pk=member2.pk) == member2 -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_updates_a_model(members, m): member1, member2, member3 = members await member1.update(last_name="Smith") @@ -560,14 +560,14 @@ async def test_updates_a_model(members, m): assert member.last_name == "Smith" -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_paginate_query(members, m): member1, member2, member3 = members actual = await m.Member.find().sort_by("age").all(batch_size=1) assert actual == [member2, member1, member3] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_access_result_by_index_cached(members, m): member1, member2, member3 = members query = m.Member.find().sort_by("age") @@ -582,7 +582,7 @@ async def test_access_result_by_index_cached(members, m): assert not mock_db.called -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_access_result_by_index_not_cached(members, m): member1, member2, member3 = members query = m.Member.find().sort_by("age") diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 6fcef62..8a114f9 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -25,7 +25,7 @@ from aredis_om import ( # We need to run this check as sync code (during tests) even in async mode # because we call it in the top-level module scope. from redis_om import has_redis_json - +from tests.conftest import py_test_mark_asyncio if not has_redis_json(): pytestmark = pytest.mark.skip @@ -131,7 +131,7 @@ async def members(address, m): yield member1, member2, member3 -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_validates_required_fields(address, m): # Raises ValidationError address is required with pytest.raises(ValidationError): @@ -143,7 +143,7 @@ async def test_validates_required_fields(address, m): ) -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_validates_field(address, m): # Raises ValidationError: join_date is not a date with pytest.raises(ValidationError): @@ -155,7 +155,7 @@ async def test_validates_field(address, m): ) -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_validation_passes(address, m): member = m.Member( first_name="Andrew", @@ -168,7 +168,7 @@ async def test_validation_passes(address, m): assert member.first_name == "Andrew" -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_saves_model_and_creates_pk(address, m, redis): await Migrator().run() @@ -188,7 +188,7 @@ async def test_saves_model_and_creates_pk(address, m, redis): assert member2.address == address -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_all_pks(address, m, redis): member = m.Member( first_name="Andrew", @@ -219,7 +219,7 @@ async def test_all_pks(address, m, redis): assert len(pk_list) == 2 -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_delete(address, m, redis): member = m.Member( first_name="Simon", @@ -235,7 +235,7 @@ async def test_delete(address, m, redis): assert response == 1 -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_saves_many_implicit_pipeline(address, m): member1 = m.Member( first_name="Andrew", @@ -261,7 +261,7 @@ async def test_saves_many_implicit_pipeline(address, m): assert await m.Member.get(pk=member2.pk) == member2 -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_saves_many_explicit_transaction(address, m): member1 = m.Member( first_name="Andrew", @@ -303,7 +303,7 @@ async def save(members): return members -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_updates_a_model(members, m): member1, member2, member3 = await save(members) @@ -318,14 +318,14 @@ async def test_updates_a_model(members, m): assert member.address.city == "Happy Valley" -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_paginate_query(members, m): member1, member2, member3 = members actual = await m.Member.find().sort_by("age").all(batch_size=1) assert actual == [member2, member1, member3] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_access_result_by_index_cached(members, m): member1, member2, member3 = members query = m.Member.find().sort_by("age") @@ -340,7 +340,7 @@ async def test_access_result_by_index_cached(members, m): assert not mock_db.called -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_access_result_by_index_not_cached(members, m): member1, member2, member3 = members query = m.Member.find().sort_by("age") @@ -353,7 +353,7 @@ async def test_access_result_by_index_not_cached(members, m): assert await query.get_item(2) == member3 -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_in_query(members, m): member1, member2, member3 = members actual = await ( @@ -364,7 +364,7 @@ async def test_in_query(members, m): assert actual == [member2, member1, member3] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_update_query(members, m): member1, member2, member3 = members await m.Member.find(m.Member.pk << [member1.pk, member2.pk, member3.pk]).update( @@ -379,7 +379,7 @@ async def test_update_query(members, m): assert all([m.first_name == "Bobby" for m in actual]) -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_exact_match_queries(members, m): member1, member2, member3 = members @@ -418,7 +418,7 @@ async def test_exact_match_queries(members, m): assert actual == [member2, member1, member3] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_recursive_query_expression_resolution(members, m): member1, member2, member3 = members @@ -433,7 +433,7 @@ async def test_recursive_query_expression_resolution(members, m): assert actual == [member2, member1, member3] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_recursive_query_field_resolution(members, m): member1, _, _ = members member1.address.note = m.Note( @@ -458,7 +458,7 @@ async def test_recursive_query_field_resolution(members, m): assert actual[0].orders[0].items[0].name == "Ball" -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_full_text_search(members, m): member1, member2, _ = members await member1.update(bio="Hates sunsets, likes beaches") @@ -471,7 +471,7 @@ async def test_full_text_search(members, m): assert actual == [member2] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_tag_queries_boolean_logic(members, m): member1, member2, member3 = members @@ -486,7 +486,7 @@ async def test_tag_queries_boolean_logic(members, m): assert actual == [member1, member3] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_tag_queries_punctuation(address, m): member1 = m.Member( first_name="Andrew, the Michael", @@ -527,7 +527,7 @@ async def test_tag_queries_punctuation(address, m): ] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_tag_queries_negation(members, m): member1, member2, member3 = members @@ -599,7 +599,7 @@ async def test_tag_queries_negation(members, m): assert actual == [member3] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_numeric_queries(members, m): member1, member2, member3 = members @@ -630,7 +630,7 @@ async def test_numeric_queries(members, m): assert actual == [member1, member3] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_sorting(members, m): member1, member2, member3 = members @@ -649,14 +649,14 @@ async def test_sorting(members, m): await m.Member.find().sort_by("join_date").all() -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_not_found(m): with pytest.raises(NotFoundError): # This ID does not exist. await m.Member.get(1000) -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_list_field_limitations(m, redis): with pytest.raises(RedisModelError): @@ -709,7 +709,7 @@ async def test_list_field_limitations(m, redis): assert actual == [witch] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_allows_dataclasses(m): @dataclasses.dataclass class Address: @@ -727,7 +727,7 @@ async def test_allows_dataclasses(m): assert member2.address.address_line_1 == "hey" -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_allows_and_serializes_dicts(m): class ValidMember(m.BaseJsonModel): address: Dict[str, str] @@ -740,7 +740,7 @@ async def test_allows_and_serializes_dicts(m): assert member2.address["address_line_1"] == "hey" -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_allows_and_serializes_sets(m): class ValidMember(m.BaseJsonModel): friend_ids: Set[int] @@ -753,7 +753,7 @@ async def test_allows_and_serializes_sets(m): assert member2.friend_ids == {1, 2} -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_allows_and_serializes_lists(m): class ValidMember(m.BaseJsonModel): friend_ids: List[int] @@ -766,7 +766,7 @@ async def test_allows_and_serializes_lists(m): assert member2.friend_ids == [1, 2] -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_schema(m, key_prefix): # We need to build the key prefix because it will differ based on whether # these tests were copied into the tests_sync folder and unasynce'd. diff --git a/tests/test_oss_redis_features.py b/tests/test_oss_redis_features.py index 538469d..bdad16c 100644 --- a/tests/test_oss_redis_features.py +++ b/tests/test_oss_redis_features.py @@ -9,6 +9,7 @@ import pytest_asyncio from pydantic import ValidationError from aredis_om import HashModel, Migrator, NotFoundError, RedisModelError +from tests.conftest import py_test_mark_asyncio today = datetime.date.today() @@ -75,14 +76,14 @@ async def members(m): yield member1, member2, member3 -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_all_keys(members, m): pks = sorted([pk async for pk in await m.Member.all_pks()]) assert len(pks) == 3 assert pks == sorted([m.pk for m in members]) -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_not_found(m): with pytest.raises(NotFoundError): # This ID does not exist. @@ -114,7 +115,7 @@ def test_validation_passes(m): assert member.first_name == "Andrew" -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_saves_model_and_creates_pk(m): member = m.Member( first_name="Andrew", @@ -144,7 +145,7 @@ def test_raises_error_with_embedded_models(m): address: Address -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_saves_many(m): member1 = m.Member( first_name="Andrew", @@ -168,7 +169,7 @@ async def test_saves_many(m): assert await m.Member.get(pk=member2.pk) == member2 -@pytest.mark.asyncio +@py_test_mark_asyncio async def test_updates_a_model(members, m): member1, member2, member3 = members await member1.update(last_name="Smith")