增加websocket支持

master
落雨楓 2 years ago
parent e23b2b2f95
commit f39f24545b

@ -0,0 +1,52 @@
from typing import Dict, List
from fastapi import WebSocket
class TaskDataSocketController:
socket_map: Dict[str, List[WebSocket]] = {}
def add(self, channel: str, socket: WebSocket):
if channel not in self.socket_map:
self.socket_map[channel] = []
self.socket_map[channel].append(socket)
def remove(self, channel: str, socket: WebSocket) -> bool:
if channel in self.socket_map:
socket_list = self.socket_map[channel]
new_socket_list = []
for socket_item in socket_list:
if socket_item != socket:
new_socket_list.append(socket_item)
if len(new_socket_list) == 0:
del(self.socket_map[channel])
else:
self.socket_map[channel] = new_socket_list
return True
else:
return False
async def close_all(self, channel: str) -> int:
if channel in self.socket_map:
sended = 0
socket_list = self.socket_map[channel]
for socket in socket_list:
await socket.close()
sended += 1
return sended
else:
return 0
async def emit(self, channel: str, data = {}) -> int:
if channel in self.socket_map:
sended = 0
socket_list = self.socket_map[channel]
for socket in socket_list:
await socket.send_json(data)
sended += 1
return sended
else:
return 0

