#!/usr/bin/env python3 import argparse import json import logging import re import sys from pathlib import Path import psycopg2 from psycopg2.extras import RealDictCursor ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from app.core.config import settings logging.basicConfig(level=logging.INFO, format="%(message)s") logger = logging.getLogger(__name__) NUMBERED_SQL_RE = re.compile(r"^\d+.*\.sql$") CREATE_TABLE_RE = re.compile( r"CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?([A-Za-z_][A-Za-z0-9_]*)\s*\(", re.IGNORECASE, ) ADD_COLUMN_RE = re.compile( r"ALTER\s+TABLE\s+(?:IF\s+EXISTS\s+)?([A-Za-z_][A-Za-z0-9_]*)\s+ADD\s+COLUMN\s+(?:IF\s+NOT\s+EXISTS\s+)?([A-Za-z_][A-Za-z0-9_]*)", re.IGNORECASE, ) CREATE_INDEX_RE = re.compile( r"CREATE\s+(?:UNIQUE\s+)?INDEX\s+(?:IF\s+NOT\s+EXISTS\s+)?([A-Za-z_][A-Za-z0-9_]*)\s+ON\s+([A-Za-z_][A-Za-z0-9_]*)", re.IGNORECASE, ) SKIP_COLUMN_LINE_RE = re.compile( r"^(?:CONSTRAINT|PRIMARY\s+KEY|FOREIGN\s+KEY|UNIQUE|CHECK|CASE|WHEN|ELSE|END)\b", re.IGNORECASE, ) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Validate BMC Hub SQL migrations against current PostgreSQL schema." ) parser.add_argument( "--json", action="store_true", help="Output report in JSON format.", ) parser.add_argument( "--strict-indexes", action="store_true", default=False, help="Treat missing indexes as failure (default: False).", ) parser.add_argument( "--module", action="append", default=[], help="Also parse numbered SQL files from a module migration directory, relative to repo root.", ) return parser.parse_args() def patch_database_url_for_local_dev() -> None: if "@postgres" in settings.DATABASE_URL: logger.info("Patching DATABASE_URL for local run") settings.DATABASE_URL = settings.DATABASE_URL.replace("@postgres", "@localhost").replace(":5432", ":5433") def collect_numbered_sql(directory: Path) -> list[Path]: files = [p for p in directory.glob("*.sql") if NUMBERED_SQL_RE.match(p.name)] def _sort_key(path: Path) -> tuple[int, str]: match = re.match(r"^(\d+)", path.name) prefix = int(match.group(1)) if match else 0 return (prefix, path.name) files.sort(key=_sort_key) return files def build_file_list(args: argparse.Namespace) -> list[Path]: files: list[Path] = [] root_migrations = ROOT / "migrations" files.extend(collect_numbered_sql(root_migrations)) for module_dir in args.module: path = (ROOT / module_dir).resolve() if not path.exists() or not path.is_dir(): raise FileNotFoundError(f"Module migration directory not found: {module_dir}") files.extend(collect_numbered_sql(path)) unique_files: list[Path] = [] seen: set[Path] = set() for path in files: if path not in seen: unique_files.append(path) seen.add(path) return unique_files def strip_sql_comments(sql: str) -> str: sql = re.sub(r"/\*.*?\*/", "", sql, flags=re.DOTALL) sql = re.sub(r"--[^\n]*", "", sql) return sql def extract_create_table_block(sql: str, table_name: str, start_pos: int) -> str: """Return create-table body between the first opening '(' and matching ')'.""" open_paren = sql.find("(", start_pos) if open_paren == -1: return "" depth = 0 for i in range(open_paren, len(sql)): ch = sql[i] if ch == "(": depth += 1 elif ch == ")": depth -= 1 if depth == 0: return sql[open_paren + 1:i] return "" def parse_columns_from_create_block(block: str) -> set[str]: columns: set[str] = set() known_types = { "serial", "bigserial", "smallint", "integer", "bigint", "numeric", "decimal", "real", "double", "varchar", "character", "text", "boolean", "bool", "date", "timestamp", "time", "json", "jsonb", "uuid" } for raw_line in block.splitlines(): line = raw_line.strip().rstrip(",") if not line: continue if SKIP_COLUMN_LINE_RE.match(line): continue tokens = line.replace("(", " ").split() if len(tokens) < 2: continue second = tokens[1].strip().lower() second_base = re.sub(r"[^a-z]", "", second) if second_base and second_base not in known_types: continue match = re.match(r"^\"?([A-Za-z_][A-Za-z0-9_]*)\"?\s+", line) if match: columns.add(match.group(1)) return columns def parse_expected_schema(files: list[Path]) -> tuple[dict[str, set[str]], set[str]]: expected_tables: dict[str, set[str]] = {} expected_indexes: set[str] = set() for path in files: sql = strip_sql_comments(path.read_text(encoding="utf-8")) for match in CREATE_TABLE_RE.finditer(sql): table_name = match.group(1) if table_name not in expected_tables: expected_tables[table_name] = set() block = extract_create_table_block(sql, table_name, match.end() - 1) expected_tables[table_name].update(parse_columns_from_create_block(block)) for match in ADD_COLUMN_RE.finditer(sql): table_name = match.group(1) column_name = match.group(2) if table_name not in expected_tables: expected_tables[table_name] = set() expected_tables[table_name].add(column_name) for match in CREATE_INDEX_RE.finditer(sql): index_name = match.group(1) expected_indexes.add(index_name) return expected_tables, expected_indexes def query_actual_schema() -> tuple[set[str], dict[str, set[str]], set[str]]: conn = psycopg2.connect(settings.DATABASE_URL) conn.autocommit = True try: with conn.cursor(cursor_factory=RealDictCursor) as cur: cur.execute( """ SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE' ORDER BY table_name """ ) actual_tables = {row["table_name"] for row in cur.fetchall()} cur.execute( """ SELECT table_name, column_name FROM information_schema.columns WHERE table_schema = 'public' ORDER BY table_name, ordinal_position """ ) actual_columns: dict[str, set[str]] = {} for row in cur.fetchall(): table_name = row["table_name"] if table_name not in actual_columns: actual_columns[table_name] = set() actual_columns[table_name].add(row["column_name"]) cur.execute( """ SELECT indexname FROM pg_indexes WHERE schemaname = 'public' ORDER BY indexname """ ) actual_indexes = {row["indexname"] for row in cur.fetchall()} return actual_tables, actual_columns, actual_indexes finally: conn.close() def compare_expected_vs_actual( expected_tables: dict[str, set[str]], expected_indexes: set[str], actual_tables: set[str], actual_columns: dict[str, set[str]], actual_indexes: set[str], ) -> dict[str, list[str]]: missing_tables: list[str] = [] missing_columns: list[str] = [] missing_indexes: list[str] = [] for table_name in sorted(expected_tables.keys()): if table_name not in actual_tables: missing_tables.append(table_name) continue expected_cols = expected_tables.get(table_name, set()) current_cols = actual_columns.get(table_name, set()) for col in sorted(expected_cols): if col not in current_cols: missing_columns.append(f"{table_name}.{col}") for index_name in sorted(expected_indexes): if index_name not in actual_indexes: missing_indexes.append(index_name) return { "missing_tables": missing_tables, "missing_columns": missing_columns, "missing_indexes": missing_indexes, } def print_report(report: dict[str, list[str]]) -> None: if not report["missing_tables"] and not report["missing_columns"] and not report["missing_indexes"]: logger.info("Schema validation OK: no mismatches found.") return if report["missing_tables"]: logger.error("Missing tables:") for table_name in report["missing_tables"]: logger.error("- %s", table_name) if report["missing_columns"]: logger.error("Missing columns:") for column_name in report["missing_columns"]: logger.error("- %s", column_name) if report["missing_indexes"]: logger.warning("Missing indexes:") for index_name in report["missing_indexes"]: logger.warning("- %s", index_name) def determine_exit_code(report: dict[str, list[str]], strict_indexes: bool) -> int: has_table_or_column_mismatch = bool(report["missing_tables"] or report["missing_columns"]) has_index_mismatch = bool(report["missing_indexes"]) if has_table_or_column_mismatch: return 1 if strict_indexes and has_index_mismatch: return 1 return 0 def main() -> int: args = parse_args() try: patch_database_url_for_local_dev() files = build_file_list(args) expected_tables, expected_indexes = parse_expected_schema(files) actual_tables, actual_columns, actual_indexes = query_actual_schema() report = compare_expected_vs_actual( expected_tables, expected_indexes, actual_tables, actual_columns, actual_indexes, ) if args.json: payload = { "status": "ok" if determine_exit_code(report, args.strict_indexes) == 0 else "mismatch", "strict_indexes": args.strict_indexes, "parsed_files": [str(path.relative_to(ROOT)) for path in files], "missing_tables": report["missing_tables"], "missing_columns": report["missing_columns"], "missing_indexes": report["missing_indexes"], } print(json.dumps(payload, ensure_ascii=True)) else: logger.info("Parsed migration files: %s", len(files)) print_report(report) return determine_exit_code(report, args.strict_indexes) except Exception as exc: if args.json: print(json.dumps({"status": "error", "message": str(exc)}, ensure_ascii=True)) else: logger.error("Validation failed: %s", exc) return 2 if __name__ == "__main__": raise SystemExit(main())