bmc_hub/app/modules/telefoni/backend/websocket.py

68 lines
2.4 KiB
Python
Raw Permalink Normal View History

import asyncio
import json
import logging
from typing import Dict, Set
from fastapi import WebSocket
logger = logging.getLogger(__name__)
class TelefoniConnectionManager:
def __init__(self) -> None:
self._lock = asyncio.Lock()
self._connections: Dict[int, Set[WebSocket]] = {}
async def connect(self, user_id: int, websocket: WebSocket) -> None:
await websocket.accept()
async with self._lock:
self._connections.setdefault(user_id, set()).add(websocket)
logger.info("📞 WS manager: user_id=%s now has %s connection(s)", user_id, len(self._connections.get(user_id, set())))
async def disconnect(self, user_id: int, websocket: WebSocket) -> None:
async with self._lock:
ws_set = self._connections.get(user_id)
if not ws_set:
return
ws_set.discard(websocket)
if not ws_set:
self._connections.pop(user_id, None)
logger.info("📞 WS manager: user_id=%s disconnected (0 connections)", user_id)
else:
logger.info("📞 WS manager: user_id=%s now has %s connection(s)", user_id, len(ws_set))
async def active_users(self) -> list[int]:
async with self._lock:
return sorted(self._connections.keys())
async def connection_count_for_user(self, user_id: int) -> int:
async with self._lock:
return len(self._connections.get(user_id, set()))
async def send_to_user(self, user_id: int, event: str, payload: dict) -> None:
message = json.dumps({"event": event, "data": payload}, default=str)
async with self._lock:
targets = list(self._connections.get(user_id, set()))
if not targets:
active = await self.active_users()
logger.info("⚠️ WS send skipped: no active connections for user_id=%s (active users=%s)", user_id, active)
return
dead: list[WebSocket] = []
for ws in targets:
try:
await ws.send_text(message)
except Exception as e:
logger.warning("⚠️ WS send failed for user %s: %s", user_id, e)
dead.append(ws)
if dead:
async with self._lock:
ws_set = self._connections.get(user_id, set())
for ws in dead:
ws_set.discard(ws)
manager = TelefoniConnectionManager()