@ -1,13 +1,10 @@
from asyncore import socket_map
from concurrent.futures import thread
import os
import random
import re
import string
import sys
from types import LambdaType
from aiohttp import request
from async_timeout import asyncio
from fastapi import FastAPI, Request, Depends
from fastapi import FastAPI, Request, Depends, WebSocket, WebSocketDisconnect
from numpy import number
from pydantic import BaseModel
from fastapi.responses import HTMLResponse, PlainTextResponse, Response
@ -15,18 +12,18 @@ from fastapi.exceptions import HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse
from controller import TaskDataSocketController
from hydra_node.config import init_config_model
from hydra_node.models import EmbedderModel, torch_gc
from typing import Optional, List, Any
from typing import Dict, List, Union
from typing_extensions import TypedDict
import socket
from hydra_node.sanitize import sanitize_input
from task import TaskData, TaskQueue, TaskQueueFullException
import uvicorn
from typing import Union, Dict
import time
import gc
import io
import signal
import base64
import traceback
import threading
@ -56,169 +53,21 @@ mainpid = config.mainpid
hostname = socket.gethostname()
sent_first_message = False
class TaskQueueFullException(Exception):
pass
class TaskData(TypedDict):
task_id: str
position: str
status: str
updated_time: float
callback: LambdaType
current_step: int
total_steps: int
payload: Any
response: Any
class TaskQueue:
max_queue_size = 10
recovery_queue_size = 6
is_busy = False
loop_thread = None
gc_loop_thread = None
running = False
gc_running = False
lock = threading.Lock()
queued_task_list: List[TaskData] = []
task_map: Dict[str, TaskData] = {}
def __init__(self, max_queue_size = 10, recovery_queue_size = 6) -> None:
self.max_queue_size = max_queue_size
self.max_queue_size = recovery_queue_size
self.gc_loop_thread = threading.Thread(name="TaskQueueGC", target=self._gc_loop)
self.gc_running = True
self.gc_loop_thread.start()
logger.info("Task queue created")
def add_task(self, callback: LambdaType, payload: Any = {}) -> str:
if self.is_busy:
raise TaskQueueFullException("Task queue is full")
if len(self.queued_task_list) >= self.max_queue_size: # mark busy
self.is_busy = True
raise TaskQueueFullException("Task queue is full")
task_id = ''.join(random.sample(string.ascii_letters + string.digits, 16))
task = TaskData(
task_id=task_id,
position=0,
status="queued",
updated_time=time.time(),
callback=callback,
current_step=0,
total_steps=0,
payload=payload
)
with self.lock:
self.queued_task_list.append(task)
task["position"] = len(self.queued_task_list) - 1
logger.info("Added task: %s, queue size: %d" % (task_id, len(self.queued_task_list)))
# create index
with self.lock:
self.task_map[task_id] = task
self._start()
return task_id
def get_task_data(self, task_id: str) -> Union[TaskData, bool]:
if task_id in self.task_map:
return self.task_map[task_id]
else:
return False
def delete_task_data(self, task_id: str) -> bool:
if task_id in self.task_map:
with self.lock:
del(self.task_map[task_id])
else:
return False
def stop(self):
self.running = False
self.gc_running = False
self.queued_task_list = []
def _start(self):
with self.lock:
if self.running == False:
self.loop_thread = threading.Thread(name="TaskQueue", target=self._loop)
self.running = True
self.loop_thread.start()
def _loop(self):
while self.running and len(self.queued_task_list) > 0:
current_task = self.queued_task_list[0]
logger.info("Start task%s." % current_task["task_id"])
try:
current_task["status"] = "running"
current_task["updated_time"] = time.time()
# run task
res = current_task["callback"](current_task)
# call gc
gc.collect()
torch_gc()
current_task["status"] = "finished"
current_task["updated_time"] = time.time()
current_task["response"] = res
current_task["current_step"] = current_task["total_steps"]
except Exception as e:
current_task["status"] = "error"
current_task["updated_time"] = time.time()
current_task["response"] = e
gc.collect()
e_s = str(e)
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
traceback.print_exc()
logger.error(str(e))
logger.error("GPU error in task.")
torch_gc()
current_task["response"] = Exception("GPU error in task.")
# os.kill(mainpid, signal.SIGTERM)
with self.lock:
self.queued_task_list.pop(0)
# reset queue position
for i in range(0, len(self.queued_task_list)):
self.queued_task_list[i]["position"] = i
if self.is_busy and len(self.queued_task_list) < self.recovery_queue_size: # mark not busy
self.is_busy = False
logger.info("Task%s finished, queue size: %d" % (current_task["task_id"], len(self.queued_task_list)))
with self.lock:
self.running = False
# Task to remove finished task
def _gc_loop(self):
while self.gc_running:
current_time = time.time()
with self.lock:
will_remove_keys = []
for (key, task_data) in self.task_map.items():
if task_data["status"] == "finished" or task_data["status"] == "error":
if task_data["updated_time"] + 120 < current_time: # delete task which finished 2 minutes ago
will_remove_keys.append(key)
for key in will_remove_keys:
del(self.task_map[key])
time.sleep(1)
queue = TaskQueue(
logger=logger,
max_queue_size=int(os.getenv("QUEUE_MAX_SIZE", 10)),
recovery_queue_size=int(os.getenv("QUEUE_RECOVERY_SIZE", 6))
)
task_ws_controller = TaskDataSocketController()
def on_task_update(task_info: Dict):
loop = asyncio.get_event_loop()
loop.run_until_complete(task_ws_controller.emit(task_info["task_id"], task_info))
queue.on_update = on_task_update
def verify_token(req: Request):
if TOKEN:
valid = "Authorization" in req.headers and req.headers["Authorization"] == "Bearer "+TOKEN
@ -335,13 +184,17 @@ def saveimage(image, request):
except Exception as e:
print("failed to save image:", e)
def _generate_stream(request: GenerationRequest, task_info: TaskData = TaskData()):
def _generate_stream(request: GenerationRequest, task_info: TaskData):
try:
task_info["total_steps"] = request.steps + 1
task_info.update({
"total_steps": request.steps + 1
})
def _on_step(step_num, total_steps):
task_info["total_steps"] = total_steps + 1
task_info["current_step"] = step_num
task_info.update({
"total_steps": total_steps + 1,
"current_step": step_num
})
if request.advanced:
if request.n_samples > 1:
@ -378,7 +231,9 @@ def _generate_stream(request: GenerationRequest, task_info: TaskData = TaskData(
image = base64.b64encode(image).decode("ascii")
images_encoded.append(image)
task_info["current_step"] += 1
task_info.update({
"current_step": task_info.current_step + 1
})
del images
@ -423,11 +278,11 @@ async def generate_stream(request: GenerationRequest, authorized: bool = Depends
task_data = queue.get_task_data(task_id)
if not task_data:
raise Exception("Task not found")
if task_data["status"] == "finished":
images_encoded = task_data["response"]
if task_data.status == "finished":
images_encoded = task_data.response
break
elif task_data["status"] == "error":
return {"error": str(task_data["response"])}
elif task_data.status == "error":
return {"error": str(task_data.response)}
await asyncio.sleep(0.1)
process_time = time.perf_counter() - t
@ -493,8 +348,8 @@ async def generate_stream_output(request: TaskIdRequest, authorized: bool = Depe
task_data = queue.get_task_data(task_id)
if not task_data:
return ErrorOutput(error="Task not found.")
if task_data["status"] == "finished":
images_encoded = task_data["response"]
if task_data.status == "finished":
images_encoded = task_data.response
data = ""
ptr = 0
@ -503,8 +358,8 @@ async def generate_stream_output(request: TaskIdRequest, authorized: bool = Depe
data += ("event: newImage\nid: {}\ndata:{}\n\n").format(ptr, x)
return Response(content=data, media_type="text/event-stream")
elif task_data["status"] == "error":
raise task_data["response"]
elif task_data.status == "error":
raise task_data.response
else:
return ErrorOutput(error="Task is not finished.")
@ -522,11 +377,15 @@ async def generate_stream_output(request: TaskIdRequest, authorized: bool = Depe
def _generate(request: GenerationRequest, task_info: TaskData):
try:
task_info["total_steps"] = request.steps + 1
task_info.update({
"total_steps": request.steps + 1
})
def _on_step(step_num, total_steps):
task_info["total_steps"] = total_steps + 1
task_info["current_step"] = step_num
task_info.update({
"total_steps": total_steps + 1,
"current_step": step_num
})
images = model.sample(request, callback=_on_step)
@ -551,7 +410,9 @@ def _generate(request: GenerationRequest, task_info: TaskData):
image = base64.b64encode(image).decode("ascii")
images_encoded.append(image)
task_info["current_step"] += 1
task_info.update({
"current_step": task_info.current_step + 1
})
del images
@ -595,11 +456,11 @@ async def generate(request: GenerationRequest, authorized: bool = Depends(verify
task_data = queue.get_task_data(task_id)
if not task_data:
raise Exception("Task not found")
if task_data["status"] == "finished":
images_encoded = task_data["response"]
if task_data.status == "finished":
images_encoded = task_data.response
break
elif task_data["status"] == "error":
return {"error": str(task_data["response"])}
elif task_data.status == "error":
return {"error": str(task_data.response)}
await asyncio.sleep(0.1)
@ -660,13 +521,13 @@ async def generate_stream_output(request: TaskIdRequest, authorized: bool = Depe
task_data = queue.get_task_data(task_id)
if not task_data:
return ErrorOutput(error="Task not found.")
if task_data["status"] == "finished":
images_encoded = task_data["response"]
if task_data.status == "finished":
images_encoded = task_data.response
return GenerationOutput(output=images_encoded)
elif task_data["status"] == "error":
raise task_data["response"]
elif task_data.status == "error":
raise task_data.response
else:
return ErrorOutput(error="Task is not finished.")
@ -743,13 +604,44 @@ async def get_task_info(request: TaskIdRequest):
task_data = queue.get_task_data(request.task_id)
if task_data:
return TaskDataOutput(
status=task_data["status"],
position=task_data["position"],
current_step=task_data["current_step"],
total_steps=task_data["total_steps"]
status=task_data.status,
position=task_data.position,
current_step=task_data.current_step,
total_steps=task_data.total_steps
)
else:
return ErrorOutput(error="Cannot find current task.")
return ErrorOutput(error="Cannot find current task.", code="ERR::TASK_NOT_FOUND")
@app.websocket('/task-info')
async def websocket_endpoint(websocket: WebSocket, task_id: str):
logger.info("websocket request task_id: %s" % task_id)
await websocket.accept()
try:
task_data = queue.get_task_data(task_id)
if not task_data:
await websocket.send_json({
"error": "ERR::TASK_NOT_FOUND",
"message": "Cannot find current task."
})
await websocket.close()
return
# Send current task info
await websocket.send_json(task_data.to_dict())
if task_data.status == "finished" or task_data.status == "error":
await websocket.close()
return
task_ws_controller.add(task_id, websocket)
while True:
data = await websocket.receive_text()
if data == "ping":
await websocket.send_json({
"pong": "pong"
})
except WebSocketDisconnect:
task_ws_controller.remove(task_id, websocket)
@app.get('/')
def index():

@ -0,0 +1,256 @@
import gc
from pydoc import isdata
import random
import string
import threading
import time
import traceback
from types import LambdaType
from typing import Dict, Optional, List, Any, Union
from typing_extensions import TypedDict
from async_timeout import asyncio
from hydra_node.models import torch_gc
class TaskQueueFullException(Exception):
pass
class TaskDataType(TypedDict):
task_id: str
position: int
status: str
updated_time: float
callback: LambdaType
current_step: int
total_steps: int
payload: Any
response: Any
class TaskData:
_task_queue = None
task_id: str = ""
position: int = 0
status: str = "preparing"
updated_time: float = 0
callback: LambdaType = lambda x: x
current_step: int = 0
total_steps: int = 0
payload: Any = {}
response: Any = None
def __init__(self, taskQueue, initData: Dict = {}):
self._task_queue = taskQueue
self.update(initData, skipEvent=True)
def update(self, data: Dict, skipEvent: bool = False):
for (key, value) in data.items():
if key[0] != "_" and hasattr(self, key):
setattr(self, key, value)
else:
raise Exception("Property %s not found on TaskData" % (key))
if skipEvent == False:
self._task_queue._on_task_update(self, data)
def to_dict(self):
allowedItems = ["task_id", "position", "status", "updated_time", "current_step", "total_steps"]
retDict = {}
for itemKey in allowedItems:
retDict[itemKey] = getattr(self, itemKey)
return retDict
class TaskQueue:
logger = None
max_queue_size = 10
recovery_queue_size = 6
is_busy = False
loop_thread = None
gc_loop_thread = None
event_loop_thread = None
running = False
core_running = False
lock = threading.Lock()
event_lock = threading.Lock()
event_trigger = threading.Event()
event_list = []
queued_task_list: List[TaskData] = []
task_map: Dict[str, TaskData] = {}
on_update = lambda queue_data: None
def __init__(self, logger, max_queue_size = 10, recovery_queue_size = 6) -> None:
self.logger = logger
self.max_queue_size = max_queue_size
self.max_queue_size = recovery_queue_size
self.core_running = True
self.gc_loop_thread = threading.Thread(name="TaskQueueGC", target=self._gc_loop)
self.gc_loop_thread.start()
self.event_loop_thread = threading.Thread(name="TaskQueueEvent", target=self._event_loop)
self.event_loop_thread.start()
logger.info("Task queue created")
def add_task(self, callback: LambdaType, payload: Any = {}) -> str:
if self.is_busy:
raise TaskQueueFullException("Task queue is full")
if len(self.queued_task_list) >= self.max_queue_size: # mark busy
self.is_busy = True
raise TaskQueueFullException("Task queue is full")
task_id = ''.join(random.sample(string.ascii_letters + string.digits, 16))
task = TaskData(self, {
"task_id": task_id,
"position": 0,
"status": "queued",
"updated_time": time.time(),
"callback": callback,
"current_step": 0,
"total_steps": 0,
"payload": payload
})
with self.lock:
self.queued_task_list.append(task)
task.position = len(self.queued_task_list) - 1
self.logger.info("Added task: %s, queue size: %d" % (task_id, len(self.queued_task_list)))
# create index
with self.lock:
self.task_map[task_id] = task
self._start()
return task_id
def get_task_data(self, task_id: str) -> Union[TaskData, bool]:
if task_id in self.task_map:
return self.task_map[task_id]
else:
return False
def delete_task_data(self, task_id: str) -> bool:
if task_id in self.task_map:
with self.lock:
del(self.task_map[task_id])
else:
return False
def stop(self):
self.running = False
self.core_running = False
self.queued_task_list = []
self.event_list = []
self.event_trigger.set()
# in thread
def _on_task_update(self, task: TaskData, updated_value: Dict):
self.event_list.append(task.to_dict())
self.event_trigger.set()
def _start(self):
with self.lock:
if self.running == False:
self.loop_thread = threading.Thread(name="TaskQueue", target=self._loop)
self.running = True
self.loop_thread.start()
def _loop(self):
while self.running and len(self.queued_task_list) > 0:
current_task = self.queued_task_list[0]
self.logger.info("Start task%s." % current_task.task_id)
try:
current_task.update({
"status": "running",
"updated_time": time.time(),
})
# run task
res = current_task.callback(current_task)
# call gc
gc.collect()
torch_gc()
current_task.update({
"status": "finished",
"updated_time": time.time(),
"response": res,
"current_step": current_task.total_steps
})
except Exception as e:
newState = {
"status": "error",
"updated_time": time.time(),
"response": e
}
gc.collect()
e_s = str(e)
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
traceback.print_exc()
self.logger.error(str(e))
self.logger.error("GPU error in task.")
torch_gc()
newState["response"] = Exception("GPU error in task.")
# os.kill(mainpid, signal.SIGTERM)
current_task.update(newState)
with self.lock:
self.queued_task_list.pop(0)
# reset queue position
for i in range(0, len(self.queued_task_list)):
self.queued_task_list[i].update({
"position": i
})
if self.is_busy and len(self.queued_task_list) < self.recovery_queue_size: # mark not busy
self.is_busy = False
self.logger.info("Task%s finished, queue size: %d" % (current_task.task_id, len(self.queued_task_list)))
with self.lock:
self.running = False
# Task to remove finished task
def _gc_loop(self):
while self.core_running:
current_time = time.time()
with self.lock:
will_remove_keys = []
for (key, task_data) in self.task_map.items():
if task_data.status == "finished" or task_data.status == "error":
if task_data.updated_time + 120 < current_time: # delete task which finished 2 minutes ago
will_remove_keys.append(key)
for key in will_remove_keys:
del(self.task_map[key])
time.sleep(0.1)
self.logger.info("TaskQueue GC Stopped.")
# Event loop
def _event_loop(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
while self.core_running:
self.event_trigger.wait()
while len(self.event_list) > 0:
event = self.event_list[0]
try:
self.on_update(event)
except Exception as err:
print(err)
self.event_list.pop(0)
self.event_trigger.clear()
self.logger.info("TaskQueue Event Stopped.")
Loading…
Cancel
Save