|
|
|
@ -2,6 +2,7 @@ from asyncore import socket_map
|
|
|
|
|
from concurrent.futures import thread
|
|
|
|
|
import os
|
|
|
|
|
import re
|
|
|
|
|
import signal
|
|
|
|
|
from types import LambdaType
|
|
|
|
|
from async_timeout import asyncio
|
|
|
|
|
from fastapi import FastAPI, Request, Depends, WebSocket, WebSocketDisconnect
|
|
|
|
@ -12,6 +13,7 @@ from fastapi.exceptions import HTTPException
|
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
|
from starlette.responses import FileResponse
|
|
|
|
|
from starlette.websockets import WebSocketState
|
|
|
|
|
from controller import TaskDataSocketController
|
|
|
|
|
from hydra_node.config import init_config_model
|
|
|
|
|
from hydra_node.models import EmbedderModel, torch_gc
|
|
|
|
@ -32,6 +34,7 @@ from PIL.PngImagePlugin import PngInfo
|
|
|
|
|
import json
|
|
|
|
|
|
|
|
|
|
genlock = threading.Lock()
|
|
|
|
|
running = True
|
|
|
|
|
|
|
|
|
|
MAX_N_SAMPLES = int(os.getenv("MAX_N_SAMPLES", 100))
|
|
|
|
|
TOKEN = os.getenv("TOKEN", None)
|
|
|
|
@ -97,6 +100,8 @@ def startup_event():
|
|
|
|
|
def shutdown_event():
|
|
|
|
|
logger.info("FastAPI Shutdown, exiting")
|
|
|
|
|
queue.stop()
|
|
|
|
|
time.sleep(2)
|
|
|
|
|
os.kill(mainpid, signal.SIGTERM)
|
|
|
|
|
|
|
|
|
|
class Masker(TypedDict):
|
|
|
|
|
seed: int
|
|
|
|
@ -274,7 +279,7 @@ async def generate_stream(request: GenerationRequest, authorized: bool = Depends
|
|
|
|
|
raise err
|
|
|
|
|
|
|
|
|
|
images_encoded = []
|
|
|
|
|
while True:
|
|
|
|
|
while running:
|
|
|
|
|
task_data = queue.get_task_data(task_id)
|
|
|
|
|
if not task_data:
|
|
|
|
|
raise Exception("Task not found")
|
|
|
|
@ -452,7 +457,7 @@ async def generate(request: GenerationRequest, authorized: bool = Depends(verify
|
|
|
|
|
raise err
|
|
|
|
|
|
|
|
|
|
images_encoded = []
|
|
|
|
|
while True:
|
|
|
|
|
while running:
|
|
|
|
|
task_data = queue.get_task_data(task_id)
|
|
|
|
|
if not task_data:
|
|
|
|
|
raise Exception("Task not found")
|
|
|
|
@ -614,7 +619,6 @@ async def get_task_info(request: TaskIdRequest):
|
|
|
|
|
|
|
|
|
|
@app.websocket('/task-info')
|
|
|
|
|
async def websocket_endpoint(websocket: WebSocket, task_id: str):
|
|
|
|
|
logger.info("websocket request task_id: %s" % task_id)
|
|
|
|
|
await websocket.accept()
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
@ -634,12 +638,15 @@ async def websocket_endpoint(websocket: WebSocket, task_id: str):
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
task_ws_controller.add(task_id, websocket)
|
|
|
|
|
while True:
|
|
|
|
|
while running and websocket.application_state == WebSocketState.CONNECTED and websocket.client_state == WebSocketState.CONNECTED:
|
|
|
|
|
data = await websocket.receive_text()
|
|
|
|
|
if data == "ping":
|
|
|
|
|
await websocket.send_json({
|
|
|
|
|
"pong": "pong"
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
if websocket.application_state == WebSocketState.CONNECTED and websocket.client_state == WebSocketState.CONNECTED:
|
|
|
|
|
await websocket.close()
|
|
|
|
|
except WebSocketDisconnect:
|
|
|
|
|
task_ws_controller.remove(task_id, websocket)
|
|
|
|
|
|
|
|
|
|