diff --git a/Makefile b/Makefile index 0ff65d0..9cd4b1d 100644 --- a/Makefile +++ b/Makefile @@ -33,6 +33,7 @@ clean: rm -rf build rm -rf dist rm -rf redis_om + rm -rf tests_sync docker-compose down diff --git a/aredis_om/connections.py b/aredis_om/connections.py index bbae0c6..940682c 100644 --- a/aredis_om/connections.py +++ b/aredis_om/connections.py @@ -2,6 +2,7 @@ import os import aioredis + URL = os.environ.get("REDIS_OM_URL", None) diff --git a/make_sync.py b/make_sync.py index a457882..3282fc2 100644 --- a/make_sync.py +++ b/make_sync.py @@ -1,5 +1,6 @@ import os from pathlib import Path +from typing import Iterable, Optional, Union import unasync @@ -7,8 +8,45 @@ ADDITIONAL_REPLACEMENTS = { "aredis_om": "redis_om", "aioredis": "redis", ":tests.": ":tests_sync.", + "pytest_asyncio": "pytest", } +STRINGS_TO_REMOVE_FROM_SYNC_TESTS = { + "@pytest.mark.asyncio", +} + + +def remove_strings_from_files( + filepaths: Iterable[Union[bytes, str, os.PathLike]], + strings_to_remove: Iterable[str], +): + for filepath in filepaths: + tmp_filepath = f"{filepath}.tmp" + with open(filepath, "r") as read_file, open(tmp_filepath, "w") as write_file: + for line in read_file: + if line.strip() in strings_to_remove: + continue + print(line, end="", file=write_file) + os.replace(tmp_filepath, filepath) + + +def get_source_filepaths(directory: Optional[Union[bytes, str, os.PathLike]] = None): + walk_path = ( + Path(__file__).absolute().parent + if directory is None + else os.path.join(Path(__file__).absolute().parent, directory) + ) + + filepaths = [] + for root, _, filenames in os.walk(walk_path): + for filename in filenames: + if filename.rpartition(".")[-1] in ( + "py", + "pyi", + ): + filepaths.append(os.path.join(root, filename)) + return filepaths + def main(): rules = [ @@ -23,15 +61,11 @@ def main(): additional_replacements=ADDITIONAL_REPLACEMENTS, ), ] - filepaths = [] - for root, _, filenames in os.walk( - Path(__file__).absolute().parent - ): - for filename in filenames: - if filename.rpartition(".")[-1] in ("py", "pyi",): - filepaths.append(os.path.join(root, filename)) - unasync.unasync_files(filepaths, rules) + unasync.unasync_files(get_source_filepaths(), rules) + remove_strings_from_files( + get_source_filepaths("tests_sync"), STRINGS_TO_REMOVE_FROM_SYNC_TESTS + ) if __name__ == "__main__": diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..641c4b5 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +asyncio_mode = strict diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index b7c9b3f..659b862 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -9,6 +9,7 @@ from typing import Dict, List, Optional, Set from unittest import mock import pytest +import pytest_asyncio from pydantic import ValidationError from aredis_om import ( @@ -30,7 +31,7 @@ if not has_redisearch(): today = datetime.date.today() -@pytest.fixture +@pytest_asyncio.fixture async def m(key_prefix, redis): class BaseHashModel(HashModel, abc.ABC): class Meta: @@ -60,7 +61,7 @@ async def m(key_prefix, redis): ) -@pytest.fixture +@pytest_asyncio.fixture async def members(m): member1 = m.Member( first_name="Andrew", @@ -357,6 +358,7 @@ def test_validation_passes(m): ) assert member.first_name == "Andrew" + @pytest.mark.asyncio async def test_retrieve_first(m): member = m.Member( @@ -395,6 +397,7 @@ async def test_retrieve_first(m): first_one = await m.Member.find().sort_by("age").first() assert first_one == member3 + @pytest.mark.asyncio async def test_saves_model_and_creates_pk(m): member = m.Member( @@ -411,6 +414,7 @@ async def test_saves_model_and_creates_pk(m): member2 = await m.Member.get(member.pk) assert member2 == member + @pytest.mark.asyncio async def test_all_pks(m): member = m.Member( @@ -441,6 +445,7 @@ async def test_all_pks(m): assert len(pk_list) == 2 + @pytest.mark.asyncio async def test_delete(m): member = m.Member( diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 13c2834..6fcef62 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -9,6 +9,7 @@ from typing import Dict, List, Optional, Set from unittest import mock import pytest +import pytest_asyncio from pydantic import ValidationError from aredis_om import ( @@ -32,7 +33,7 @@ if not has_redis_json(): today = datetime.date.today() -@pytest.fixture +@pytest_asyncio.fixture async def m(key_prefix, redis): class BaseJsonModel(JsonModel, abc.ABC): class Meta: @@ -94,7 +95,7 @@ def address(m): ) -@pytest.fixture() +@pytest_asyncio.fixture() async def members(address, m): member1 = m.Member( first_name="Andrew", @@ -186,6 +187,7 @@ async def test_saves_model_and_creates_pk(address, m, redis): assert member2 == member assert member2.address == address + @pytest.mark.asyncio async def test_all_pks(address, m, redis): member = m.Member( @@ -197,7 +199,7 @@ async def test_all_pks(address, m, redis): address=address, ) - await member.save() + await member.save() member1 = m.Member( first_name="Simon", @@ -216,6 +218,7 @@ async def test_all_pks(address, m, redis): assert len(pk_list) == 2 + @pytest.mark.asyncio async def test_delete(address, m, redis): member = m.Member( diff --git a/tests/test_oss_redis_features.py b/tests/test_oss_redis_features.py index 69aded9..538469d 100644 --- a/tests/test_oss_redis_features.py +++ b/tests/test_oss_redis_features.py @@ -5,6 +5,7 @@ from collections import namedtuple from typing import Optional import pytest +import pytest_asyncio from pydantic import ValidationError from aredis_om import HashModel, Migrator, NotFoundError, RedisModelError @@ -13,7 +14,7 @@ from aredis_om import HashModel, Migrator, NotFoundError, RedisModelError today = datetime.date.today() -@pytest.fixture +@pytest_asyncio.fixture async def m(key_prefix, redis): class BaseHashModel(HashModel, abc.ABC): class Meta: @@ -42,7 +43,7 @@ async def m(key_prefix, redis): ) -@pytest.fixture +@pytest_asyncio.fixture async def members(m): member1 = m.Member( first_name="Andrew", diff --git a/tests/test_pydantic_integrations.py b/tests/test_pydantic_integrations.py index 1fc9457..5ff735f 100644 --- a/tests/test_pydantic_integrations.py +++ b/tests/test_pydantic_integrations.py @@ -3,6 +3,7 @@ import datetime from collections import namedtuple import pytest +import pytest_asyncio from pydantic import EmailStr, ValidationError from aredis_om import Field, HashModel, Migrator @@ -11,7 +12,7 @@ from aredis_om import Field, HashModel, Migrator today = datetime.date.today() -@pytest.fixture +@pytest_asyncio.fixture async def m(key_prefix, redis): class BaseHashModel(HashModel, abc.ABC): class Meta: