redis-om-python/tests/test_hash_model.py
2022-03-29 22:15:05 +01:00

544 lines
15 KiB
Python

# type: ignore
import abc
import dataclasses
import datetime
import decimal
from collections import namedtuple
from typing import Dict, List, Optional, Set
from unittest import mock
import pytest
from pydantic import ValidationError
from aredis_om import (
Field,
HashModel,
Migrator,
QueryNotSupportedError,
RedisModelError,
)
# We need to run this check as sync code (during tests) even in async mode
# because we call it in the top-level module scope.
from redis_om import has_redisearch
if not has_redisearch():
pytestmark = pytest.mark.skip
today = datetime.date.today()
@pytest.fixture
async def m(key_prefix, redis):
class BaseHashModel(HashModel, abc.ABC):
class Meta:
global_key_prefix = key_prefix
class Order(BaseHashModel):
total: decimal.Decimal
currency: str
created_on: datetime.datetime
class Member(BaseHashModel):
first_name: str = Field(index=True)
last_name: str = Field(index=True)
email: str = Field(index=True)
join_date: datetime.date
age: int = Field(index=True)
bio: str = Field(index=True, full_text_search=True)
class Meta:
model_key_prefix = "member"
primary_key_pattern = ""
await Migrator().run()
return namedtuple("Models", ["BaseHashModel", "Order", "Member"])(
BaseHashModel, Order, Member
)
@pytest.fixture
async def members(m):
member1 = m.Member(
first_name="Andrew",
last_name="Brookins",
email="a@example.com",
age=38,
join_date=today,
bio="This is member 1 whose greatness makes him the life and soul of any party he goes to.",
)
member2 = m.Member(
first_name="Kim",
last_name="Brookins",
email="k@example.com",
age=34,
join_date=today,
bio="This is member 2 who can be quite anxious until you get to know them.",
)
member3 = m.Member(
first_name="Andrew",
last_name="Smith",
email="as@example.com",
age=100,
join_date=today,
bio="This is member 3 who is a funny and lively sort of person.",
)
await member1.save()
await member2.save()
await member3.save()
yield member1, member2, member3
@pytest.mark.asyncio
async def test_exact_match_queries(members, m):
member1, member2, member3 = members
actual = await m.Member.find(m.Member.last_name == "Brookins").sort_by("age").all()
assert actual == [member2, member1]
actual = await m.Member.find(
(m.Member.last_name == "Brookins") & ~(m.Member.first_name == "Andrew")
).all()
assert actual == [member2]
actual = await m.Member.find(~(m.Member.last_name == "Brookins")).all()
assert actual == [member3]
actual = await m.Member.find(m.Member.last_name != "Brookins").all()
assert actual == [member3]
actual = await (
m.Member.find(
(m.Member.last_name == "Brookins") & (m.Member.first_name == "Andrew")
| (m.Member.first_name == "Kim")
)
.sort_by("age")
.all()
)
assert actual == [member2, member1]
actual = await m.Member.find(
m.Member.first_name == "Kim", m.Member.last_name == "Brookins"
).all()
assert actual == [member2]
@pytest.mark.asyncio
async def test_full_text_search_queries(members, m):
member1, member2, member3 = members
actual = await (m.Member.find(m.Member.bio % "great").all())
assert actual == [member1]
actual = await (m.Member.find(~(m.Member.bio % "anxious")).all())
assert actual == [member1, member3]
@pytest.mark.asyncio
async def test_recursive_query_resolution(members, m):
member1, member2, member3 = members
actual = await (
m.Member.find(
(m.Member.last_name == "Brookins")
| (m.Member.age == 100) & (m.Member.last_name == "Smith")
)
.sort_by("age")
.all()
)
assert actual == [member2, member1, member3]
@pytest.mark.asyncio
async def test_tag_queries_boolean_logic(members, m):
member1, member2, member3 = members
actual = await (
m.Member.find(
(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
| (m.Member.last_name == "Smith")
)
.sort_by("age")
.all()
)
assert actual == [member1, member3]
@pytest.mark.asyncio
async def test_tag_queries_punctuation(m):
member1 = m.Member(
first_name="Andrew, the Michael",
last_name="St. Brookins-on-Pier",
email="a|b@example.com", # NOTE: This string uses the TAG field separator.
age=38,
join_date=today,
bio="This is a test user on our system.",
)
await member1.save()
member2 = m.Member(
first_name="Bob",
last_name="the Villain",
email="a|villain@example.com", # NOTE: This string uses the TAG field separator.
age=38,
join_date=today,
bio="This is a villain, they are a really bad person!",
)
await member2.save()
result = await (m.Member.find(m.Member.first_name == "Andrew, the Michael").first())
assert result == member1
result = await (m.Member.find(m.Member.last_name == "St. Brookins-on-Pier").first())
assert result == member1
# Notice that when we index and query multiple values that use the internal
# TAG separator for single-value exact-match fields, like an indexed string,
# the queries will succeed. We apply a workaround that queries for the union
# of the two values separated by the tag separator.
results = await m.Member.find(m.Member.email == "a|b@example.com").all()
assert results == [member1]
results = await m.Member.find(m.Member.email == "a|villain@example.com").all()
assert results == [member2]
@pytest.mark.asyncio
async def test_tag_queries_negation(members, m):
member1, member2, member3 = members
"""
┌first_name
NOT EQ┤
└Andrew
"""
query = m.Member.find(~(m.Member.first_name == "Andrew"))
assert await query.all() == [member2]
"""
┌first_name
┌NOT EQ┤
| └Andrew
AND┤
| ┌last_name
└EQ┤
└Brookins
"""
query = m.Member.find(
~(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
)
assert await query.all() == [member2]
"""
┌first_name
┌NOT EQ┤
| └Andrew
AND┤
| ┌last_name
| ┌EQ┤
| | └Brookins
└OR┤
| ┌last_name
└EQ┤
└Smith
"""
query = m.Member.find(
~(m.Member.first_name == "Andrew")
& ((m.Member.last_name == "Brookins") | (m.Member.last_name == "Smith"))
)
assert await query.all() == [member2]
"""
┌first_name
┌NOT EQ┤
| └Andrew
┌AND┤
| | ┌last_name
| └EQ┤
| └Brookins
OR┤
| ┌last_name
└EQ┤
└Smith
"""
query = m.Member.find(
~(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
| (m.Member.last_name == "Smith")
)
assert await query.sort_by("age").all() == [member2, member3]
actual = await m.Member.find(
(m.Member.first_name == "Andrew") & ~(m.Member.last_name == "Brookins")
).all()
assert actual == [member3]
@pytest.mark.asyncio
async def test_numeric_queries(members, m):
member1, member2, member3 = members
actual = await m.Member.find(m.Member.age == 34).all()
assert actual == [member2]
actual = await m.Member.find(m.Member.age > 34).sort_by("age").all()
assert actual == [member1, member3]
actual = await m.Member.find(m.Member.age < 35).all()
assert actual == [member2]
actual = await m.Member.find(m.Member.age <= 34).all()
assert actual == [member2]
actual = await m.Member.find(m.Member.age >= 100).all()
assert actual == [member3]
actual = await m.Member.find(m.Member.age != 34).sort_by("age").all()
assert actual == [member1, member3]
actual = await m.Member.find(~(m.Member.age == 100)).sort_by("age").all()
assert actual == [member2, member1]
actual = (
await m.Member.find(m.Member.age > 30, m.Member.age < 40).sort_by("age").all()
)
assert actual == [member2, member1]
@pytest.mark.asyncio
async def test_sorting(members, m):
member1, member2, member3 = members
actual = await m.Member.find(m.Member.age > 34).sort_by("age").all()
assert actual == [member1, member3]
actual = await m.Member.find(m.Member.age > 34).sort_by("-age").all()
assert actual == [member3, member1]
with pytest.raises(QueryNotSupportedError):
# This field does not exist.
await m.Member.find().sort_by("not-a-real-field").all()
with pytest.raises(QueryNotSupportedError):
# This field is not sortable.
await m.Member.find().sort_by("join_date").all()
def test_validates_required_fields(m):
# Raises ValidationError: last_name is required
# TODO: Test the error value
with pytest.raises(ValidationError):
m.Member(first_name="Andrew", zipcode="97086", join_date=today)
def test_validates_field(m):
# Raises ValidationError: join_date is not a date
# TODO: Test the error value
with pytest.raises(ValidationError):
m.Member(first_name="Andrew", last_name="Brookins", join_date="yesterday")
def test_validation_passes(m):
member = m.Member(
first_name="Andrew",
last_name="Brookins",
email="a@example.com",
join_date=today,
age=38,
bio="This is the bio field.",
)
assert member.first_name == "Andrew"
@pytest.mark.asyncio
async def test_saves_model_and_creates_pk(m):
member = m.Member(
first_name="Andrew",
last_name="Brookins",
email="a@example.com",
join_date=today,
age=38,
bio="This is the bio field for this user.",
)
# Save a model instance to Redis
await member.save()
member2 = await m.Member.get(member.pk)
assert member2 == member
@pytest.mark.asyncio
async def test_delete(m):
member = m.Member(
first_name="Simon",
last_name="Prickett",
email="s@example.com",
join_date=today,
age=97,
bio="This is a test user to be deleted.",
)
await member.save()
response = await m.Member.delete(member.pk)
assert response == 1
@pytest.mark.asyncio
async def test_expire(m):
member = m.Member(
first_name="Expire",
last_name="Test",
email="e@example.com",
join_date=today,
age=93,
bio="This is a test user for expiry",
)
await member.save()
await member.expire(60)
ttl = await m.Member.db().ttl(member.key())
assert ttl > 0
def test_raises_error_with_embedded_models(m):
class Address(m.BaseHashModel):
address_line_1: str
address_line_2: Optional[str]
city: str
country: str
postal_code: str
with pytest.raises(RedisModelError):
class InvalidMember(m.BaseHashModel):
address: Address
def test_raises_error_with_dataclasses(m):
@dataclasses.dataclass
class Address:
address_line_1: str
with pytest.raises(RedisModelError):
class InvalidMember(m.BaseHashModel):
address: Address
def test_raises_error_with_dicts(m):
with pytest.raises(RedisModelError):
class InvalidMember(m.BaseHashModel):
address: Dict[str, str]
def test_raises_error_with_sets(m):
with pytest.raises(RedisModelError):
class InvalidMember(m.BaseHashModel):
friend_ids: Set[str]
def test_raises_error_with_lists(m):
with pytest.raises(RedisModelError):
class InvalidMember(m.BaseHashModel):
friend_ids: List[str]
@pytest.mark.asyncio
async def test_saves_many(m):
member1 = m.Member(
first_name="Andrew",
last_name="Brookins",
email="a@example.com",
join_date=today,
age=38,
bio="This is the user bio.",
)
member2 = m.Member(
first_name="Kim",
last_name="Brookins",
email="k@example.com",
join_date=today,
age=34,
bio="This is the bio for Kim.",
)
members = [member1, member2]
result = await m.Member.add(members)
assert result == [member1, member2]
assert await m.Member.get(pk=member1.pk) == member1
assert await m.Member.get(pk=member2.pk) == member2
@pytest.mark.asyncio
async def test_updates_a_model(members, m):
member1, member2, member3 = members
await member1.update(last_name="Smith")
member = await m.Member.get(member1.pk)
assert member.last_name == "Smith"
@pytest.mark.asyncio
async def test_paginate_query(members, m):
member1, member2, member3 = members
actual = await m.Member.find().sort_by("age").all(batch_size=1)
assert actual == [member2, member1, member3]
@pytest.mark.asyncio
async def test_access_result_by_index_cached(members, m):
member1, member2, member3 = members
query = m.Member.find().sort_by("age")
# Load the cache, throw away the result.
assert query._model_cache == []
await query.execute()
assert query._model_cache == [member2, member1, member3]
# Access an item that should be in the cache.
with mock.patch.object(query.model, "db") as mock_db:
assert await query.get_item(0) == member2
assert not mock_db.called
@pytest.mark.asyncio
async def test_access_result_by_index_not_cached(members, m):
member1, member2, member3 = members
query = m.Member.find().sort_by("age")
# Assert that we don't have any models in the cache yet -- we
# haven't made any requests of Redis.
assert query._model_cache == []
assert await query.get_item(0) == member2
assert await query.get_item(1) == member1
assert await query.get_item(2) == member3
def test_schema(m):
class Address(m.BaseHashModel):
a_string: str = Field(index=True)
a_full_text_string: str = Field(index=True, full_text_search=True)
an_integer: int = Field(index=True, sortable=True)
a_float: float = Field(index=True)
another_integer: int
another_float: float
# We need to build the key prefix because it will differ based on whether
# these tests were copied into the tests_sync folder and unasynce'd.
key_prefix = Address.make_key(Address._meta.primary_key_pattern.format(pk=""))
assert (
Address.redisearch_schema()
== f"ON HASH PREFIX 1 {key_prefix} SCHEMA pk TAG SEPARATOR | a_string TAG SEPARATOR | a_full_text_string TAG SEPARATOR | a_full_text_string AS a_full_text_string_fts TEXT an_integer NUMERIC SORTABLE a_float NUMERIC"
)