diff --git a/controller.py b/controller.py index 29d4ae3..d0aba82 100644 --- a/controller.py +++ b/controller.py @@ -27,17 +27,16 @@ class TaskDataSocketController: else: return False - async def close_all(self, channel: str) -> int: - if channel in self.socket_map: - sended = 0 - socket_list = self.socket_map[channel] - for socket in socket_list: - await socket.close() - sended += 1 - - return sended - else: - return 0 + async def close_all(self) -> int: + closedCount = 0 + for (_, sockets) in self.socket_map.items(): + for socket in sockets: + try: + await socket.close() + except Exception as err: + print(err) + closedCount += 1 + return closedCount async def emit(self, channel: str, data = {}) -> int: if channel in self.socket_map: diff --git a/main.py b/main.py index 2f121c8..ac27f70 100644 --- a/main.py +++ b/main.py @@ -2,6 +2,7 @@ from asyncore import socket_map from concurrent.futures import thread import os import re +import signal from types import LambdaType from async_timeout import asyncio from fastapi import FastAPI, Request, Depends, WebSocket, WebSocketDisconnect @@ -12,6 +13,7 @@ from fastapi.exceptions import HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from starlette.responses import FileResponse +from starlette.websockets import WebSocketState from controller import TaskDataSocketController from hydra_node.config import init_config_model from hydra_node.models import EmbedderModel, torch_gc @@ -32,6 +34,7 @@ from PIL.PngImagePlugin import PngInfo import json genlock = threading.Lock() +running = True MAX_N_SAMPLES = int(os.getenv("MAX_N_SAMPLES", 100)) TOKEN = os.getenv("TOKEN", None) @@ -97,6 +100,8 @@ def startup_event(): def shutdown_event(): logger.info("FastAPI Shutdown, exiting") queue.stop() + time.sleep(2) + os.kill(mainpid, signal.SIGTERM) class Masker(TypedDict): seed: int @@ -274,7 +279,7 @@ async def generate_stream(request: GenerationRequest, authorized: bool = Depends raise err images_encoded = [] - while True: + while running: task_data = queue.get_task_data(task_id) if not task_data: raise Exception("Task not found") @@ -452,7 +457,7 @@ async def generate(request: GenerationRequest, authorized: bool = Depends(verify raise err images_encoded = [] - while True: + while running: task_data = queue.get_task_data(task_id) if not task_data: raise Exception("Task not found") @@ -614,7 +619,6 @@ async def get_task_info(request: TaskIdRequest): @app.websocket('/task-info') async def websocket_endpoint(websocket: WebSocket, task_id: str): - logger.info("websocket request task_id: %s" % task_id) await websocket.accept() try: @@ -634,12 +638,15 @@ async def websocket_endpoint(websocket: WebSocket, task_id: str): return task_ws_controller.add(task_id, websocket) - while True: + while running and websocket.application_state == WebSocketState.CONNECTED and websocket.client_state == WebSocketState.CONNECTED: data = await websocket.receive_text() if data == "ping": await websocket.send_json({ "pong": "pong" }) + + if websocket.application_state == WebSocketState.CONNECTED and websocket.client_state == WebSocketState.CONNECTED: + await websocket.close() except WebSocketDisconnect: task_ws_controller.remove(task_id, websocket) diff --git a/task.py b/task.py index f4981c1..fc2ad12 100644 --- a/task.py +++ b/task.py @@ -150,9 +150,10 @@ class TaskQueue: self.event_list = [] self.event_trigger.set() - # in thread + # in task thread def _on_task_update(self, task: TaskData, updated_value: Dict): - self.event_list.append(task.to_dict()) + with self.event_lock: + self.event_list.append(task.to_dict()) self.event_trigger.set() def _start(self): @@ -207,14 +208,15 @@ class TaskQueue: with self.lock: self.queued_task_list.pop(0) - # reset queue position - for i in range(0, len(self.queued_task_list)): - self.queued_task_list[i].update({ - "position": i - }) - - if self.is_busy and len(self.queued_task_list) < self.recovery_queue_size: # mark not busy - self.is_busy = False + + # reset queue position + for i in range(0, len(self.queued_task_list)): + self.queued_task_list[i].update({ + "position": i + }) + + if self.is_busy and len(self.queued_task_list) < self.recovery_queue_size: # mark not busy + self.is_busy = False self.logger.info("Taskļ¼š%s finished, queue size: %d" % (current_task.task_id, len(self.queued_task_list))) @@ -226,17 +228,20 @@ class TaskQueue: while self.core_running: current_time = time.time() with self.lock: - will_remove_keys = [] - for (key, task_data) in self.task_map.items(): - if task_data.status == "finished" or task_data.status == "error": - if task_data.updated_time + 120 < current_time: # delete task which finished 2 minutes ago - will_remove_keys.append(key) - - for key in will_remove_keys: - del(self.task_map[key]) + task_map = self.task_map.copy() + + will_remove_keys = [] + for (key, task_data) in task_map.items(): + if task_data.status == "finished" or task_data.status == "error": + if task_data.updated_time + 120 < current_time: # delete task which finished 2 minutes ago + will_remove_keys.append(key) + + if len(will_remove_keys) > 0: + with self.lock: + for key in will_remove_keys: + del(self.task_map[key]) - time.sleep(0.1) - self.logger.info("TaskQueue GC Stopped.") + time.sleep(1) # Event loop def _event_loop(self): @@ -250,7 +255,9 @@ class TaskQueue: self.on_update(event) except Exception as err: print(err) - self.event_list.pop(0) + + with self.event_lock: + self.event_list.pop(0) self.event_trigger.clear() - self.logger.info("TaskQueue Event Stopped.") + loop.stop()