From 0f9f7aa868eb2c290adb294b086fa1024ebe0638 Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Thu, 21 Oct 2021 13:12:54 -0700 Subject: [PATCH] Attempt run-time change of type annotations on model fields --- redis_developer/model/model.py | 18 ++++++++++++++++++ tests/test_json_model.py | 2 -- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/redis_developer/model/model.py b/redis_developer/model/model.py index 08a38cd..168f49c 100644 --- a/redis_developer/model/model.py +++ b/redis_developer/model/model.py @@ -976,6 +976,13 @@ class ModelMeta(ModelMetaclass): # in queries, like Model.get(Model.field_name == 1) for field_name, field in new_class.__fields__.items(): setattr(new_class, field_name, ExpressionProxy(field, [])) + annotation = new_class.get_annotations().get(field_name) + if annotation: + new_class.__annotations__[field_name] = Union[ + annotation, ExpressionProxy + ] + else: + new_class.__annotations__[field_name] = ExpressionProxy # Check if this is our FieldInfo version with extended ORM metadata. if isinstance(field.field_info, FieldInfo): if field.field_info.primary_key: @@ -1139,6 +1146,17 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): docs.append(doc) return docs + @classmethod + def get_annotations(cls): + d = {} + for c in cls.mro(): + try: + d.update(**c.__annotations__) + except AttributeError: + # object, at least, has no __annotations__ attribute. + pass + return d + @classmethod def add(cls, models: Sequence["RedisModel"]) -> Sequence["RedisModel"]: # TODO: Add transaction support diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 3fd420a..ffa6594 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -6,7 +6,6 @@ from typing import List, Optional from unittest import mock import pytest -import redis from pydantic import ValidationError from redis_developer.model import EmbeddedJsonModel, Field, JsonModel @@ -528,7 +527,6 @@ def test_not_found(m): m.Member.get(1000) -@pytest.mark.skip("Does not clean up after itself properly") def test_list_field_limitations(m): with pytest.raises(RedisModelError):