|
|
|
@ -1,13 +1,10 @@
|
|
|
|
|
from asyncore import socket_map
|
|
|
|
|
from concurrent.futures import thread
|
|
|
|
|
import os
|
|
|
|
|
import random
|
|
|
|
|
import re
|
|
|
|
|
import string
|
|
|
|
|
import sys
|
|
|
|
|
from types import LambdaType
|
|
|
|
|
from aiohttp import request
|
|
|
|
|
from async_timeout import asyncio
|
|
|
|
|
from fastapi import FastAPI, Request, Depends
|
|
|
|
|
from fastapi import FastAPI, Request, Depends, WebSocket, WebSocketDisconnect
|
|
|
|
|
from numpy import number
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
from fastapi.responses import HTMLResponse, PlainTextResponse, Response
|
|
|
|
@ -15,18 +12,18 @@ from fastapi.exceptions import HTTPException
|
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
|
from starlette.responses import FileResponse
|
|
|
|
|
from controller import TaskDataSocketController
|
|
|
|
|
from hydra_node.config import init_config_model
|
|
|
|
|
from hydra_node.models import EmbedderModel, torch_gc
|
|
|
|
|
from typing import Optional, List, Any
|
|
|
|
|
from typing import Dict, List, Union
|
|
|
|
|
from typing_extensions import TypedDict
|
|
|
|
|
import socket
|
|
|
|
|
from hydra_node.sanitize import sanitize_input
|
|
|
|
|
from task import TaskData, TaskQueue, TaskQueueFullException
|
|
|
|
|
import uvicorn
|
|
|
|
|
from typing import Union, Dict
|
|
|
|
|
import time
|
|
|
|
|
import gc
|
|
|
|
|
import io
|
|
|
|
|
import signal
|
|
|
|
|
import base64
|
|
|
|
|
import traceback
|
|
|
|
|
import threading
|
|
|
|
@ -56,169 +53,21 @@ mainpid = config.mainpid
|
|
|
|
|
hostname = socket.gethostname()
|
|
|
|
|
sent_first_message = False
|
|
|
|
|
|
|
|
|
|
class TaskQueueFullException(Exception):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
class TaskData(TypedDict):
|
|
|
|
|
task_id: str
|
|
|
|
|
position: str
|
|
|
|
|
status: str
|
|
|
|
|
updated_time: float
|
|
|
|
|
callback: LambdaType
|
|
|
|
|
current_step: int
|
|
|
|
|
total_steps: int
|
|
|
|
|
payload: Any
|
|
|
|
|
response: Any
|
|
|
|
|
|
|
|
|
|
class TaskQueue:
|
|
|
|
|
max_queue_size = 10
|
|
|
|
|
recovery_queue_size = 6
|
|
|
|
|
is_busy = False
|
|
|
|
|
|
|
|
|
|
loop_thread = None
|
|
|
|
|
gc_loop_thread = None
|
|
|
|
|
running = False
|
|
|
|
|
gc_running = False
|
|
|
|
|
|
|
|
|
|
lock = threading.Lock()
|
|
|
|
|
|
|
|
|
|
queued_task_list: List[TaskData] = []
|
|
|
|
|
task_map: Dict[str, TaskData] = {}
|
|
|
|
|
|
|
|
|
|
def __init__(self, max_queue_size = 10, recovery_queue_size = 6) -> None:
|
|
|
|
|
self.max_queue_size = max_queue_size
|
|
|
|
|
self.max_queue_size = recovery_queue_size
|
|
|
|
|
|
|
|
|
|
self.gc_loop_thread = threading.Thread(name="TaskQueueGC", target=self._gc_loop)
|
|
|
|
|
self.gc_running = True
|
|
|
|
|
self.gc_loop_thread.start()
|
|
|
|
|
|
|
|
|
|
logger.info("Task queue created")
|
|
|
|
|
|
|
|
|
|
def add_task(self, callback: LambdaType, payload: Any = {}) -> str:
|
|
|
|
|
if self.is_busy:
|
|
|
|
|
raise TaskQueueFullException("Task queue is full")
|
|
|
|
|
if len(self.queued_task_list) >= self.max_queue_size: # mark busy
|
|
|
|
|
self.is_busy = True
|
|
|
|
|
raise TaskQueueFullException("Task queue is full")
|
|
|
|
|
|
|
|
|
|
task_id = ''.join(random.sample(string.ascii_letters + string.digits, 16))
|
|
|
|
|
task = TaskData(
|
|
|
|
|
task_id=task_id,
|
|
|
|
|
position=0,
|
|
|
|
|
status="queued",
|
|
|
|
|
updated_time=time.time(),
|
|
|
|
|
callback=callback,
|
|
|
|
|
current_step=0,
|
|
|
|
|
total_steps=0,
|
|
|
|
|
payload=payload
|
|
|
|
|
)
|
|
|
|
|
with self.lock:
|
|
|
|
|
self.queued_task_list.append(task)
|
|
|
|
|
task["position"] = len(self.queued_task_list) - 1
|
|
|
|
|
logger.info("Added task: %s, queue size: %d" % (task_id, len(self.queued_task_list)))
|
|
|
|
|
# create index
|
|
|
|
|
with self.lock:
|
|
|
|
|
self.task_map[task_id] = task
|
|
|
|
|
|
|
|
|
|
self._start()
|
|
|
|
|
|
|
|
|
|
return task_id
|
|
|
|
|
|
|
|
|
|
def get_task_data(self, task_id: str) -> Union[TaskData, bool]:
|
|
|
|
|
if task_id in self.task_map:
|
|
|
|
|
return self.task_map[task_id]
|
|
|
|
|
else:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def delete_task_data(self, task_id: str) -> bool:
|
|
|
|
|
if task_id in self.task_map:
|
|
|
|
|
with self.lock:
|
|
|
|
|
del(self.task_map[task_id])
|
|
|
|
|
else:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def stop(self):
|
|
|
|
|
self.running = False
|
|
|
|
|
self.gc_running = False
|
|
|
|
|
self.queued_task_list = []
|
|
|
|
|
|
|
|
|
|
def _start(self):
|
|
|
|
|
with self.lock:
|
|
|
|
|
if self.running == False:
|
|
|
|
|
self.loop_thread = threading.Thread(name="TaskQueue", target=self._loop)
|
|
|
|
|
self.running = True
|
|
|
|
|
self.loop_thread.start()
|
|
|
|
|
|
|
|
|
|
def _loop(self):
|
|
|
|
|
while self.running and len(self.queued_task_list) > 0:
|
|
|
|
|
current_task = self.queued_task_list[0]
|
|
|
|
|
logger.info("Start task:%s." % current_task["task_id"])
|
|
|
|
|
try:
|
|
|
|
|
current_task["status"] = "running"
|
|
|
|
|
current_task["updated_time"] = time.time()
|
|
|
|
|
|
|
|
|
|
# run task
|
|
|
|
|
res = current_task["callback"](current_task)
|
|
|
|
|
|
|
|
|
|
# call gc
|
|
|
|
|
gc.collect()
|
|
|
|
|
torch_gc()
|
|
|
|
|
|
|
|
|
|
current_task["status"] = "finished"
|
|
|
|
|
current_task["updated_time"] = time.time()
|
|
|
|
|
current_task["response"] = res
|
|
|
|
|
current_task["current_step"] = current_task["total_steps"]
|
|
|
|
|
except Exception as e:
|
|
|
|
|
current_task["status"] = "error"
|
|
|
|
|
current_task["updated_time"] = time.time()
|
|
|
|
|
current_task["response"] = e
|
|
|
|
|
gc.collect()
|
|
|
|
|
e_s = str(e)
|
|
|
|
|
if "CUDA out of memory" in e_s or \
|
|
|
|
|
"an illegal memory access" in e_s or "CUDA" in e_s:
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
logger.error(str(e))
|
|
|
|
|
logger.error("GPU error in task.")
|
|
|
|
|
torch_gc()
|
|
|
|
|
current_task["response"] = Exception("GPU error in task.")
|
|
|
|
|
# os.kill(mainpid, signal.SIGTERM)
|
|
|
|
|
|
|
|
|
|
with self.lock:
|
|
|
|
|
self.queued_task_list.pop(0)
|
|
|
|
|
# reset queue position
|
|
|
|
|
for i in range(0, len(self.queued_task_list)):
|
|
|
|
|
self.queued_task_list[i]["position"] = i
|
|
|
|
|
|
|
|
|
|
if self.is_busy and len(self.queued_task_list) < self.recovery_queue_size: # mark not busy
|
|
|
|
|
self.is_busy = False
|
|
|
|
|
|
|
|
|
|
logger.info("Task:%s finished, queue size: %d" % (current_task["task_id"], len(self.queued_task_list)))
|
|
|
|
|
|
|
|
|
|
with self.lock:
|
|
|
|
|
self.running = False
|
|
|
|
|
|
|
|
|
|
# Task to remove finished task
|
|
|
|
|
def _gc_loop(self):
|
|
|
|
|
while self.gc_running:
|
|
|
|
|
current_time = time.time()
|
|
|
|
|
with self.lock:
|
|
|
|
|
will_remove_keys = []
|
|
|
|
|
for (key, task_data) in self.task_map.items():
|
|
|
|
|
if task_data["status"] == "finished" or task_data["status"] == "error":
|
|
|
|
|
if task_data["updated_time"] + 120 < current_time: # delete task which finished 2 minutes ago
|
|
|
|
|
will_remove_keys.append(key)
|
|
|
|
|
|
|
|
|
|
for key in will_remove_keys:
|
|
|
|
|
del(self.task_map[key])
|
|
|
|
|
|
|
|
|
|
time.sleep(1)
|
|
|
|
|
|
|
|
|
|
queue = TaskQueue(
|
|
|
|
|
logger=logger,
|
|
|
|
|
max_queue_size=int(os.getenv("QUEUE_MAX_SIZE", 10)),
|
|
|
|
|
recovery_queue_size=int(os.getenv("QUEUE_RECOVERY_SIZE", 6))
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
task_ws_controller = TaskDataSocketController()
|
|
|
|
|
|
|
|
|
|
def on_task_update(task_info: Dict):
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
|
loop.run_until_complete(task_ws_controller.emit(task_info["task_id"], task_info))
|
|
|
|
|
|
|
|
|
|
queue.on_update = on_task_update
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def verify_token(req: Request):
|
|
|
|
|
if TOKEN:
|
|
|
|
|
valid = "Authorization" in req.headers and req.headers["Authorization"] == "Bearer "+TOKEN
|
|
|
|
@ -335,13 +184,17 @@ def saveimage(image, request):
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print("failed to save image:", e)
|
|
|
|
|
|
|
|
|
|
def _generate_stream(request: GenerationRequest, task_info: TaskData = TaskData()):
|
|
|
|
|
def _generate_stream(request: GenerationRequest, task_info: TaskData):
|
|
|
|
|
try:
|
|
|
|
|
task_info["total_steps"] = request.steps + 1
|
|
|
|
|
task_info.update({
|
|
|
|
|
"total_steps": request.steps + 1
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
def _on_step(step_num, total_steps):
|
|
|
|
|
task_info["total_steps"] = total_steps + 1
|
|
|
|
|
task_info["current_step"] = step_num
|
|
|
|
|
task_info.update({
|
|
|
|
|
"total_steps": total_steps + 1,
|
|
|
|
|
"current_step": step_num
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
if request.advanced:
|
|
|
|
|
if request.n_samples > 1:
|
|
|
|
@ -378,7 +231,9 @@ def _generate_stream(request: GenerationRequest, task_info: TaskData = TaskData(
|
|
|
|
|
image = base64.b64encode(image).decode("ascii")
|
|
|
|
|
images_encoded.append(image)
|
|
|
|
|
|
|
|
|
|
task_info["current_step"] += 1
|
|
|
|
|
task_info.update({
|
|
|
|
|
"current_step": task_info.current_step + 1
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
del images
|
|
|
|
|
|
|
|
|
@ -423,11 +278,11 @@ async def generate_stream(request: GenerationRequest, authorized: bool = Depends
|
|
|
|
|
task_data = queue.get_task_data(task_id)
|
|
|
|
|
if not task_data:
|
|
|
|
|
raise Exception("Task not found")
|
|
|
|
|
if task_data["status"] == "finished":
|
|
|
|
|
images_encoded = task_data["response"]
|
|
|
|
|
if task_data.status == "finished":
|
|
|
|
|
images_encoded = task_data.response
|
|
|
|
|
break
|
|
|
|
|
elif task_data["status"] == "error":
|
|
|
|
|
return {"error": str(task_data["response"])}
|
|
|
|
|
elif task_data.status == "error":
|
|
|
|
|
return {"error": str(task_data.response)}
|
|
|
|
|
await asyncio.sleep(0.1)
|
|
|
|
|
|
|
|
|
|
process_time = time.perf_counter() - t
|
|
|
|
@ -493,8 +348,8 @@ async def generate_stream_output(request: TaskIdRequest, authorized: bool = Depe
|
|
|
|
|
task_data = queue.get_task_data(task_id)
|
|
|
|
|
if not task_data:
|
|
|
|
|
return ErrorOutput(error="Task not found.")
|
|
|
|
|
if task_data["status"] == "finished":
|
|
|
|
|
images_encoded = task_data["response"]
|
|
|
|
|
if task_data.status == "finished":
|
|
|
|
|
images_encoded = task_data.response
|
|
|
|
|
|
|
|
|
|
data = ""
|
|
|
|
|
ptr = 0
|
|
|
|
@ -503,8 +358,8 @@ async def generate_stream_output(request: TaskIdRequest, authorized: bool = Depe
|
|
|
|
|
data += ("event: newImage\nid: {}\ndata:{}\n\n").format(ptr, x)
|
|
|
|
|
return Response(content=data, media_type="text/event-stream")
|
|
|
|
|
|
|
|
|
|
elif task_data["status"] == "error":
|
|
|
|
|
raise task_data["response"]
|
|
|
|
|
elif task_data.status == "error":
|
|
|
|
|
raise task_data.response
|
|
|
|
|
else:
|
|
|
|
|
return ErrorOutput(error="Task is not finished.")
|
|
|
|
|
|
|
|
|
@ -522,11 +377,15 @@ async def generate_stream_output(request: TaskIdRequest, authorized: bool = Depe
|
|
|
|
|
|
|
|
|
|
def _generate(request: GenerationRequest, task_info: TaskData):
|
|
|
|
|
try:
|
|
|
|
|
task_info["total_steps"] = request.steps + 1
|
|
|
|
|
task_info.update({
|
|
|
|
|
"total_steps": request.steps + 1
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
def _on_step(step_num, total_steps):
|
|
|
|
|
task_info["total_steps"] = total_steps + 1
|
|
|
|
|
task_info["current_step"] = step_num
|
|
|
|
|
task_info.update({
|
|
|
|
|
"total_steps": total_steps + 1,
|
|
|
|
|
"current_step": step_num
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
images = model.sample(request, callback=_on_step)
|
|
|
|
|
|
|
|
|
@ -551,7 +410,9 @@ def _generate(request: GenerationRequest, task_info: TaskData):
|
|
|
|
|
image = base64.b64encode(image).decode("ascii")
|
|
|
|
|
images_encoded.append(image)
|
|
|
|
|
|
|
|
|
|
task_info["current_step"] += 1
|
|
|
|
|
task_info.update({
|
|
|
|
|
"current_step": task_info.current_step + 1
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
del images
|
|
|
|
|
|
|
|
|
@ -595,11 +456,11 @@ async def generate(request: GenerationRequest, authorized: bool = Depends(verify
|
|
|
|
|
task_data = queue.get_task_data(task_id)
|
|
|
|
|
if not task_data:
|
|
|
|
|
raise Exception("Task not found")
|
|
|
|
|
if task_data["status"] == "finished":
|
|
|
|
|
images_encoded = task_data["response"]
|
|
|
|
|
if task_data.status == "finished":
|
|
|
|
|
images_encoded = task_data.response
|
|
|
|
|
break
|
|
|
|
|
elif task_data["status"] == "error":
|
|
|
|
|
return {"error": str(task_data["response"])}
|
|
|
|
|
elif task_data.status == "error":
|
|
|
|
|
return {"error": str(task_data.response)}
|
|
|
|
|
await asyncio.sleep(0.1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -660,13 +521,13 @@ async def generate_stream_output(request: TaskIdRequest, authorized: bool = Depe
|
|
|
|
|
task_data = queue.get_task_data(task_id)
|
|
|
|
|
if not task_data:
|
|
|
|
|
return ErrorOutput(error="Task not found.")
|
|
|
|
|
if task_data["status"] == "finished":
|
|
|
|
|
images_encoded = task_data["response"]
|
|
|
|
|
if task_data.status == "finished":
|
|
|
|
|
images_encoded = task_data.response
|
|
|
|
|
|
|
|
|
|
return GenerationOutput(output=images_encoded)
|
|
|
|
|
|
|
|
|
|
elif task_data["status"] == "error":
|
|
|
|
|
raise task_data["response"]
|
|
|
|
|
elif task_data.status == "error":
|
|
|
|
|
raise task_data.response
|
|
|
|
|
else:
|
|
|
|
|
return ErrorOutput(error="Task is not finished.")
|
|
|
|
|
|
|
|
|
@ -743,13 +604,44 @@ async def get_task_info(request: TaskIdRequest):
|
|
|
|
|
task_data = queue.get_task_data(request.task_id)
|
|
|
|
|
if task_data:
|
|
|
|
|
return TaskDataOutput(
|
|
|
|
|
status=task_data["status"],
|
|
|
|
|
position=task_data["position"],
|
|
|
|
|
current_step=task_data["current_step"],
|
|
|
|
|
total_steps=task_data["total_steps"]
|
|
|
|
|
status=task_data.status,
|
|
|
|
|
position=task_data.position,
|
|
|
|
|
current_step=task_data.current_step,
|
|
|
|
|
total_steps=task_data.total_steps
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
return ErrorOutput(error="Cannot find current task.")
|
|
|
|
|
return ErrorOutput(error="Cannot find current task.", code="ERR::TASK_NOT_FOUND")
|
|
|
|
|
|
|
|
|
|
@app.websocket('/task-info')
|
|
|
|
|
async def websocket_endpoint(websocket: WebSocket, task_id: str):
|
|
|
|
|
logger.info("websocket request task_id: %s" % task_id)
|
|
|
|
|
await websocket.accept()
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
task_data = queue.get_task_data(task_id)
|
|
|
|
|
if not task_data:
|
|
|
|
|
await websocket.send_json({
|
|
|
|
|
"error": "ERR::TASK_NOT_FOUND",
|
|
|
|
|
"message": "Cannot find current task."
|
|
|
|
|
})
|
|
|
|
|
await websocket.close()
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# Send current task info
|
|
|
|
|
await websocket.send_json(task_data.to_dict())
|
|
|
|
|
if task_data.status == "finished" or task_data.status == "error":
|
|
|
|
|
await websocket.close()
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
task_ws_controller.add(task_id, websocket)
|
|
|
|
|
while True:
|
|
|
|
|
data = await websocket.receive_text()
|
|
|
|
|
if data == "ping":
|
|
|
|
|
await websocket.send_json({
|
|
|
|
|
"pong": "pong"
|
|
|
|
|
})
|
|
|
|
|
except WebSocketDisconnect:
|
|
|
|
|
task_ws_controller.remove(task_id, websocket)
|
|
|
|
|
|
|
|
|
|
@app.get('/')
|
|
|
|
|
def index():
|
|
|
|
|