from typing import Dict, List from fastapi import WebSocket class TaskDataSocketController: socket_map: Dict[str, List[WebSocket]] = {} def add(self, channel: str, socket: WebSocket): if channel not in self.socket_map: self.socket_map[channel] = [] self.socket_map[channel].append(socket) def remove(self, channel: str, socket: WebSocket) -> bool: if channel in self.socket_map: socket_list = self.socket_map[channel] new_socket_list = [] for socket_item in socket_list: if socket_item != socket: new_socket_list.append(socket_item) if len(new_socket_list) == 0: del(self.socket_map[channel]) else: self.socket_map[channel] = new_socket_list return True else: return False async def close_all(self) -> int: closedCount = 0 for (_, sockets) in self.socket_map.items(): for socket in sockets: try: await socket.close() except Exception as err: print(err) closedCount += 1 return closedCount async def emit(self, channel: str, data = {}) -> int: if channel in self.socket_map: sended = 0 socket_list = self.socket_map[channel] for socket in socket_list: await socket.send_json(data) sended += 1 return sended else: return 0