341 lines
11 KiB
Python
341 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import re
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import psycopg2
|
|
from psycopg2.extras import RealDictCursor
|
|
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
if str(ROOT) not in sys.path:
|
|
sys.path.insert(0, str(ROOT))
|
|
|
|
from app.core.config import settings
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
|
logger = logging.getLogger(__name__)
|
|
|
|
NUMBERED_SQL_RE = re.compile(r"^\d+.*\.sql$")
|
|
CREATE_TABLE_RE = re.compile(
|
|
r"CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?([A-Za-z_][A-Za-z0-9_]*)\s*\(",
|
|
re.IGNORECASE,
|
|
)
|
|
ADD_COLUMN_RE = re.compile(
|
|
r"ALTER\s+TABLE\s+(?:IF\s+EXISTS\s+)?([A-Za-z_][A-Za-z0-9_]*)\s+ADD\s+COLUMN\s+(?:IF\s+NOT\s+EXISTS\s+)?([A-Za-z_][A-Za-z0-9_]*)",
|
|
re.IGNORECASE,
|
|
)
|
|
CREATE_INDEX_RE = re.compile(
|
|
r"CREATE\s+(?:UNIQUE\s+)?INDEX\s+(?:IF\s+NOT\s+EXISTS\s+)?([A-Za-z_][A-Za-z0-9_]*)\s+ON\s+([A-Za-z_][A-Za-z0-9_]*)",
|
|
re.IGNORECASE,
|
|
)
|
|
SKIP_COLUMN_LINE_RE = re.compile(
|
|
r"^(?:CONSTRAINT|PRIMARY\s+KEY|FOREIGN\s+KEY|UNIQUE|CHECK|CASE|WHEN|ELSE|END)\b",
|
|
re.IGNORECASE,
|
|
)
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(
|
|
description="Validate BMC Hub SQL migrations against current PostgreSQL schema."
|
|
)
|
|
parser.add_argument(
|
|
"--json",
|
|
action="store_true",
|
|
help="Output report in JSON format.",
|
|
)
|
|
parser.add_argument(
|
|
"--strict-indexes",
|
|
action="store_true",
|
|
default=False,
|
|
help="Treat missing indexes as failure (default: False).",
|
|
)
|
|
parser.add_argument(
|
|
"--module",
|
|
action="append",
|
|
default=[],
|
|
help="Also parse numbered SQL files from a module migration directory, relative to repo root.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def patch_database_url_for_local_dev() -> None:
|
|
if "@postgres" in settings.DATABASE_URL:
|
|
logger.info("Patching DATABASE_URL for local run")
|
|
settings.DATABASE_URL = settings.DATABASE_URL.replace("@postgres", "@localhost").replace(":5432", ":5433")
|
|
|
|
|
|
def collect_numbered_sql(directory: Path) -> list[Path]:
|
|
files = [p for p in directory.glob("*.sql") if NUMBERED_SQL_RE.match(p.name)]
|
|
|
|
def _sort_key(path: Path) -> tuple[int, str]:
|
|
match = re.match(r"^(\d+)", path.name)
|
|
prefix = int(match.group(1)) if match else 0
|
|
return (prefix, path.name)
|
|
|
|
files.sort(key=_sort_key)
|
|
return files
|
|
|
|
|
|
def build_file_list(args: argparse.Namespace) -> list[Path]:
|
|
files: list[Path] = []
|
|
|
|
root_migrations = ROOT / "migrations"
|
|
files.extend(collect_numbered_sql(root_migrations))
|
|
|
|
for module_dir in args.module:
|
|
path = (ROOT / module_dir).resolve()
|
|
if not path.exists() or not path.is_dir():
|
|
raise FileNotFoundError(f"Module migration directory not found: {module_dir}")
|
|
files.extend(collect_numbered_sql(path))
|
|
|
|
unique_files: list[Path] = []
|
|
seen: set[Path] = set()
|
|
for path in files:
|
|
if path not in seen:
|
|
unique_files.append(path)
|
|
seen.add(path)
|
|
|
|
return unique_files
|
|
|
|
|
|
def strip_sql_comments(sql: str) -> str:
|
|
sql = re.sub(r"/\*.*?\*/", "", sql, flags=re.DOTALL)
|
|
sql = re.sub(r"--[^\n]*", "", sql)
|
|
return sql
|
|
|
|
|
|
def extract_create_table_block(sql: str, table_name: str, start_pos: int) -> str:
|
|
"""Return create-table body between the first opening '(' and matching ')'."""
|
|
open_paren = sql.find("(", start_pos)
|
|
if open_paren == -1:
|
|
return ""
|
|
|
|
depth = 0
|
|
for i in range(open_paren, len(sql)):
|
|
ch = sql[i]
|
|
if ch == "(":
|
|
depth += 1
|
|
elif ch == ")":
|
|
depth -= 1
|
|
if depth == 0:
|
|
return sql[open_paren + 1:i]
|
|
return ""
|
|
|
|
|
|
def parse_columns_from_create_block(block: str) -> set[str]:
|
|
columns: set[str] = set()
|
|
known_types = {
|
|
"serial", "bigserial", "smallint", "integer", "bigint", "numeric", "decimal", "real", "double",
|
|
"varchar", "character", "text", "boolean", "bool", "date", "timestamp", "time", "json", "jsonb", "uuid"
|
|
}
|
|
for raw_line in block.splitlines():
|
|
line = raw_line.strip().rstrip(",")
|
|
if not line:
|
|
continue
|
|
if SKIP_COLUMN_LINE_RE.match(line):
|
|
continue
|
|
|
|
tokens = line.replace("(", " ").split()
|
|
if len(tokens) < 2:
|
|
continue
|
|
|
|
second = tokens[1].strip().lower()
|
|
second_base = re.sub(r"[^a-z]", "", second)
|
|
if second_base and second_base not in known_types:
|
|
continue
|
|
|
|
match = re.match(r"^\"?([A-Za-z_][A-Za-z0-9_]*)\"?\s+", line)
|
|
if match:
|
|
columns.add(match.group(1))
|
|
return columns
|
|
|
|
|
|
def parse_expected_schema(files: list[Path]) -> tuple[dict[str, set[str]], set[str]]:
|
|
expected_tables: dict[str, set[str]] = {}
|
|
expected_indexes: set[str] = set()
|
|
|
|
for path in files:
|
|
sql = strip_sql_comments(path.read_text(encoding="utf-8"))
|
|
|
|
for match in CREATE_TABLE_RE.finditer(sql):
|
|
table_name = match.group(1)
|
|
if table_name not in expected_tables:
|
|
expected_tables[table_name] = set()
|
|
|
|
block = extract_create_table_block(sql, table_name, match.end() - 1)
|
|
expected_tables[table_name].update(parse_columns_from_create_block(block))
|
|
|
|
for match in ADD_COLUMN_RE.finditer(sql):
|
|
table_name = match.group(1)
|
|
column_name = match.group(2)
|
|
if table_name not in expected_tables:
|
|
expected_tables[table_name] = set()
|
|
expected_tables[table_name].add(column_name)
|
|
|
|
for match in CREATE_INDEX_RE.finditer(sql):
|
|
index_name = match.group(1)
|
|
expected_indexes.add(index_name)
|
|
|
|
return expected_tables, expected_indexes
|
|
|
|
|
|
def query_actual_schema() -> tuple[set[str], dict[str, set[str]], set[str]]:
|
|
conn = psycopg2.connect(settings.DATABASE_URL)
|
|
conn.autocommit = True
|
|
|
|
try:
|
|
with conn.cursor(cursor_factory=RealDictCursor) as cur:
|
|
cur.execute(
|
|
"""
|
|
SELECT table_name
|
|
FROM information_schema.tables
|
|
WHERE table_schema = 'public'
|
|
AND table_type = 'BASE TABLE'
|
|
ORDER BY table_name
|
|
"""
|
|
)
|
|
actual_tables = {row["table_name"] for row in cur.fetchall()}
|
|
|
|
cur.execute(
|
|
"""
|
|
SELECT table_name, column_name
|
|
FROM information_schema.columns
|
|
WHERE table_schema = 'public'
|
|
ORDER BY table_name, ordinal_position
|
|
"""
|
|
)
|
|
actual_columns: dict[str, set[str]] = {}
|
|
for row in cur.fetchall():
|
|
table_name = row["table_name"]
|
|
if table_name not in actual_columns:
|
|
actual_columns[table_name] = set()
|
|
actual_columns[table_name].add(row["column_name"])
|
|
|
|
cur.execute(
|
|
"""
|
|
SELECT indexname
|
|
FROM pg_indexes
|
|
WHERE schemaname = 'public'
|
|
ORDER BY indexname
|
|
"""
|
|
)
|
|
actual_indexes = {row["indexname"] for row in cur.fetchall()}
|
|
|
|
return actual_tables, actual_columns, actual_indexes
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def compare_expected_vs_actual(
|
|
expected_tables: dict[str, set[str]],
|
|
expected_indexes: set[str],
|
|
actual_tables: set[str],
|
|
actual_columns: dict[str, set[str]],
|
|
actual_indexes: set[str],
|
|
) -> dict[str, list[str]]:
|
|
missing_tables: list[str] = []
|
|
missing_columns: list[str] = []
|
|
missing_indexes: list[str] = []
|
|
|
|
for table_name in sorted(expected_tables.keys()):
|
|
if table_name not in actual_tables:
|
|
missing_tables.append(table_name)
|
|
continue
|
|
|
|
expected_cols = expected_tables.get(table_name, set())
|
|
current_cols = actual_columns.get(table_name, set())
|
|
for col in sorted(expected_cols):
|
|
if col not in current_cols:
|
|
missing_columns.append(f"{table_name}.{col}")
|
|
|
|
for index_name in sorted(expected_indexes):
|
|
if index_name not in actual_indexes:
|
|
missing_indexes.append(index_name)
|
|
|
|
return {
|
|
"missing_tables": missing_tables,
|
|
"missing_columns": missing_columns,
|
|
"missing_indexes": missing_indexes,
|
|
}
|
|
|
|
|
|
def print_report(report: dict[str, list[str]]) -> None:
|
|
if not report["missing_tables"] and not report["missing_columns"] and not report["missing_indexes"]:
|
|
logger.info("Schema validation OK: no mismatches found.")
|
|
return
|
|
|
|
if report["missing_tables"]:
|
|
logger.error("Missing tables:")
|
|
for table_name in report["missing_tables"]:
|
|
logger.error("- %s", table_name)
|
|
|
|
if report["missing_columns"]:
|
|
logger.error("Missing columns:")
|
|
for column_name in report["missing_columns"]:
|
|
logger.error("- %s", column_name)
|
|
|
|
if report["missing_indexes"]:
|
|
logger.warning("Missing indexes:")
|
|
for index_name in report["missing_indexes"]:
|
|
logger.warning("- %s", index_name)
|
|
|
|
|
|
def determine_exit_code(report: dict[str, list[str]], strict_indexes: bool) -> int:
|
|
has_table_or_column_mismatch = bool(report["missing_tables"] or report["missing_columns"])
|
|
has_index_mismatch = bool(report["missing_indexes"])
|
|
|
|
if has_table_or_column_mismatch:
|
|
return 1
|
|
if strict_indexes and has_index_mismatch:
|
|
return 1
|
|
return 0
|
|
|
|
|
|
def main() -> int:
|
|
args = parse_args()
|
|
|
|
try:
|
|
patch_database_url_for_local_dev()
|
|
files = build_file_list(args)
|
|
expected_tables, expected_indexes = parse_expected_schema(files)
|
|
actual_tables, actual_columns, actual_indexes = query_actual_schema()
|
|
|
|
report = compare_expected_vs_actual(
|
|
expected_tables,
|
|
expected_indexes,
|
|
actual_tables,
|
|
actual_columns,
|
|
actual_indexes,
|
|
)
|
|
|
|
if args.json:
|
|
payload = {
|
|
"status": "ok" if determine_exit_code(report, args.strict_indexes) == 0 else "mismatch",
|
|
"strict_indexes": args.strict_indexes,
|
|
"parsed_files": [str(path.relative_to(ROOT)) for path in files],
|
|
"missing_tables": report["missing_tables"],
|
|
"missing_columns": report["missing_columns"],
|
|
"missing_indexes": report["missing_indexes"],
|
|
}
|
|
print(json.dumps(payload, ensure_ascii=True))
|
|
else:
|
|
logger.info("Parsed migration files: %s", len(files))
|
|
print_report(report)
|
|
|
|
return determine_exit_code(report, args.strict_indexes)
|
|
except Exception as exc:
|
|
if args.json:
|
|
print(json.dumps({"status": "error", "message": str(exc)}, ensure_ascii=True))
|
|
else:
|
|
logger.error("Validation failed: %s", exc)
|
|
return 2
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|