修复多线程调度问题

master
落雨楓 2 years ago
parent f39f24545b
commit 330fdcd933

@ -27,17 +27,16 @@ class TaskDataSocketController:
else: else:
return False return False
async def close_all(self, channel: str) -> int: async def close_all(self) -> int:
if channel in self.socket_map: closedCount = 0
sended = 0 for (_, sockets) in self.socket_map.items():
socket_list = self.socket_map[channel] for socket in sockets:
for socket in socket_list: try:
await socket.close() await socket.close()
sended += 1 except Exception as err:
print(err)
return sended closedCount += 1
else: return closedCount
return 0
async def emit(self, channel: str, data = {}) -> int: async def emit(self, channel: str, data = {}) -> int:
if channel in self.socket_map: if channel in self.socket_map:

@ -2,6 +2,7 @@ from asyncore import socket_map
from concurrent.futures import thread from concurrent.futures import thread
import os import os
import re import re
import signal
from types import LambdaType from types import LambdaType
from async_timeout import asyncio from async_timeout import asyncio
from fastapi import FastAPI, Request, Depends, WebSocket, WebSocketDisconnect from fastapi import FastAPI, Request, Depends, WebSocket, WebSocketDisconnect
@ -12,6 +13,7 @@ from fastapi.exceptions import HTTPException
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse from starlette.responses import FileResponse
from starlette.websockets import WebSocketState
from controller import TaskDataSocketController from controller import TaskDataSocketController
from hydra_node.config import init_config_model from hydra_node.config import init_config_model
from hydra_node.models import EmbedderModel, torch_gc from hydra_node.models import EmbedderModel, torch_gc
@ -32,6 +34,7 @@ from PIL.PngImagePlugin import PngInfo
import json import json
genlock = threading.Lock() genlock = threading.Lock()
running = True
MAX_N_SAMPLES = int(os.getenv("MAX_N_SAMPLES", 100)) MAX_N_SAMPLES = int(os.getenv("MAX_N_SAMPLES", 100))
TOKEN = os.getenv("TOKEN", None) TOKEN = os.getenv("TOKEN", None)
@ -97,6 +100,8 @@ def startup_event():
def shutdown_event(): def shutdown_event():
logger.info("FastAPI Shutdown, exiting") logger.info("FastAPI Shutdown, exiting")
queue.stop() queue.stop()
time.sleep(2)
os.kill(mainpid, signal.SIGTERM)
class Masker(TypedDict): class Masker(TypedDict):
seed: int seed: int
@ -274,7 +279,7 @@ async def generate_stream(request: GenerationRequest, authorized: bool = Depends
raise err raise err
images_encoded = [] images_encoded = []
while True: while running:
task_data = queue.get_task_data(task_id) task_data = queue.get_task_data(task_id)
if not task_data: if not task_data:
raise Exception("Task not found") raise Exception("Task not found")
@ -452,7 +457,7 @@ async def generate(request: GenerationRequest, authorized: bool = Depends(verify
raise err raise err
images_encoded = [] images_encoded = []
while True: while running:
task_data = queue.get_task_data(task_id) task_data = queue.get_task_data(task_id)
if not task_data: if not task_data:
raise Exception("Task not found") raise Exception("Task not found")
@ -614,7 +619,6 @@ async def get_task_info(request: TaskIdRequest):
@app.websocket('/task-info') @app.websocket('/task-info')
async def websocket_endpoint(websocket: WebSocket, task_id: str): async def websocket_endpoint(websocket: WebSocket, task_id: str):
logger.info("websocket request task_id: %s" % task_id)
await websocket.accept() await websocket.accept()
try: try:
@ -634,12 +638,15 @@ async def websocket_endpoint(websocket: WebSocket, task_id: str):
return return
task_ws_controller.add(task_id, websocket) task_ws_controller.add(task_id, websocket)
while True: while running and websocket.application_state == WebSocketState.CONNECTED and websocket.client_state == WebSocketState.CONNECTED:
data = await websocket.receive_text() data = await websocket.receive_text()
if data == "ping": if data == "ping":
await websocket.send_json({ await websocket.send_json({
"pong": "pong" "pong": "pong"
}) })
if websocket.application_state == WebSocketState.CONNECTED and websocket.client_state == WebSocketState.CONNECTED:
await websocket.close()
except WebSocketDisconnect: except WebSocketDisconnect:
task_ws_controller.remove(task_id, websocket) task_ws_controller.remove(task_id, websocket)

@ -150,9 +150,10 @@ class TaskQueue:
self.event_list = [] self.event_list = []
self.event_trigger.set() self.event_trigger.set()
# in thread # in task thread
def _on_task_update(self, task: TaskData, updated_value: Dict): def _on_task_update(self, task: TaskData, updated_value: Dict):
self.event_list.append(task.to_dict()) with self.event_lock:
self.event_list.append(task.to_dict())
self.event_trigger.set() self.event_trigger.set()
def _start(self): def _start(self):
@ -207,14 +208,15 @@ class TaskQueue:
with self.lock: with self.lock:
self.queued_task_list.pop(0) self.queued_task_list.pop(0)
# reset queue position
for i in range(0, len(self.queued_task_list)): # reset queue position
self.queued_task_list[i].update({ for i in range(0, len(self.queued_task_list)):
"position": i self.queued_task_list[i].update({
}) "position": i
})
if self.is_busy and len(self.queued_task_list) < self.recovery_queue_size: # mark not busy
self.is_busy = False if self.is_busy and len(self.queued_task_list) < self.recovery_queue_size: # mark not busy
self.is_busy = False
self.logger.info("Task%s finished, queue size: %d" % (current_task.task_id, len(self.queued_task_list))) self.logger.info("Task%s finished, queue size: %d" % (current_task.task_id, len(self.queued_task_list)))
@ -226,17 +228,20 @@ class TaskQueue:
while self.core_running: while self.core_running:
current_time = time.time() current_time = time.time()
with self.lock: with self.lock:
will_remove_keys = [] task_map = self.task_map.copy()
for (key, task_data) in self.task_map.items():
if task_data.status == "finished" or task_data.status == "error": will_remove_keys = []
if task_data.updated_time + 120 < current_time: # delete task which finished 2 minutes ago for (key, task_data) in task_map.items():
will_remove_keys.append(key) if task_data.status == "finished" or task_data.status == "error":
if task_data.updated_time + 120 < current_time: # delete task which finished 2 minutes ago
for key in will_remove_keys: will_remove_keys.append(key)
del(self.task_map[key])
if len(will_remove_keys) > 0:
with self.lock:
for key in will_remove_keys:
del(self.task_map[key])
time.sleep(0.1) time.sleep(1)
self.logger.info("TaskQueue GC Stopped.")
# Event loop # Event loop
def _event_loop(self): def _event_loop(self):
@ -250,7 +255,9 @@ class TaskQueue:
self.on_update(event) self.on_update(event)
except Exception as err: except Exception as err:
print(err) print(err)
self.event_list.pop(0)
with self.event_lock:
self.event_list.pop(0)
self.event_trigger.clear() self.event_trigger.clear()
self.logger.info("TaskQueue Event Stopped.") loop.stop()

Loading…
Cancel
Save