redis-om-python/aredis_om/model/migrations/migrator.py

165 lines
5.1 KiB
Python
Raw Normal View History

import hashlib
import logging
from dataclasses import dataclass
from enum import Enum
2021-11-10 00:59:10 +01:00
from typing import List, Optional
2021-11-10 00:59:10 +01:00
from aioredis import Redis, ResponseError
log = logging.getLogger(__name__)
import importlib # noqa: E402
import pkgutil # noqa: E402
class MigrationError(Exception):
pass
def import_submodules(root_module_name: str):
"""Import all submodules of a module, recursively."""
# TODO: Call this without specifying a module name, to import everything?
root_module = importlib.import_module(root_module_name)
if not hasattr(root_module, "__path__"):
raise MigrationError(
"The root module must be a Python package. "
f"You specified: {root_module_name}"
)
for loader, module_name, is_pkg in pkgutil.walk_packages(
root_module.__path__, root_module.__name__ + "." # type: ignore
):
importlib.import_module(module_name)
def schema_hash_key(index_name):
return f"{index_name}:hash"
2021-11-10 00:59:10 +01:00
async def create_index(redis: Redis, index_name, schema, current_hash):
db_number = redis.connection_pool.connection_kwargs.get("db")
if db_number and db_number > 0:
raise MigrationError(
"Creating search indexes is only supported in database 0. "
f"You attempted to create an index in database {db_number}"
)
try:
await redis.execute_command(f"ft.info {index_name}")
except ResponseError:
await redis.execute_command(f"ft.create {index_name} {schema}")
await redis.set(schema_hash_key(index_name), current_hash)
2021-10-21 08:24:31 +02:00
else:
log.info("Index already exists, skipping. Index hash: %s", index_name)
class MigrationAction(Enum):
CREATE = 2
DROP = 1
@dataclass
class IndexMigration:
model_name: str
index_name: str
schema: str
hash: str
action: MigrationAction
2021-11-10 00:59:10 +01:00
redis: Redis
previous_hash: Optional[str] = None
async def run(self):
if self.action is MigrationAction.CREATE:
await self.create()
elif self.action is MigrationAction.DROP:
await self.drop()
async def create(self):
2021-10-21 08:24:31 +02:00
try:
await create_index(self.redis, self.index_name, self.schema, self.hash)
2021-10-21 08:24:31 +02:00
except ResponseError:
log.info("Index already exists: %s", self.index_name)
async def drop(self):
2021-10-21 08:24:31 +02:00
try:
await self.redis.execute_command(f"FT.DROPINDEX {self.index_name}")
2021-10-21 08:24:31 +02:00
except ResponseError:
log.info("Index does not exist: %s", self.index_name)
class Migrator:
2021-11-25 03:12:27 +01:00
def __init__(self, module=None):
self.module = module
2021-11-10 00:59:10 +01:00
self.migrations: List[IndexMigration] = []
2021-11-25 03:12:27 +01:00
async def detect_migrations(self):
# Try to load any modules found under the given path or module name.
if self.module:
import_submodules(self.module)
2021-11-10 20:31:02 +01:00
# Import this at run-time to avoid triggering import-time side effects,
# e.g. checks for RedisJSON, etc.
from aredis_om.model.model import model_registry
for name, cls in model_registry.items():
hash_key = schema_hash_key(cls.Meta.index_name)
2021-11-25 03:12:27 +01:00
redis = cls.db()
try:
2021-09-30 05:23:39 +02:00
schema = cls.redisearch_schema()
except NotImplementedError:
log.info("Skipping migrations for %s", name)
continue
current_hash = hashlib.sha1(schema.encode("utf-8")).hexdigest() # nosec
try:
2021-11-25 03:12:27 +01:00
await redis.execute_command("ft.info", cls.Meta.index_name)
except ResponseError:
self.migrations.append(
IndexMigration(
name,
cls.Meta.index_name,
schema,
current_hash,
MigrationAction.CREATE,
2021-11-25 03:12:27 +01:00
redis,
)
)
continue
2021-11-25 03:12:27 +01:00
stored_hash = await redis.get(hash_key)
schema_out_of_date = current_hash != stored_hash
if schema_out_of_date:
# TODO: Switch out schema with an alias to avoid downtime -- separate migration?
self.migrations.append(
IndexMigration(
name,
cls.Meta.index_name,
schema,
current_hash,
MigrationAction.DROP,
2021-11-25 03:12:27 +01:00
redis,
stored_hash,
)
)
self.migrations.append(
IndexMigration(
name,
cls.Meta.index_name,
schema,
current_hash,
MigrationAction.CREATE,
2021-11-25 03:12:27 +01:00
redis,
stored_hash,
)
)
2021-11-25 03:12:27 +01:00
async def run(self):
# TODO: Migration history
# TODO: Dry run with output
2021-11-25 03:12:27 +01:00
await self.detect_migrations()
for migration in self.migrations:
await migration.run()