|
|
|
from asyncore import socket_map
|
|
|
|
from concurrent.futures import thread
|
|
|
|
import os
|
|
|
|
import re
|
|
|
|
import signal
|
|
|
|
from types import LambdaType
|
|
|
|
from async_timeout import asyncio
|
|
|
|
from fastapi import FastAPI, Request, Depends, WebSocket, WebSocketDisconnect
|
|
|
|
from numpy import number
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from fastapi.responses import HTMLResponse, PlainTextResponse, Response
|
|
|
|
from fastapi.exceptions import HTTPException
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
from starlette.responses import FileResponse
|
|
|
|
from starlette.websockets import WebSocketState
|
|
|
|
from controller import TaskDataSocketController
|
|
|
|
from hydra_node.config import init_config_model
|
|
|
|
from hydra_node.models import EmbedderModel, torch_gc
|
|
|
|
from typing import Dict, List, Union
|
|
|
|
from typing_extensions import TypedDict
|
|
|
|
import socket
|
|
|
|
from hydra_node.sanitize import sanitize_input
|
|
|
|
from task import TaskData, TaskQueue, TaskQueueFullException
|
|
|
|
import uvicorn
|
|
|
|
import time
|
|
|
|
import gc
|
|
|
|
import io
|
|
|
|
import base64
|
|
|
|
import traceback
|
|
|
|
import threading
|
|
|
|
from PIL import Image
|
|
|
|
from PIL.PngImagePlugin import PngInfo
|
|
|
|
import json
|
|
|
|
|
|
|
|
genlock = threading.Lock()
|
|
|
|
running = True
|
|
|
|
|
|
|
|
MAX_N_SAMPLES = int(os.getenv("MAX_N_SAMPLES", 100))
|
|
|
|
TOKEN = os.getenv("TOKEN", None)
|
|
|
|
print(f"Starting Hydra Node HTTP TOKEN={TOKEN}")
|
|
|
|
|
|
|
|
#Initialize model and config
|
|
|
|
model, config, model_hash = init_config_model()
|
|
|
|
try:
|
|
|
|
embedmodel = EmbedderModel()
|
|
|
|
except Exception as e:
|
|
|
|
print("couldn't load embed model, suggestions won't work:", e)
|
|
|
|
embedmodel = False
|
|
|
|
logger = config.logger
|
|
|
|
try:
|
|
|
|
config.mainpid = int(open("gunicorn.pid", "r").read())
|
|
|
|
except FileNotFoundError:
|
|
|
|
config.mainpid = os.getpid()
|
|
|
|
mainpid = config.mainpid
|
|
|
|
hostname = socket.gethostname()
|
|
|
|
sent_first_message = False
|
|
|
|
|
|
|
|
queue = TaskQueue(
|
|
|
|
logger=logger,
|
|
|
|
max_queue_size=int(os.getenv("QUEUE_MAX_SIZE", 10)),
|
|
|
|
recovery_queue_size=int(os.getenv("QUEUE_RECOVERY_SIZE", 6))
|
|
|
|
)
|
|
|
|
|
|
|
|
task_ws_controller = TaskDataSocketController()
|
|
|
|
|
|
|
|
def on_task_update(task_info: Dict):
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
loop.run_until_complete(task_ws_controller.emit(task_info["task_id"], task_info))
|
|
|
|
|
|
|
|
queue.on_update = on_task_update
|
|
|
|
|
|
|
|
|
|
|
|
def verify_token(req: Request):
|
|
|
|
if TOKEN:
|
|
|
|
valid = "Authorization" in req.headers and req.headers["Authorization"] == "Bearer "+TOKEN
|
|
|
|
if not valid:
|
|
|
|
raise HTTPException(
|
|
|
|
status_code=401,
|
|
|
|
detail="Unauthorized"
|
|
|
|
)
|
|
|
|
return True
|
|
|
|
|
|
|
|
#Initialize fastapi
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
|
|
app.add_middleware(
|
|
|
|
CORSMiddleware,
|
|
|
|
allow_origins=["*"],
|
|
|
|
allow_credentials=False,
|
|
|
|
allow_methods=["*"],
|
|
|
|
allow_headers=["*"],
|
|
|
|
)
|
|
|
|
|
|
|
|
@app.on_event("startup")
|
|
|
|
def startup_event():
|
|
|
|
logger.info("FastAPI Started, serving")
|
|
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
|
|
def shutdown_event():
|
|
|
|
logger.info("FastAPI Shutdown, exiting")
|
|
|
|
queue.stop()
|
|
|
|
time.sleep(2)
|
|
|
|
os.kill(mainpid, signal.SIGTERM)
|
|
|
|
|
|
|
|
class Masker(TypedDict):
|
|
|
|
seed: int
|
|
|
|
mask: str
|
|
|
|
|
|
|
|
class Tags(TypedDict):
|
|
|
|
tag: str
|
|
|
|
count: int
|
|
|
|
confidence: float
|
|
|
|
|
|
|
|
class GenerationRequest(BaseModel):
|
|
|
|
prompt: str
|
|
|
|
image: str = None
|
|
|
|
n_samples: int = 1
|
|
|
|
steps: int = 50
|
|
|
|
sampler: str = "plms"
|
|
|
|
fixed_code: bool = False
|
|
|
|
ddim_eta: float = 0.0
|
|
|
|
height: int = 512
|
|
|
|
width: int = 512
|
|
|
|
latent_channels: int = 4
|
|
|
|
downsampling_factor: int = 8
|
|
|
|
scale: float = 7.0
|
|
|
|
dynamic_threshold: float = None
|
|
|
|
seed: int = None
|
|
|
|
temp: float = 1.0
|
|
|
|
top_k: int = 256
|
|
|
|
grid_size: int = 4
|
|
|
|
advanced: bool = False
|
|
|
|
stage_two_seed: int = None
|
|
|
|
strength: float = 0.69
|
|
|
|
noise: float = 0.667
|
|
|
|
mitigate: bool = False
|
|
|
|
module: str = None
|
|
|
|
masks: List[Masker] = None
|
|
|
|
uc: str = None
|
|
|
|
|
|
|
|
class TaskIdOutput(BaseModel):
|
|
|
|
task_id: str
|
|
|
|
|
|
|
|
class TextRequest(BaseModel):
|
|
|
|
prompt: str
|
|
|
|
|
|
|
|
class TagOutput(BaseModel):
|
|
|
|
tags: List[Tags]
|
|
|
|
|
|
|
|
class TextOutput(BaseModel):
|
|
|
|
is_safe: str
|
|
|
|
corrected_text: str
|
|
|
|
|
|
|
|
class TaskIdRequest(BaseModel):
|
|
|
|
task_id: str
|
|
|
|
|
|
|
|
class TaskDataOutput(BaseModel):
|
|
|
|
status: str
|
|
|
|
position: int
|
|
|
|
current_step: int
|
|
|
|
total_steps: int
|
|
|
|
|
|
|
|
class GenerationOutput(BaseModel):
|
|
|
|
output: List[str]
|
|
|
|
|
|
|
|
class ErrorOutput(BaseModel):
|
|
|
|
error: str
|
|
|
|
|
|
|
|
def saveimage(image, request):
|
|
|
|
os.makedirs("images", exist_ok=True)
|
|
|
|
|
|
|
|
filename = request.prompt.replace('masterpiece, best quality, ', '')
|
|
|
|
filename = re.sub(r'[/\\<>:"|]', '', filename)
|
|
|
|
filename = filename[:128]
|
|
|
|
filename += f' s-{request.seed}'
|
|
|
|
filename = os.path.join("images", filename.strip())
|
|
|
|
|
|
|
|
for n in range(1000000):
|
|
|
|
suff = '.png'
|
|
|
|
if n:
|
|
|
|
suff = f'-{n}.png'
|
|
|
|
if not os.path.exists(filename + suff):
|
|
|
|
break
|
|
|
|
|
|
|
|
try:
|
|
|
|
with open(filename + suff, "wb") as f:
|
|
|
|
f.write(image)
|
|
|
|
except Exception as e:
|
|
|
|
print("failed to save image:", e)
|
|
|
|
|
|
|
|
def _generate_stream(request: GenerationRequest, task_info: TaskData):
|
|
|
|
try:
|
|
|
|
task_info.update({
|
|
|
|
"total_steps": request.steps + 1
|
|
|
|
})
|
|
|
|
|
|
|
|
def _on_step(step_num, total_steps):
|
|
|
|
task_info.update({
|
|
|
|
"total_steps": total_steps + 1,
|
|
|
|
"current_step": step_num
|
|
|
|
})
|
|
|
|
|
|
|
|
if request.advanced:
|
|
|
|
if request.n_samples > 1:
|
|
|
|
return ErrorOutput(error="advanced mode does not support n_samples > 1")
|
|
|
|
|
|
|
|
images = model.sample_two_stages(request, callback=_on_step)
|
|
|
|
else:
|
|
|
|
images = model.sample(request, callback=_on_step)
|
|
|
|
|
|
|
|
logger.info("Sample finished.")
|
|
|
|
|
|
|
|
seed = request.seed
|
|
|
|
|
|
|
|
images_encoded = []
|
|
|
|
for x in range(len(images)):
|
|
|
|
if seed is not None:
|
|
|
|
request.seed = seed
|
|
|
|
seed += 1
|
|
|
|
comment = json.dumps({"steps":request.steps,"sampler":request.sampler,"seed":request.seed,"strength":request.strength,"noise":request.noise,"scale":request.scale,"uc":request.uc})
|
|
|
|
metadata = PngInfo()
|
|
|
|
metadata.add_text("Title", "AI generated image")
|
|
|
|
metadata.add_text("Description", request.prompt)
|
|
|
|
metadata.add_text("Software", "NovelAI")
|
|
|
|
metadata.add_text("Source", "Stable Diffusion "+model_hash)
|
|
|
|
metadata.add_text("Comment", comment)
|
|
|
|
image = Image.fromarray(images[x])
|
|
|
|
#save pillow image with bytesIO
|
|
|
|
output = io.BytesIO()
|
|
|
|
image.save(output, format='PNG', pnginfo=metadata)
|
|
|
|
image = output.getvalue()
|
|
|
|
if config.savefiles:
|
|
|
|
saveimage(image, request)
|
|
|
|
#get base64 of image
|
|
|
|
image = base64.b64encode(image).decode("ascii")
|
|
|
|
images_encoded.append(image)
|
|
|
|
|
|
|
|
task_info.update({
|
|
|
|
"current_step": task_info.current_step + 1
|
|
|
|
})
|
|
|
|
|
|
|
|
del images
|
|
|
|
|
|
|
|
logger.info("Images encoded.")
|
|
|
|
|
|
|
|
return images_encoded
|
|
|
|
except Exception as e:
|
|
|
|
traceback.print_exc()
|
|
|
|
logger.error(str(e))
|
|
|
|
e_s = str(e)
|
|
|
|
gc.collect()
|
|
|
|
if "CUDA out of memory" in e_s or \
|
|
|
|
"an illegal memory access" in e_s or "CUDA" in e_s:
|
|
|
|
torch_gc()
|
|
|
|
# logger.error("GPU error, committing seppuku.")
|
|
|
|
# os.kill(mainpid, signal.SIGTERM)
|
|
|
|
else:
|
|
|
|
raise e
|
|
|
|
|
|
|
|
@app.post('/generate-stream')
|
|
|
|
async def generate_stream(request: GenerationRequest, authorized: bool = Depends(verify_token)):
|
|
|
|
t = time.perf_counter()
|
|
|
|
try:
|
|
|
|
request.n_samples = min(request.n_samples, MAX_N_SAMPLES) # 应用生成上限
|
|
|
|
output = sanitize_input(config, request)
|
|
|
|
|
|
|
|
if output[0]:
|
|
|
|
request = output[1]
|
|
|
|
else:
|
|
|
|
return ErrorOutput(error=output[1])
|
|
|
|
|
|
|
|
task_id = None
|
|
|
|
try:
|
|
|
|
task_id = queue.add_task(lambda t: _generate_stream(request, t), request)
|
|
|
|
except TaskQueueFullException:
|
|
|
|
return ErrorOutput(error="Task queue is full, please wait for minutes.")
|
|
|
|
except Exception as err:
|
|
|
|
raise err
|
|
|
|
|
|
|
|
images_encoded = []
|
|
|
|
while running:
|
|
|
|
task_data = queue.get_task_data(task_id)
|
|
|
|
if not task_data:
|
|
|
|
raise Exception("Task not found")
|
|
|
|
if task_data.status == "finished":
|
|
|
|
images_encoded = task_data.response
|
|
|
|
break
|
|
|
|
elif task_data.status == "error":
|
|
|
|
return {"error": str(task_data.response)}
|
|
|
|
await asyncio.sleep(0.1)
|
|
|
|
|
|
|
|
process_time = time.perf_counter() - t
|
|
|
|
logger.info(f"Request took {process_time:0.3f} seconds")
|
|
|
|
data = ""
|
|
|
|
ptr = 0
|
|
|
|
for x in images_encoded:
|
|
|
|
ptr += 1
|
|
|
|
data += ("event: newImage\nid: {}\ndata:{}\n\n").format(ptr, x)
|
|
|
|
return Response(content=data, media_type="text/event-stream")
|
|
|
|
#return GenerationOutput(output=images)
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
traceback.print_exc()
|
|
|
|
logger.error(str(e))
|
|
|
|
e_s = str(e)
|
|
|
|
gc.collect()
|
|
|
|
if "CUDA out of memory" in e_s or \
|
|
|
|
"an illegal memory access" in e_s or "CUDA" in e_s:
|
|
|
|
torch_gc()
|
|
|
|
# logger.error("GPU error, committing seppuku.")
|
|
|
|
# os.kill(mainpid, signal.SIGTERM)
|
|
|
|
return {"error": str(e)}
|
|
|
|
|
|
|
|
@app.post('/start-generate-stream')
|
|
|
|
async def start_generate_stream(request: GenerationRequest, authorized: bool = Depends(verify_token)):
|
|
|
|
try:
|
|
|
|
request.n_samples = min(request.n_samples, MAX_N_SAMPLES) # 应用生成上限
|
|
|
|
output = sanitize_input(config, request)
|
|
|
|
|
|
|
|
if output[0]:
|
|
|
|
request = output[1]
|
|
|
|
else:
|
|
|
|
return ErrorOutput(error=output[1])
|
|
|
|
|
|
|
|
task_id = None
|
|
|
|
try:
|
|
|
|
task_id = queue.add_task(lambda t: _generate_stream(request, t), request)
|
|
|
|
except TaskQueueFullException:
|
|
|
|
return ErrorOutput(error="Task queue is full, please wait for minutes.")
|
|
|
|
except Exception as err:
|
|
|
|
raise err
|
|
|
|
|
|
|
|
return TaskIdOutput(task_id=task_id)
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
traceback.print_exc()
|
|
|
|
logger.error(str(e))
|
|
|
|
e_s = str(e)
|
|
|
|
gc.collect()
|
|
|
|
if "CUDA out of memory" in e_s or \
|
|
|
|
"an illegal memory access" in e_s or "CUDA" in e_s:
|
|
|
|
torch_gc()
|
|
|
|
# logger.error("GPU error, committing seppuku.")
|
|
|
|
# os.kill(mainpid, signal.SIGTERM)
|
|
|
|
return {"error": str(e)}
|
|
|
|
|
|
|
|
@app.post('/get-generate-stream-output')
|
|
|
|
async def generate_stream_output(request: TaskIdRequest, authorized: bool = Depends(verify_token)):
|
|
|
|
try:
|
|
|
|
task_id = request.task_id
|
|
|
|
|
|
|
|
task_data = queue.get_task_data(task_id)
|
|
|
|
if not task_data:
|
|
|
|
return ErrorOutput(error="Task not found.")
|
|
|
|
if task_data.status == "finished":
|
|
|
|
images_encoded = task_data.response
|
|
|
|
|
|
|
|
data = ""
|
|
|
|
ptr = 0
|
|
|
|
for x in images_encoded:
|
|
|
|
ptr += 1
|
|
|
|
data += ("event: newImage\nid: {}\ndata:{}\n\n").format(ptr, x)
|
|
|
|
return Response(content=data, media_type="text/event-stream")
|
|
|
|
|
|
|
|
elif task_data.status == "error":
|
|
|
|
raise task_data.response
|
|
|
|
else:
|
|
|
|
return ErrorOutput(error="Task is not finished.")
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
traceback.print_exc()
|
|
|
|
logger.error(str(e))
|
|
|
|
e_s = str(e)
|
|
|
|
gc.collect()
|
|
|
|
if "CUDA out of memory" in e_s or \
|
|
|
|
"an illegal memory access" in e_s or "CUDA" in e_s:
|
|
|
|
torch_gc()
|
|
|
|
# logger.error("GPU error, committing seppuku.")
|
|
|
|
# os.kill(mainpid, signal.SIGTERM)
|
|
|
|
return {"error": str(e)}
|
|
|
|
|
|
|
|
def _generate(request: GenerationRequest, task_info: TaskData):
|
|
|
|
try:
|
|
|
|
task_info.update({
|
|
|
|
"total_steps": request.steps + 1
|
|
|
|
})
|
|
|
|
|
|
|
|
def _on_step(step_num, total_steps):
|
|
|
|
task_info.update({
|
|
|
|
"total_steps": total_steps + 1,
|
|
|
|
"current_step": step_num
|
|
|
|
})
|
|
|
|
|
|
|
|
images = model.sample(request, callback=_on_step)
|
|
|
|
|
|
|
|
images_encoded = []
|
|
|
|
for x in range(len(images)):
|
|
|
|
image = Image.fromarray(images[x])
|
|
|
|
comment = json.dumps({"steps":request.steps,"sampler":request.sampler,"seed":request.seed,"strength":request.strength,"noise":request.noise,"scale":request.scale,"uc":request.uc})
|
|
|
|
metadata = PngInfo()
|
|
|
|
metadata.add_text("Title", "AI generated image")
|
|
|
|
metadata.add_text("Description", request.prompt)
|
|
|
|
metadata.add_text("Software", "NovelAI")
|
|
|
|
metadata.add_text("Source", "Stable Diffusion "+model_hash)
|
|
|
|
metadata.add_text("Comment", comment)
|
|
|
|
image = Image.fromarray(images[x])
|
|
|
|
#save pillow image with bytesIO
|
|
|
|
output = io.BytesIO()
|
|
|
|
image.save(output, format='PNG', pnginfo=metadata)
|
|
|
|
image = output.getvalue()
|
|
|
|
if config.savefiles:
|
|
|
|
saveimage(image, request)
|
|
|
|
#get base64 of image
|
|
|
|
image = base64.b64encode(image).decode("ascii")
|
|
|
|
images_encoded.append(image)
|
|
|
|
|
|
|
|
task_info.update({
|
|
|
|
"current_step": task_info.current_step + 1
|
|
|
|
})
|
|
|
|
|
|
|
|
del images
|
|
|
|
|
|
|
|
return images_encoded
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
traceback.print_exc()
|
|
|
|
logger.error(str(e))
|
|
|
|
e_s = str(e)
|
|
|
|
gc.collect()
|
|
|
|
if "CUDA out of memory" in e_s or \
|
|
|
|
"an illegal memory access" in e_s or "CUDA" in e_s:
|
|
|
|
torch_gc()
|
|
|
|
# logger.error("GPU error, committing seppuku.")
|
|
|
|
# os.kill(mainpid, signal.SIGTERM)
|
|
|
|
else:
|
|
|
|
raise e
|
|
|
|
|
|
|
|
@app.post('/generate', response_model=Union[GenerationOutput, ErrorOutput])
|
|
|
|
async def generate(request: GenerationRequest, authorized: bool = Depends(verify_token)):
|
|
|
|
t = time.perf_counter()
|
|
|
|
try:
|
|
|
|
request.n_samples = min(request.n_samples, MAX_N_SAMPLES) # 应用生成上限
|
|
|
|
output = sanitize_input(config, request)
|
|
|
|
|
|
|
|
if output[0]:
|
|
|
|
request = output[1]
|
|
|
|
else:
|
|
|
|
return ErrorOutput(error=output[1])
|
|
|
|
|
|
|
|
task_id = None
|
|
|
|
try:
|
|
|
|
task_id = queue.add_task(lambda t: _generate(request, t), request)
|
|
|
|
except TaskQueueFullException:
|
|
|
|
return ErrorOutput(error="Task queue is full, please wait for minutes.")
|
|
|
|
except Exception as err:
|
|
|
|
raise err
|
|
|
|
|
|
|
|
images_encoded = []
|
|
|
|
while running:
|
|
|
|
task_data = queue.get_task_data(task_id)
|
|
|
|
if not task_data:
|
|
|
|
raise Exception("Task not found")
|
|
|
|
if task_data.status == "finished":
|
|
|
|
images_encoded = task_data.response
|
|
|
|
break
|
|
|
|
elif task_data.status == "error":
|
|
|
|
return {"error": str(task_data.response)}
|
|
|
|
await asyncio.sleep(0.1)
|
|
|
|
|
|
|
|
|
|
|
|
process_time = time.perf_counter() - t
|
|
|
|
logger.info(f"Request took {process_time:0.3f} seconds")
|
|
|
|
return GenerationOutput(output=images_encoded)
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
traceback.print_exc()
|
|
|
|
logger.error(str(e))
|
|
|
|
e_s = str(e)
|
|
|
|
gc.collect()
|
|
|
|
if "CUDA out of memory" in e_s or \
|
|
|
|
"an illegal memory access" in e_s or "CUDA" in e_s:
|
|
|
|
torch_gc()
|
|
|
|
# logger.error("GPU error, committing seppuku.")
|
|
|
|
# os.kill(mainpid, signal.SIGTERM)
|
|
|
|
return {"error": str(e)}
|
|
|
|
|
|
|
|
@app.post('/start-generate', response_model=Union[GenerationOutput, ErrorOutput])
|
|
|
|
async def generate(request: GenerationRequest, authorized: bool = Depends(verify_token)):
|
|
|
|
try:
|
|
|
|
request.n_samples = min(request.n_samples, MAX_N_SAMPLES) # 应用生成上限
|
|
|
|
output = sanitize_input(config, request)
|
|
|
|
|
|
|
|
if output[0]:
|
|
|
|
request = output[1]
|
|
|
|
else:
|
|
|
|
return ErrorOutput(error=output[1])
|
|
|
|
|
|
|
|
task_id = None
|
|
|
|
try:
|
|
|
|
task_id = queue.add_task(lambda t: _generate(request, t), request)
|
|
|
|
except TaskQueueFullException:
|
|
|
|
return ErrorOutput(error="Task queue is full, please wait for minutes.")
|
|
|
|
except Exception as err:
|
|
|
|
raise err
|
|
|
|
|
|
|
|
return TaskIdOutput(task_id=task_id)
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
traceback.print_exc()
|
|
|
|
logger.error(str(e))
|
|
|
|
e_s = str(e)
|
|
|
|
gc.collect()
|
|
|
|
if "CUDA out of memory" in e_s or \
|
|
|
|
"an illegal memory access" in e_s or "CUDA" in e_s:
|
|
|
|
torch_gc()
|
|
|
|
# logger.error("GPU error, committing seppuku.")
|
|
|
|
# os.kill(mainpid, signal.SIGTERM)
|
|
|
|
return {"error": str(e)}
|
|
|
|
|
|
|
|
@app.post('/get-generate-output')
|
|
|
|
async def generate_stream_output(request: TaskIdRequest, authorized: bool = Depends(verify_token)):
|
|
|
|
try:
|
|
|
|
task_id = request.task_id
|
|
|
|
|
|
|
|
task_data = queue.get_task_data(task_id)
|
|
|
|
if not task_data:
|
|
|
|
return ErrorOutput(error="Task not found.")
|
|
|
|
if task_data.status == "finished":
|
|
|
|
images_encoded = task_data.response
|
|
|
|
|
|
|
|
return GenerationOutput(output=images_encoded)
|
|
|
|
|
|
|
|
elif task_data.status == "error":
|
|
|
|
raise task_data.response
|
|
|
|
else:
|
|
|
|
return ErrorOutput(error="Task is not finished.")
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
traceback.print_exc()
|
|
|
|
logger.error(str(e))
|
|
|
|
e_s = str(e)
|
|
|
|
gc.collect()
|
|
|
|
if "CUDA out of memory" in e_s or \
|
|
|
|
"an illegal memory access" in e_s or "CUDA" in e_s:
|
|
|
|
torch_gc()
|
|
|
|
# logger.error("GPU error, committing seppuku.")
|
|
|
|
# os.kill(mainpid, signal.SIGTERM)
|
|
|
|
return {"error": str(e)}
|
|
|
|
|
|
|
|
@app.post('/generate-text', response_model=Union[TextOutput, ErrorOutput])
|
|
|
|
def generate_text(request: TextRequest, authorized: bool = Depends(verify_token)):
|
|
|
|
t = time.perf_counter()
|
|
|
|
try:
|
|
|
|
output = sanitize_input(config, request)
|
|
|
|
if output[0]:
|
|
|
|
request = output[1]
|
|
|
|
else:
|
|
|
|
return ErrorOutput(error=output[1])
|
|
|
|
|
|
|
|
is_safe, corrected_text = model.sample(request)
|
|
|
|
|
|
|
|
process_time = time.perf_counter() - t
|
|
|
|
logger.info(f"Request took {process_time:0.3f} seconds")
|
|
|
|
return TextOutput(is_safe=is_safe, corrected_text=corrected_text)
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
traceback.print_exc()
|
|
|
|
logger.error(str(e))
|
|
|
|
e_s = str(e)
|
|
|
|
gc.collect()
|
|
|
|
if "CUDA out of memory" in e_s or \
|
|
|
|
"an illegal memory access" in e_s or "CUDA" in e_s:
|
|
|
|
torch_gc()
|
|
|
|
# logger.error("GPU error, committing seppuku.")
|
|
|
|
# os.kill(mainpid, signal.SIGTERM)
|
|
|
|
return ErrorOutput(error=str(e))
|
|
|
|
|
|
|
|
@app.get('/predict-tags', response_model=Union[TagOutput, ErrorOutput])
|
|
|
|
async def predict_tags(prompt="", authorized: bool = Depends(verify_token)):
|
|
|
|
t = time.perf_counter()
|
|
|
|
try:
|
|
|
|
#output = sanitize_input(config, request)
|
|
|
|
#if output[0]:
|
|
|
|
# request = output[1]
|
|
|
|
#else:
|
|
|
|
# return ErrorOutput(error=output[1])
|
|
|
|
|
|
|
|
tags = embedmodel.get_top_k(prompt)
|
|
|
|
|
|
|
|
process_time = time.perf_counter() - t
|
|
|
|
logger.info(f"Request took {process_time:0.3f} seconds")
|
|
|
|
return TagOutput(tags=[Tags(tag=tag, count=count, confidence=confidence) for tag, count, confidence in tags])
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
traceback.print_exc()
|
|
|
|
logger.error(str(e))
|
|
|
|
e_s = str(e)
|
|
|
|
gc.collect()
|
|
|
|
if "CUDA out of memory" in e_s or \
|
|
|
|
"an illegal memory access" in e_s or "CUDA" in e_s:
|
|
|
|
torch_gc()
|
|
|
|
# logger.error("GPU error, committing seppuku.")
|
|
|
|
# os.kill(mainpid, signal.SIGTERM)
|
|
|
|
return ErrorOutput(error=str(e))
|
|
|
|
|
|
|
|
@app.post('/task-info', response_model=Union[TaskDataOutput, ErrorOutput])
|
|
|
|
async def get_task_info(request: TaskIdRequest):
|
|
|
|
task_data = queue.get_task_data(request.task_id)
|
|
|
|
if task_data:
|
|
|
|
return TaskDataOutput(
|
|
|
|
status=task_data.status,
|
|
|
|
position=task_data.position,
|
|
|
|
current_step=task_data.current_step,
|
|
|
|
total_steps=task_data.total_steps
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
return ErrorOutput(error="Cannot find current task.", code="ERR::TASK_NOT_FOUND")
|
|
|
|
|
|
|
|
@app.websocket('/task-info')
|
|
|
|
async def websocket_endpoint(websocket: WebSocket, task_id: str):
|
|
|
|
await websocket.accept()
|
|
|
|
|
|
|
|
try:
|
|
|
|
task_data = queue.get_task_data(task_id)
|
|
|
|
if not task_data:
|
|
|
|
await websocket.send_json({
|
|
|
|
"error": "ERR::TASK_NOT_FOUND",
|
|
|
|
"message": "Cannot find current task."
|
|
|
|
})
|
|
|
|
await websocket.close()
|
|
|
|
return
|
|
|
|
|
|
|
|
# Send current task info
|
|
|
|
await websocket.send_json(task_data.to_dict())
|
|
|
|
if task_data.status == "finished" or task_data.status == "error":
|
|
|
|
await websocket.close()
|
|
|
|
return
|
|
|
|
|
|
|
|
task_ws_controller.add(task_id, websocket)
|
|
|
|
while running and websocket.application_state == WebSocketState.CONNECTED and websocket.client_state == WebSocketState.CONNECTED:
|
|
|
|
data = await websocket.receive_text()
|
|
|
|
if data == "ping":
|
|
|
|
await websocket.send_json({
|
|
|
|
"pong": "pong"
|
|
|
|
})
|
|
|
|
|
|
|
|
if websocket.application_state == WebSocketState.CONNECTED and websocket.client_state == WebSocketState.CONNECTED:
|
|
|
|
await websocket.close()
|
|
|
|
except WebSocketDisconnect:
|
|
|
|
task_ws_controller.remove(task_id, websocket)
|
|
|
|
|
|
|
|
@app.get('/')
|
|
|
|
def index():
|
|
|
|
return FileResponse('static/index.html')
|
|
|
|
|
|
|
|
app.mount("/", StaticFiles(directory="static/"), name="static")
|
|
|
|
|
|
|
|
def start():
|
|
|
|
uvicorn.run("main:app", host="0.0.0.0", port=4315, log_level="info")
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
start()
|