#!/usr/bin/env python3 """ Signalling relay + periodic label aggregation broadcaster with stale‑timeout. """ import asyncio, json, logging, time import websockets logging.basicConfig(level=logging.INFO) # ws → peer_id mapping, and peer_id → ws CLIENTS: dict[str, websockets.WebSocketServerProtocol] = {} # peer_id → latest labels list CLIENT_LABELS: dict[str, list[str]] = {} # peer_id → timestamp of last labels update CLIENT_LABELS_TS: dict[str, float] = {} # After this many seconds without an update, we'll clear their list STALE_TIMEOUT = 5.0 broadcast = False warning = False info = False debug = False async def handler(ws: websockets.WebSocketServerProtocol): peer_id = None try: async for raw in ws: try: msg = json.loads(raw) except json.JSONDecodeError: if warning: logging.warning("Bad JSON from %s: %s", peer_id, raw) continue mtype = msg.get("type") if mtype == "register": peer_id = msg.get("id") if not peer_id: await ws.close(code=4000, reason="Missing id") return CLIENTS[peer_id] = ws CLIENT_LABELS[peer_id] = [] CLIENT_LABELS_TS[peer_id] = time.monotonic() logging.info("Registered %s (%d clients)", peer_id, len(CLIENTS)) elif mtype == "signal": target = msg.get("target") if target in CLIENTS: await CLIENTS[target].send(raw) else: if warning: logging.warning("Target %s not found", target) elif mtype == "labels": labels = msg.get("labels") peer_id = msg.get("id") freqs = msg.get("freqs") transitions = msg.get("transitions") if info: logging.info(f"Got labels from {peer_id}: {msg['labels']}") if peer_id and isinstance(labels, list): payload = json.dumps({ "type": "broadcast_labels", "from": peer_id, "labels": labels, "freqs": freqs, "transitions": transitions }) # send to all except the originator count = 0 for other_id, other_ws in CLIENTS.items(): if other_id != peer_id: await other_ws.send(payload) count += 1 if info: logging.info(f"Re-broadcasted labels to {count} other clients (excluding {peer_id})") else: if warning: logging.warning("Malformed labels payload from %s: %s", peer_id, raw) else: if warning: logging.warning("Unknown message type from %s: %s", peer_id, mtype) except websockets.ConnectionClosed: pass finally: if peer_id: CLIENTS.pop(peer_id, None) CLIENT_LABELS.pop(peer_id, None) CLIENT_LABELS_TS.pop(peer_id, None) if info: logging.info("Disconnected %s (%d clients left)", peer_id, len(CLIENTS)) async def broadcast_labels_periodically(): while True: now = time.monotonic() # 1) clear stale clients for pid, last_ts in list(CLIENT_LABELS_TS.items()): if now - last_ts > STALE_TIMEOUT: CLIENT_LABELS[pid] = [] # Optionally also remove their timestamp entry if you don't # want them checked again until next registration/update: # CLIENT_LABELS_TS.pop(pid, None) if debug: logging.debug("Cleared labels for %s due to timeout", pid) # 2) broadcast nested lists if CLIENTS: nested = list(CLIENT_LABELS.values()) payload = json.dumps({ "type": "broadcast_labels", "data": nested }) await asyncio.gather(*( ws.send(payload) for ws in CLIENTS.values() ), return_exceptions=True) if debug: logging.debug("Broadcasted labels to %d clients", len(CLIENTS)) if info: logging.info(f"Broadcasted labels >>>>>> {payload}") await asyncio.sleep(1) async def main(): # start the broadcaster task broadcaster = None if broadcast: broadcaster = asyncio.create_task(broadcast_labels_periodically()) # start the websocket server async with websockets.serve(handler, "0.0.0.0", 8080): if info: logging.info("Signalling server listening on :8080") await asyncio.Future() # run forever if broadcast: broadcaster.cancel() if __name__ == "__main__": asyncio.run(main())