"""Reusable Alembic env.py helper for per-service migrations. Each service calls ``run_alembic(config, model_modules, table_names)`` from its own ``alembic/env.py``. The helper: * Imports only the requested model modules (so ``Base.metadata`` sees only the tables that belong to the service). * Uses an ``include_name`` callback to filter ``CREATE TABLE`` to only the service's tables (belt-and-suspenders on top of the import filter). * Reads ``ALEMBIC_DATABASE_URL`` for the connection string. """ from __future__ import annotations import importlib import os import sys from typing import Sequence from alembic import context from sqlalchemy import engine_from_config, pool def run_alembic( config, model_modules: Sequence[str], table_names: frozenset[str], ) -> None: """Run Alembic migrations filtered to *table_names*. Parameters ---------- config: The ``alembic.config.Config`` instance (``context.config``). model_modules: Dotted module paths to import so that ``Base.metadata`` is populated (e.g. ``["shared.models.user", "blog.models"]``). table_names: The set of table names this service owns. Only these tables will be created / altered / dropped. """ # Ensure project root is importable project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) if project_root not in sys.path: sys.path.insert(0, project_root) # Import models so Base.metadata sees the tables for mod in model_modules: try: importlib.import_module(mod) except ImportError: pass # OK in Docker images that don't ship sibling apps from shared.db.base import Base target_metadata = Base.metadata # ---- include_name filter ------------------------------------------------ def _include_name(name, type_, parent_names): if type_ == "table": return name in table_names # Always include indexes/constraints that belong to included tables return True # ---- connection URL ----------------------------------------------------- def _get_url() -> str: return os.getenv( "ALEMBIC_DATABASE_URL", os.getenv("DATABASE_URL", config.get_main_option("sqlalchemy.url") or ""), ) # ---- offline / online --------------------------------------------------- if context.is_offline_mode(): context.configure( url=_get_url(), target_metadata=target_metadata, literal_binds=True, dialect_opts={"paramstyle": "named"}, compare_type=True, include_name=_include_name, ) with context.begin_transaction(): context.run_migrations() else: url = _get_url() if url: config.set_main_option("sqlalchemy.url", url) connectable = engine_from_config( config.get_section(config.config_ini_section, {}), prefix="sqlalchemy.", poolclass=pool.NullPool, ) with connectable.connect() as connection: context.configure( connection=connection, target_metadata=target_metadata, compare_type=True, include_name=_include_name, ) with context.begin_transaction(): context.run_migrations()