#!/usr/bin/env python3 import argparse import logging import os import re import sys from pathlib import Path import psycopg2 ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from app.core.config import settings logging.basicConfig(level=logging.INFO, format="%(message)s") logger = logging.getLogger(__name__) NUMBERED_SQL_RE = re.compile(r"^\d+.*\.sql$") def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Run BMC Hub SQL migrations against the configured PostgreSQL database." ) parser.add_argument( "files", nargs="*", help="Specific SQL files to run, relative to repo root (for example migrations/145_sag_start_date.sql).", ) parser.add_argument( "--all", action="store_true", help="Run all numbered SQL files from ./migrations in numeric order. Default when no files are provided.", ) parser.add_argument( "--module", action="append", default=[], help="Also run numbered SQL files from a module migration directory, relative to repo root.", ) parser.add_argument( "--dry-run", action="store_true", help="Print the files that would run without executing them.", ) parser.add_argument( "--stop-on-error", action="store_true", help="Stop immediately on the first migration failure.", ) return parser.parse_args() def patch_database_url_for_local_dev() -> None: if "@postgres" in settings.DATABASE_URL: logger.info("Patching DATABASE_URL for local run") settings.DATABASE_URL = settings.DATABASE_URL.replace("@postgres", "@localhost").replace(":5432", ":5433") def collect_numbered_sql(directory: Path) -> list[Path]: files = [p for p in directory.glob("*.sql") if NUMBERED_SQL_RE.match(p.name)] files.sort(key=lambda p: (int(re.match(r"^(\d+)", p.name).group(1)), p.name)) return files def resolve_explicit_files(file_args: list[str]) -> list[Path]: resolved = [] for raw in file_args: path = (ROOT / raw).resolve() if not path.exists(): raise FileNotFoundError(f"Migration file not found: {raw}") resolved.append(path) return resolved def build_file_list(args: argparse.Namespace) -> list[Path]: files: list[Path] = [] if args.files: files.extend(resolve_explicit_files(args.files)) else: files.extend(collect_numbered_sql(ROOT / "migrations")) for module_dir in args.module: path = (ROOT / module_dir).resolve() if not path.exists() or not path.is_dir(): raise FileNotFoundError(f"Module migration directory not found: {module_dir}") files.extend(collect_numbered_sql(path)) # Preserve order but remove duplicates. unique_files: list[Path] = [] seen: set[Path] = set() for path in files: if path not in seen: unique_files.append(path) seen.add(path) return unique_files def run_files(files: list[Path], dry_run: bool, stop_on_error: bool) -> int: if not files: logger.info("No migration files selected.") return 0 if dry_run: for path in files: logger.info("DRY %s", path.relative_to(ROOT)) return 0 conn = psycopg2.connect(settings.DATABASE_URL) conn.autocommit = False cur = conn.cursor() failures: list[tuple[Path, str]] = [] try: for path in files: rel = path.relative_to(ROOT) sql = path.read_text(encoding="utf-8") try: cur.execute(sql) conn.commit() logger.info("OK %s", rel) except Exception as exc: conn.rollback() message = str(exc).strip().splitlines()[0] failures.append((path, message)) logger.error("FAIL %s: %s", rel, message) if stop_on_error: break finally: cur.close() conn.close() if failures: logger.error("") logger.error("Failed migrations:") for path, message in failures: logger.error("- %s: %s", path.relative_to(ROOT), message) return 1 logger.info("") logger.info("All selected migrations completed successfully.") return 0 def main() -> int: args = parse_args() patch_database_url_for_local_dev() files = build_file_list(args) return run_files(files, dry_run=args.dry_run, stop_on_error=args.stop_on_error) if __name__ == "__main__": raise SystemExit(main())