bmc_hub/scripts/run_migrations.py

160 lines
4.6 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
import argparse
import logging
import os
import re
import sys
from pathlib import Path
import psycopg2
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from app.core.config import settings
logging.basicConfig(level=logging.INFO, format="%(message)s")
logger = logging.getLogger(__name__)
NUMBERED_SQL_RE = re.compile(r"^\d+.*\.sql$")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run BMC Hub SQL migrations against the configured PostgreSQL database."
)
parser.add_argument(
"files",
nargs="*",
help="Specific SQL files to run, relative to repo root (for example migrations/145_sag_start_date.sql).",
)
parser.add_argument(
"--all",
action="store_true",
help="Run all numbered SQL files from ./migrations in numeric order. Default when no files are provided.",
)
parser.add_argument(
"--module",
action="append",
default=[],
help="Also run numbered SQL files from a module migration directory, relative to repo root.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Print the files that would run without executing them.",
)
parser.add_argument(
"--stop-on-error",
action="store_true",
help="Stop immediately on the first migration failure.",
)
return parser.parse_args()
def patch_database_url_for_local_dev() -> None:
if "@postgres" in settings.DATABASE_URL:
logger.info("Patching DATABASE_URL for local run")
settings.DATABASE_URL = settings.DATABASE_URL.replace("@postgres", "@localhost").replace(":5432", ":5433")
def collect_numbered_sql(directory: Path) -> list[Path]:
files = [p for p in directory.glob("*.sql") if NUMBERED_SQL_RE.match(p.name)]
files.sort(key=lambda p: (int(re.match(r"^(\d+)", p.name).group(1)), p.name))
return files
def resolve_explicit_files(file_args: list[str]) -> list[Path]:
resolved = []
for raw in file_args:
path = (ROOT / raw).resolve()
if not path.exists():
raise FileNotFoundError(f"Migration file not found: {raw}")
resolved.append(path)
return resolved
def build_file_list(args: argparse.Namespace) -> list[Path]:
files: list[Path] = []
if args.files:
files.extend(resolve_explicit_files(args.files))
else:
files.extend(collect_numbered_sql(ROOT / "migrations"))
for module_dir in args.module:
path = (ROOT / module_dir).resolve()
if not path.exists() or not path.is_dir():
raise FileNotFoundError(f"Module migration directory not found: {module_dir}")
files.extend(collect_numbered_sql(path))
# Preserve order but remove duplicates.
unique_files: list[Path] = []
seen: set[Path] = set()
for path in files:
if path not in seen:
unique_files.append(path)
seen.add(path)
return unique_files
def run_files(files: list[Path], dry_run: bool, stop_on_error: bool) -> int:
if not files:
logger.info("No migration files selected.")
return 0
if dry_run:
for path in files:
logger.info("DRY %s", path.relative_to(ROOT))
return 0
conn = psycopg2.connect(settings.DATABASE_URL)
conn.autocommit = False
cur = conn.cursor()
failures: list[tuple[Path, str]] = []
try:
for path in files:
rel = path.relative_to(ROOT)
sql = path.read_text(encoding="utf-8")
try:
cur.execute(sql)
conn.commit()
logger.info("OK %s", rel)
except Exception as exc:
conn.rollback()
message = str(exc).strip().splitlines()[0]
failures.append((path, message))
logger.error("FAIL %s: %s", rel, message)
if stop_on_error:
break
finally:
cur.close()
conn.close()
if failures:
logger.error("")
logger.error("Failed migrations:")
for path, message in failures:
logger.error("- %s: %s", path.relative_to(ROOT), message)
return 1
logger.info("")
logger.info("All selected migrations completed successfully.")
return 0
def main() -> int:
args = parse_args()
patch_database_url_for_local_dev()
files = build_file_list(args)
return run_files(files, dry_run=args.dry_run, stop_on_error=args.stop_on_error)
if __name__ == "__main__":
raise SystemExit(main())