修复多线程调度问题

master
落雨楓 2 years ago
parent f39f24545b
commit 330fdcd933

@ -27,17 +27,16 @@ class TaskDataSocketController:
else:
return False
async def close_all(self, channel: str) -> int:
if channel in self.socket_map:
sended = 0
socket_list = self.socket_map[channel]
for socket in socket_list:
await socket.close()
sended += 1
return sended
else:
return 0
async def close_all(self) -> int:
closedCount = 0
for (_, sockets) in self.socket_map.items():
for socket in sockets:
try:
await socket.close()
except Exception as err:
print(err)
closedCount += 1
return closedCount
async def emit(self, channel: str, data = {}) -> int:
if channel in self.socket_map:

@ -2,6 +2,7 @@ from asyncore import socket_map
from concurrent.futures import thread
import os
import re
import signal
from types import LambdaType
from async_timeout import asyncio
from fastapi import FastAPI, Request, Depends, WebSocket, WebSocketDisconnect
@ -12,6 +13,7 @@ from fastapi.exceptions import HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse
from starlette.websockets import WebSocketState
from controller import TaskDataSocketController
from hydra_node.config import init_config_model
from hydra_node.models import EmbedderModel, torch_gc
@ -32,6 +34,7 @@ from PIL.PngImagePlugin import PngInfo
import json
genlock = threading.Lock()
running = True
MAX_N_SAMPLES = int(os.getenv("MAX_N_SAMPLES", 100))
TOKEN = os.getenv("TOKEN", None)
@ -97,6 +100,8 @@ def startup_event():
def shutdown_event():
logger.info("FastAPI Shutdown, exiting")
queue.stop()
time.sleep(2)
os.kill(mainpid, signal.SIGTERM)
class Masker(TypedDict):
seed: int
@ -274,7 +279,7 @@ async def generate_stream(request: GenerationRequest, authorized: bool = Depends
raise err
images_encoded = []
while True:
while running:
task_data = queue.get_task_data(task_id)
if not task_data:
raise Exception("Task not found")
@ -452,7 +457,7 @@ async def generate(request: GenerationRequest, authorized: bool = Depends(verify
raise err
images_encoded = []
while True:
while running:
task_data = queue.get_task_data(task_id)
if not task_data:
raise Exception("Task not found")
@ -614,7 +619,6 @@ async def get_task_info(request: TaskIdRequest):
@app.websocket('/task-info')
async def websocket_endpoint(websocket: WebSocket, task_id: str):
logger.info("websocket request task_id: %s" % task_id)
await websocket.accept()
try:
@ -634,12 +638,15 @@ async def websocket_endpoint(websocket: WebSocket, task_id: str):
return
task_ws_controller.add(task_id, websocket)
while True:
while running and websocket.application_state == WebSocketState.CONNECTED and websocket.client_state == WebSocketState.CONNECTED:
data = await websocket.receive_text()
if data == "ping":
await websocket.send_json({
"pong": "pong"
})
if websocket.application_state == WebSocketState.CONNECTED and websocket.client_state == WebSocketState.CONNECTED:
await websocket.close()
except WebSocketDisconnect:
task_ws_controller.remove(task_id, websocket)

@ -150,9 +150,10 @@ class TaskQueue:
self.event_list = []
self.event_trigger.set()
# in thread
# in task thread
def _on_task_update(self, task: TaskData, updated_value: Dict):
self.event_list.append(task.to_dict())
with self.event_lock:
self.event_list.append(task.to_dict())
self.event_trigger.set()
def _start(self):
@ -207,14 +208,15 @@ class TaskQueue:
with self.lock:
self.queued_task_list.pop(0)
# reset queue position
for i in range(0, len(self.queued_task_list)):
self.queued_task_list[i].update({
"position": i
})
if self.is_busy and len(self.queued_task_list) < self.recovery_queue_size: # mark not busy
self.is_busy = False
# reset queue position
for i in range(0, len(self.queued_task_list)):
self.queued_task_list[i].update({
"position": i
})
if self.is_busy and len(self.queued_task_list) < self.recovery_queue_size: # mark not busy
self.is_busy = False
self.logger.info("Task%s finished, queue size: %d" % (current_task.task_id, len(self.queued_task_list)))
@ -226,17 +228,20 @@ class TaskQueue:
while self.core_running:
current_time = time.time()
with self.lock:
will_remove_keys = []
for (key, task_data) in self.task_map.items():
if task_data.status == "finished" or task_data.status == "error":
if task_data.updated_time + 120 < current_time: # delete task which finished 2 minutes ago
will_remove_keys.append(key)
for key in will_remove_keys:
del(self.task_map[key])
task_map = self.task_map.copy()
will_remove_keys = []
for (key, task_data) in task_map.items():
if task_data.status == "finished" or task_data.status == "error":
if task_data.updated_time + 120 < current_time: # delete task which finished 2 minutes ago
will_remove_keys.append(key)
if len(will_remove_keys) > 0:
with self.lock:
for key in will_remove_keys:
del(self.task_map[key])
time.sleep(0.1)
self.logger.info("TaskQueue GC Stopped.")
time.sleep(1)
# Event loop
def _event_loop(self):
@ -250,7 +255,9 @@ class TaskQueue:
self.on_update(event)
except Exception as err:
print(err)
self.event_list.pop(0)
with self.event_lock:
self.event_list.pop(0)
self.event_trigger.clear()
self.logger.info("TaskQueue Event Stopped.")
loop.stop()

Loading…
Cancel
Save