fixed a potential bug (#337)

* fixed a potential bug

Signed-off-by: wiseaidev <business@wiseai.dev>

* add unit tests

Signed-off-by: wiseaidev <business@wiseai.dev>

* remove unnecessary logic related to six module

Signed-off-by: wiseaidev <business@wiseai.dev>

* remove six from dependencies

Signed-off-by: wiseaidev <business@wiseai.dev>

* pass "ignore" as a kwarg

Signed-off-by: wiseaidev <business@wiseai.dev>

* get rid of try catch and simplify logic

Signed-off-by: wiseaidev <business@wiseai.dev>

* rm poetry.lock

Signed-off-by: wiseaidev <business@wiseai.dev>

* rm poetry.lock

Signed-off-by: wiseaidev <business@wiseai.dev>

* run black

Signed-off-by: wiseaidev <business@wiseai.dev>

* fix mypy issue

Signed-off-by: wiseaidev <business@wiseai.dev>

* adjust other tests accordingly

Signed-off-by: wiseaidev <business@wiseai.dev>
This commit is contained in:
Mahmoud Harmouch 2022-08-09 17:40:27 +03:00 committed by GitHub
parent b103adbe6d
commit ac6a75be19
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 33 additions and 32 deletions

View file

@ -1179,15 +1179,11 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
@classmethod @classmethod
def from_redis(cls, res: Any): def from_redis(cls, res: Any):
# TODO: Parsing logic copied from redisearch-py. Evaluate. # TODO: Parsing logic copied from redisearch-py. Evaluate.
import six
from six.moves import xrange
from six.moves import zip as izip
def to_string(s): def to_string(s):
if isinstance(s, six.string_types): if isinstance(s, (str,)):
return s return s
elif isinstance(s, six.binary_type): elif isinstance(s, bytes):
return s.decode("utf-8", "ignore") return s.decode(errors="ignore")
else: else:
return s # Not a string we care about return s # Not a string we care about
@ -1195,34 +1191,20 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
step = 2 # Because the result has content step = 2 # Because the result has content
offset = 1 # The first item is the count of total matches. offset = 1 # The first item is the count of total matches.
for i in xrange(1, len(res), step): for i in range(1, len(res), step):
fields_offset = offset
fields = dict( fields = dict(
dict( zip(
izip( map(to_string, res[i + offset][::2]),
map(to_string, res[i + fields_offset][::2]), map(to_string, res[i + offset][1::2]),
map(to_string, res[i + fields_offset][1::2]),
) )
) )
) # $ means a json entry
if fields.get("$"):
try: json_fields = json.loads(fields.pop("$"))
del fields["id"]
except KeyError:
pass
try:
fields["json"] = fields["$"]
del fields["$"]
except KeyError:
pass
if "json" in fields:
json_fields = json.loads(fields["json"])
doc = cls(**json_fields) doc = cls(**json_fields)
else: else:
doc = cls(**fields) doc = cls(**fields)
docs.append(doc) docs.append(doc)
return docs return docs

View file

@ -27,7 +27,6 @@ redis = ">=3.5.3,<5.0.0"
aioredis = "^2.0.0" aioredis = "^2.0.0"
pydantic = "^1.8.2" pydantic = "^1.8.2"
click = "^8.0.1" click = "^8.0.1"
six = "^1.16.0"
pptree = "^3.1" pptree = "^3.1"
types-redis = ">=3.5.9,<5.0.0" types-redis = ">=3.5.9,<5.0.0"
types-six = "^1.16.1" types-six = "^1.16.1"

View file

@ -43,6 +43,7 @@ async def m(key_prefix, redis):
created_on: datetime.datetime created_on: datetime.datetime
class Member(BaseHashModel): class Member(BaseHashModel):
id: int = Field(index=True)
first_name: str = Field(index=True) first_name: str = Field(index=True)
last_name: str = Field(index=True) last_name: str = Field(index=True)
email: str = Field(index=True) email: str = Field(index=True)
@ -64,6 +65,7 @@ async def m(key_prefix, redis):
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def members(m): async def members(m):
member1 = m.Member( member1 = m.Member(
id=0,
first_name="Andrew", first_name="Andrew",
last_name="Brookins", last_name="Brookins",
email="a@example.com", email="a@example.com",
@ -73,6 +75,7 @@ async def members(m):
) )
member2 = m.Member( member2 = m.Member(
id=1,
first_name="Kim", first_name="Kim",
last_name="Brookins", last_name="Brookins",
email="k@example.com", email="k@example.com",
@ -82,6 +85,7 @@ async def members(m):
) )
member3 = m.Member( member3 = m.Member(
id=2,
first_name="Andrew", first_name="Andrew",
last_name="Smith", last_name="Smith",
email="as@example.com", email="as@example.com",
@ -129,6 +133,9 @@ async def test_exact_match_queries(members, m):
).all() ).all()
assert actual == [member2] assert actual == [member2]
actual = await m.Member.find(m.Member.id == 0).all()
assert actual == [member1]
@py_test_mark_asyncio @py_test_mark_asyncio
async def test_full_text_search_queries(members, m): async def test_full_text_search_queries(members, m):
@ -176,6 +183,7 @@ async def test_tag_queries_boolean_logic(members, m):
@py_test_mark_asyncio @py_test_mark_asyncio
async def test_tag_queries_punctuation(m): async def test_tag_queries_punctuation(m):
member1 = m.Member( member1 = m.Member(
id=0,
first_name="Andrew, the Michael", first_name="Andrew, the Michael",
last_name="St. Brookins-on-Pier", last_name="St. Brookins-on-Pier",
email="a|b@example.com", # NOTE: This string uses the TAG field separator. email="a|b@example.com", # NOTE: This string uses the TAG field separator.
@ -186,6 +194,7 @@ async def test_tag_queries_punctuation(m):
await member1.save() await member1.save()
member2 = m.Member( member2 = m.Member(
id=1,
first_name="Bob", first_name="Bob",
last_name="the Villain", last_name="the Villain",
email="a|villain@example.com", # NOTE: This string uses the TAG field separator. email="a|villain@example.com", # NOTE: This string uses the TAG field separator.
@ -337,18 +346,19 @@ def test_validates_required_fields(m):
# Raises ValidationError: last_name is required # Raises ValidationError: last_name is required
# TODO: Test the error value # TODO: Test the error value
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
m.Member(first_name="Andrew", zipcode="97086", join_date=today) m.Member(id=0, first_name="Andrew", zipcode="97086", join_date=today)
def test_validates_field(m): def test_validates_field(m):
# Raises ValidationError: join_date is not a date # Raises ValidationError: join_date is not a date
# TODO: Test the error value # TODO: Test the error value
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
m.Member(first_name="Andrew", last_name="Brookins", join_date="yesterday") m.Member(id=0, first_name="Andrew", last_name="Brookins", join_date="yesterday")
def test_validation_passes(m): def test_validation_passes(m):
member = m.Member( member = m.Member(
id=0,
first_name="Andrew", first_name="Andrew",
last_name="Brookins", last_name="Brookins",
email="a@example.com", email="a@example.com",
@ -362,6 +372,7 @@ def test_validation_passes(m):
@py_test_mark_asyncio @py_test_mark_asyncio
async def test_retrieve_first(m): async def test_retrieve_first(m):
member = m.Member( member = m.Member(
id=0,
first_name="Simon", first_name="Simon",
last_name="Prickett", last_name="Prickett",
email="s@example.com", email="s@example.com",
@ -373,6 +384,7 @@ async def test_retrieve_first(m):
await member.save() await member.save()
member2 = m.Member( member2 = m.Member(
id=1,
first_name="Another", first_name="Another",
last_name="Member", last_name="Member",
email="m@example.com", email="m@example.com",
@ -384,6 +396,7 @@ async def test_retrieve_first(m):
await member2.save() await member2.save()
member3 = m.Member( member3 = m.Member(
id=2,
first_name="Third", first_name="Third",
last_name="Member", last_name="Member",
email="t@example.com", email="t@example.com",
@ -401,6 +414,7 @@ async def test_retrieve_first(m):
@py_test_mark_asyncio @py_test_mark_asyncio
async def test_saves_model_and_creates_pk(m): async def test_saves_model_and_creates_pk(m):
member = m.Member( member = m.Member(
id=0,
first_name="Andrew", first_name="Andrew",
last_name="Brookins", last_name="Brookins",
email="a@example.com", email="a@example.com",
@ -418,6 +432,7 @@ async def test_saves_model_and_creates_pk(m):
@py_test_mark_asyncio @py_test_mark_asyncio
async def test_all_pks(m): async def test_all_pks(m):
member = m.Member( member = m.Member(
id=0,
first_name="Simon", first_name="Simon",
last_name="Prickett", last_name="Prickett",
email="s@example.com", email="s@example.com",
@ -429,6 +444,7 @@ async def test_all_pks(m):
await member.save() await member.save()
member1 = m.Member( member1 = m.Member(
id=1,
first_name="Andrew", first_name="Andrew",
last_name="Brookins", last_name="Brookins",
email="a@example.com", email="a@example.com",
@ -449,6 +465,7 @@ async def test_all_pks(m):
@py_test_mark_asyncio @py_test_mark_asyncio
async def test_delete(m): async def test_delete(m):
member = m.Member( member = m.Member(
id=0,
first_name="Simon", first_name="Simon",
last_name="Prickett", last_name="Prickett",
email="s@example.com", email="s@example.com",
@ -465,6 +482,7 @@ async def test_delete(m):
@py_test_mark_asyncio @py_test_mark_asyncio
async def test_expire(m): async def test_expire(m):
member = m.Member( member = m.Member(
id=0,
first_name="Expire", first_name="Expire",
last_name="Test", last_name="Test",
email="e@example.com", email="e@example.com",
@ -529,6 +547,7 @@ def test_raises_error_with_lists(m):
@py_test_mark_asyncio @py_test_mark_asyncio
async def test_saves_many(m): async def test_saves_many(m):
member1 = m.Member( member1 = m.Member(
id=0,
first_name="Andrew", first_name="Andrew",
last_name="Brookins", last_name="Brookins",
email="a@example.com", email="a@example.com",
@ -537,6 +556,7 @@ async def test_saves_many(m):
bio="This is the user bio.", bio="This is the user bio.",
) )
member2 = m.Member( member2 = m.Member(
id=1,
first_name="Kim", first_name="Kim",
last_name="Brookins", last_name="Brookins",
email="k@example.com", email="k@example.com",