From f39f24545bc30f8b451c0d5097eac37206190695 Mon Sep 17 00:00:00 2001 From: Lex Lim Date: Sat, 22 Oct 2022 09:09:28 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0websocket=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller.py | 52 +++++++++ main.py | 284 ++++++++++++++++---------------------------------- task.py | 256 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 396 insertions(+), 196 deletions(-) create mode 100644 controller.py create mode 100644 task.py diff --git a/controller.py b/controller.py new file mode 100644 index 0000000..29d4ae3 --- /dev/null +++ b/controller.py @@ -0,0 +1,52 @@ +from typing import Dict, List + +from fastapi import WebSocket + + +class TaskDataSocketController: + socket_map: Dict[str, List[WebSocket]] = {} + + def add(self, channel: str, socket: WebSocket): + if channel not in self.socket_map: + self.socket_map[channel] = [] + self.socket_map[channel].append(socket) + + def remove(self, channel: str, socket: WebSocket) -> bool: + if channel in self.socket_map: + socket_list = self.socket_map[channel] + new_socket_list = [] + for socket_item in socket_list: + if socket_item != socket: + new_socket_list.append(socket_item) + + if len(new_socket_list) == 0: + del(self.socket_map[channel]) + else: + self.socket_map[channel] = new_socket_list + return True + else: + return False + + async def close_all(self, channel: str) -> int: + if channel in self.socket_map: + sended = 0 + socket_list = self.socket_map[channel] + for socket in socket_list: + await socket.close() + sended += 1 + + return sended + else: + return 0 + + async def emit(self, channel: str, data = {}) -> int: + if channel in self.socket_map: + sended = 0 + socket_list = self.socket_map[channel] + for socket in socket_list: + await socket.send_json(data) + sended += 1 + + return sended + else: + return 0 \ No newline at end of file diff --git a/main.py b/main.py index 3a93c17..2f121c8 100644 --- a/main.py +++ b/main.py @@ -1,13 +1,10 @@ +from asyncore import socket_map from concurrent.futures import thread import os -import random import re -import string -import sys from types import LambdaType -from aiohttp import request from async_timeout import asyncio -from fastapi import FastAPI, Request, Depends +from fastapi import FastAPI, Request, Depends, WebSocket, WebSocketDisconnect from numpy import number from pydantic import BaseModel from fastapi.responses import HTMLResponse, PlainTextResponse, Response @@ -15,18 +12,18 @@ from fastapi.exceptions import HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from starlette.responses import FileResponse +from controller import TaskDataSocketController from hydra_node.config import init_config_model from hydra_node.models import EmbedderModel, torch_gc -from typing import Optional, List, Any +from typing import Dict, List, Union from typing_extensions import TypedDict import socket from hydra_node.sanitize import sanitize_input +from task import TaskData, TaskQueue, TaskQueueFullException import uvicorn -from typing import Union, Dict import time import gc import io -import signal import base64 import traceback import threading @@ -56,169 +53,21 @@ mainpid = config.mainpid hostname = socket.gethostname() sent_first_message = False -class TaskQueueFullException(Exception): - pass - -class TaskData(TypedDict): - task_id: str - position: str - status: str - updated_time: float - callback: LambdaType - current_step: int - total_steps: int - payload: Any - response: Any - -class TaskQueue: - max_queue_size = 10 - recovery_queue_size = 6 - is_busy = False - - loop_thread = None - gc_loop_thread = None - running = False - gc_running = False - - lock = threading.Lock() - - queued_task_list: List[TaskData] = [] - task_map: Dict[str, TaskData] = {} - - def __init__(self, max_queue_size = 10, recovery_queue_size = 6) -> None: - self.max_queue_size = max_queue_size - self.max_queue_size = recovery_queue_size - - self.gc_loop_thread = threading.Thread(name="TaskQueueGC", target=self._gc_loop) - self.gc_running = True - self.gc_loop_thread.start() - - logger.info("Task queue created") - - def add_task(self, callback: LambdaType, payload: Any = {}) -> str: - if self.is_busy: - raise TaskQueueFullException("Task queue is full") - if len(self.queued_task_list) >= self.max_queue_size: # mark busy - self.is_busy = True - raise TaskQueueFullException("Task queue is full") - - task_id = ''.join(random.sample(string.ascii_letters + string.digits, 16)) - task = TaskData( - task_id=task_id, - position=0, - status="queued", - updated_time=time.time(), - callback=callback, - current_step=0, - total_steps=0, - payload=payload - ) - with self.lock: - self.queued_task_list.append(task) - task["position"] = len(self.queued_task_list) - 1 - logger.info("Added task: %s, queue size: %d" % (task_id, len(self.queued_task_list))) - # create index - with self.lock: - self.task_map[task_id] = task - - self._start() - - return task_id - - def get_task_data(self, task_id: str) -> Union[TaskData, bool]: - if task_id in self.task_map: - return self.task_map[task_id] - else: - return False - - def delete_task_data(self, task_id: str) -> bool: - if task_id in self.task_map: - with self.lock: - del(self.task_map[task_id]) - else: - return False - - def stop(self): - self.running = False - self.gc_running = False - self.queued_task_list = [] - - def _start(self): - with self.lock: - if self.running == False: - self.loop_thread = threading.Thread(name="TaskQueue", target=self._loop) - self.running = True - self.loop_thread.start() - - def _loop(self): - while self.running and len(self.queued_task_list) > 0: - current_task = self.queued_task_list[0] - logger.info("Start task:%s." % current_task["task_id"]) - try: - current_task["status"] = "running" - current_task["updated_time"] = time.time() - - # run task - res = current_task["callback"](current_task) - - # call gc - gc.collect() - torch_gc() - - current_task["status"] = "finished" - current_task["updated_time"] = time.time() - current_task["response"] = res - current_task["current_step"] = current_task["total_steps"] - except Exception as e: - current_task["status"] = "error" - current_task["updated_time"] = time.time() - current_task["response"] = e - gc.collect() - e_s = str(e) - if "CUDA out of memory" in e_s or \ - "an illegal memory access" in e_s or "CUDA" in e_s: - traceback.print_exc() - logger.error(str(e)) - logger.error("GPU error in task.") - torch_gc() - current_task["response"] = Exception("GPU error in task.") - # os.kill(mainpid, signal.SIGTERM) - - with self.lock: - self.queued_task_list.pop(0) - # reset queue position - for i in range(0, len(self.queued_task_list)): - self.queued_task_list[i]["position"] = i - - if self.is_busy and len(self.queued_task_list) < self.recovery_queue_size: # mark not busy - self.is_busy = False - - logger.info("Task:%s finished, queue size: %d" % (current_task["task_id"], len(self.queued_task_list))) - - with self.lock: - self.running = False - - # Task to remove finished task - def _gc_loop(self): - while self.gc_running: - current_time = time.time() - with self.lock: - will_remove_keys = [] - for (key, task_data) in self.task_map.items(): - if task_data["status"] == "finished" or task_data["status"] == "error": - if task_data["updated_time"] + 120 < current_time: # delete task which finished 2 minutes ago - will_remove_keys.append(key) - - for key in will_remove_keys: - del(self.task_map[key]) - - time.sleep(1) - queue = TaskQueue( + logger=logger, max_queue_size=int(os.getenv("QUEUE_MAX_SIZE", 10)), recovery_queue_size=int(os.getenv("QUEUE_RECOVERY_SIZE", 6)) ) +task_ws_controller = TaskDataSocketController() + +def on_task_update(task_info: Dict): + loop = asyncio.get_event_loop() + loop.run_until_complete(task_ws_controller.emit(task_info["task_id"], task_info)) + +queue.on_update = on_task_update + + def verify_token(req: Request): if TOKEN: valid = "Authorization" in req.headers and req.headers["Authorization"] == "Bearer "+TOKEN @@ -335,13 +184,17 @@ def saveimage(image, request): except Exception as e: print("failed to save image:", e) -def _generate_stream(request: GenerationRequest, task_info: TaskData = TaskData()): +def _generate_stream(request: GenerationRequest, task_info: TaskData): try: - task_info["total_steps"] = request.steps + 1 + task_info.update({ + "total_steps": request.steps + 1 + }) def _on_step(step_num, total_steps): - task_info["total_steps"] = total_steps + 1 - task_info["current_step"] = step_num + task_info.update({ + "total_steps": total_steps + 1, + "current_step": step_num + }) if request.advanced: if request.n_samples > 1: @@ -378,7 +231,9 @@ def _generate_stream(request: GenerationRequest, task_info: TaskData = TaskData( image = base64.b64encode(image).decode("ascii") images_encoded.append(image) - task_info["current_step"] += 1 + task_info.update({ + "current_step": task_info.current_step + 1 + }) del images @@ -423,11 +278,11 @@ async def generate_stream(request: GenerationRequest, authorized: bool = Depends task_data = queue.get_task_data(task_id) if not task_data: raise Exception("Task not found") - if task_data["status"] == "finished": - images_encoded = task_data["response"] + if task_data.status == "finished": + images_encoded = task_data.response break - elif task_data["status"] == "error": - return {"error": str(task_data["response"])} + elif task_data.status == "error": + return {"error": str(task_data.response)} await asyncio.sleep(0.1) process_time = time.perf_counter() - t @@ -493,8 +348,8 @@ async def generate_stream_output(request: TaskIdRequest, authorized: bool = Depe task_data = queue.get_task_data(task_id) if not task_data: return ErrorOutput(error="Task not found.") - if task_data["status"] == "finished": - images_encoded = task_data["response"] + if task_data.status == "finished": + images_encoded = task_data.response data = "" ptr = 0 @@ -503,8 +358,8 @@ async def generate_stream_output(request: TaskIdRequest, authorized: bool = Depe data += ("event: newImage\nid: {}\ndata:{}\n\n").format(ptr, x) return Response(content=data, media_type="text/event-stream") - elif task_data["status"] == "error": - raise task_data["response"] + elif task_data.status == "error": + raise task_data.response else: return ErrorOutput(error="Task is not finished.") @@ -522,11 +377,15 @@ async def generate_stream_output(request: TaskIdRequest, authorized: bool = Depe def _generate(request: GenerationRequest, task_info: TaskData): try: - task_info["total_steps"] = request.steps + 1 + task_info.update({ + "total_steps": request.steps + 1 + }) def _on_step(step_num, total_steps): - task_info["total_steps"] = total_steps + 1 - task_info["current_step"] = step_num + task_info.update({ + "total_steps": total_steps + 1, + "current_step": step_num + }) images = model.sample(request, callback=_on_step) @@ -551,7 +410,9 @@ def _generate(request: GenerationRequest, task_info: TaskData): image = base64.b64encode(image).decode("ascii") images_encoded.append(image) - task_info["current_step"] += 1 + task_info.update({ + "current_step": task_info.current_step + 1 + }) del images @@ -595,11 +456,11 @@ async def generate(request: GenerationRequest, authorized: bool = Depends(verify task_data = queue.get_task_data(task_id) if not task_data: raise Exception("Task not found") - if task_data["status"] == "finished": - images_encoded = task_data["response"] + if task_data.status == "finished": + images_encoded = task_data.response break - elif task_data["status"] == "error": - return {"error": str(task_data["response"])} + elif task_data.status == "error": + return {"error": str(task_data.response)} await asyncio.sleep(0.1) @@ -660,13 +521,13 @@ async def generate_stream_output(request: TaskIdRequest, authorized: bool = Depe task_data = queue.get_task_data(task_id) if not task_data: return ErrorOutput(error="Task not found.") - if task_data["status"] == "finished": - images_encoded = task_data["response"] + if task_data.status == "finished": + images_encoded = task_data.response return GenerationOutput(output=images_encoded) - elif task_data["status"] == "error": - raise task_data["response"] + elif task_data.status == "error": + raise task_data.response else: return ErrorOutput(error="Task is not finished.") @@ -743,13 +604,44 @@ async def get_task_info(request: TaskIdRequest): task_data = queue.get_task_data(request.task_id) if task_data: return TaskDataOutput( - status=task_data["status"], - position=task_data["position"], - current_step=task_data["current_step"], - total_steps=task_data["total_steps"] + status=task_data.status, + position=task_data.position, + current_step=task_data.current_step, + total_steps=task_data.total_steps ) else: - return ErrorOutput(error="Cannot find current task.") + return ErrorOutput(error="Cannot find current task.", code="ERR::TASK_NOT_FOUND") + +@app.websocket('/task-info') +async def websocket_endpoint(websocket: WebSocket, task_id: str): + logger.info("websocket request task_id: %s" % task_id) + await websocket.accept() + + try: + task_data = queue.get_task_data(task_id) + if not task_data: + await websocket.send_json({ + "error": "ERR::TASK_NOT_FOUND", + "message": "Cannot find current task." + }) + await websocket.close() + return + + # Send current task info + await websocket.send_json(task_data.to_dict()) + if task_data.status == "finished" or task_data.status == "error": + await websocket.close() + return + + task_ws_controller.add(task_id, websocket) + while True: + data = await websocket.receive_text() + if data == "ping": + await websocket.send_json({ + "pong": "pong" + }) + except WebSocketDisconnect: + task_ws_controller.remove(task_id, websocket) @app.get('/') def index(): diff --git a/task.py b/task.py new file mode 100644 index 0000000..f4981c1 --- /dev/null +++ b/task.py @@ -0,0 +1,256 @@ +import gc +from pydoc import isdata +import random +import string +import threading +import time +import traceback +from types import LambdaType +from typing import Dict, Optional, List, Any, Union +from typing_extensions import TypedDict + +from async_timeout import asyncio + +from hydra_node.models import torch_gc + +class TaskQueueFullException(Exception): + pass + +class TaskDataType(TypedDict): + task_id: str + position: int + status: str + updated_time: float + callback: LambdaType + current_step: int + total_steps: int + payload: Any + response: Any + +class TaskData: + _task_queue = None + + task_id: str = "" + position: int = 0 + status: str = "preparing" + updated_time: float = 0 + callback: LambdaType = lambda x: x + current_step: int = 0 + total_steps: int = 0 + payload: Any = {} + response: Any = None + + def __init__(self, taskQueue, initData: Dict = {}): + self._task_queue = taskQueue + self.update(initData, skipEvent=True) + + def update(self, data: Dict, skipEvent: bool = False): + for (key, value) in data.items(): + if key[0] != "_" and hasattr(self, key): + setattr(self, key, value) + else: + raise Exception("Property %s not found on TaskData" % (key)) + + if skipEvent == False: + self._task_queue._on_task_update(self, data) + + def to_dict(self): + allowedItems = ["task_id", "position", "status", "updated_time", "current_step", "total_steps"] + retDict = {} + for itemKey in allowedItems: + retDict[itemKey] = getattr(self, itemKey) + + return retDict + +class TaskQueue: + logger = None + + max_queue_size = 10 + recovery_queue_size = 6 + is_busy = False + + loop_thread = None + gc_loop_thread = None + event_loop_thread = None + running = False + core_running = False + + lock = threading.Lock() + event_lock = threading.Lock() + + event_trigger = threading.Event() + event_list = [] + + queued_task_list: List[TaskData] = [] + task_map: Dict[str, TaskData] = {} + + on_update = lambda queue_data: None + + def __init__(self, logger, max_queue_size = 10, recovery_queue_size = 6) -> None: + self.logger = logger + + self.max_queue_size = max_queue_size + self.max_queue_size = recovery_queue_size + + self.core_running = True + self.gc_loop_thread = threading.Thread(name="TaskQueueGC", target=self._gc_loop) + self.gc_loop_thread.start() + self.event_loop_thread = threading.Thread(name="TaskQueueEvent", target=self._event_loop) + self.event_loop_thread.start() + + logger.info("Task queue created") + + def add_task(self, callback: LambdaType, payload: Any = {}) -> str: + if self.is_busy: + raise TaskQueueFullException("Task queue is full") + if len(self.queued_task_list) >= self.max_queue_size: # mark busy + self.is_busy = True + raise TaskQueueFullException("Task queue is full") + + task_id = ''.join(random.sample(string.ascii_letters + string.digits, 16)) + task = TaskData(self, { + "task_id": task_id, + "position": 0, + "status": "queued", + "updated_time": time.time(), + "callback": callback, + "current_step": 0, + "total_steps": 0, + "payload": payload + }) + with self.lock: + self.queued_task_list.append(task) + task.position = len(self.queued_task_list) - 1 + self.logger.info("Added task: %s, queue size: %d" % (task_id, len(self.queued_task_list))) + # create index + with self.lock: + self.task_map[task_id] = task + + self._start() + + return task_id + + def get_task_data(self, task_id: str) -> Union[TaskData, bool]: + if task_id in self.task_map: + return self.task_map[task_id] + else: + return False + + def delete_task_data(self, task_id: str) -> bool: + if task_id in self.task_map: + with self.lock: + del(self.task_map[task_id]) + else: + return False + + def stop(self): + self.running = False + self.core_running = False + self.queued_task_list = [] + self.event_list = [] + self.event_trigger.set() + + # in thread + def _on_task_update(self, task: TaskData, updated_value: Dict): + self.event_list.append(task.to_dict()) + self.event_trigger.set() + + def _start(self): + with self.lock: + if self.running == False: + self.loop_thread = threading.Thread(name="TaskQueue", target=self._loop) + self.running = True + self.loop_thread.start() + + def _loop(self): + while self.running and len(self.queued_task_list) > 0: + current_task = self.queued_task_list[0] + self.logger.info("Start task:%s." % current_task.task_id) + try: + current_task.update({ + "status": "running", + "updated_time": time.time(), + }) + + # run task + res = current_task.callback(current_task) + + # call gc + gc.collect() + torch_gc() + + current_task.update({ + "status": "finished", + "updated_time": time.time(), + "response": res, + "current_step": current_task.total_steps + }) + except Exception as e: + newState = { + "status": "error", + "updated_time": time.time(), + "response": e + } + + gc.collect() + e_s = str(e) + if "CUDA out of memory" in e_s or \ + "an illegal memory access" in e_s or "CUDA" in e_s: + traceback.print_exc() + self.logger.error(str(e)) + self.logger.error("GPU error in task.") + torch_gc() + newState["response"] = Exception("GPU error in task.") + # os.kill(mainpid, signal.SIGTERM) + + current_task.update(newState) + + with self.lock: + self.queued_task_list.pop(0) + # reset queue position + for i in range(0, len(self.queued_task_list)): + self.queued_task_list[i].update({ + "position": i + }) + + if self.is_busy and len(self.queued_task_list) < self.recovery_queue_size: # mark not busy + self.is_busy = False + + self.logger.info("Task:%s finished, queue size: %d" % (current_task.task_id, len(self.queued_task_list))) + + with self.lock: + self.running = False + + # Task to remove finished task + def _gc_loop(self): + while self.core_running: + current_time = time.time() + with self.lock: + will_remove_keys = [] + for (key, task_data) in self.task_map.items(): + if task_data.status == "finished" or task_data.status == "error": + if task_data.updated_time + 120 < current_time: # delete task which finished 2 minutes ago + will_remove_keys.append(key) + + for key in will_remove_keys: + del(self.task_map[key]) + + time.sleep(0.1) + self.logger.info("TaskQueue GC Stopped.") + + # Event loop + def _event_loop(self): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + while self.core_running: + self.event_trigger.wait() + while len(self.event_list) > 0: + event = self.event_list[0] + try: + self.on_update(event) + except Exception as err: + print(err) + self.event_list.pop(0) + self.event_trigger.clear() + + self.logger.info("TaskQueue Event Stopped.")