|
|
|
|
from concurrent.futures import thread
|
|
|
|
|
import os
|
|
|
|
|
import random
|
|
|
|
|
import re
|
|
|
|
|
import string
|
|
|
|
|
import sys
|
|
|
|
|
from types import LambdaType
|
|
|
|
|
from aiohttp import request
|
|
|
|
|
from async_timeout import asyncio
|
|
|
|
|
from fastapi import FastAPI, Request, Depends
|
|
|
|
|
from numpy import number
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
from fastapi.responses import HTMLResponse, PlainTextResponse, Response
|
|
|
|
|
from fastapi.exceptions import HTTPException
|
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
|
from starlette.responses import FileResponse
|
|
|
|
|
from hydra_node.config import init_config_model
|
|
|
|
|
from hydra_node.models import EmbedderModel, torch_gc
|
|
|
|
|
from typing import Optional, List, Any
|
|
|
|
|
from typing_extensions import TypedDict
|
|
|
|
|
import socket
|
|
|
|
|
from hydra_node.sanitize import sanitize_input
|
|
|
|
|
import uvicorn
|
|
|
|
|
from typing import Union, Dict
|
|
|
|
|
import time
|
|
|
|
|
import gc
|
|
|
|
|
import io
|
|
|
|
|
import signal
|
|
|
|
|
import base64
|
|
|
|
|
import traceback
|
|
|
|
|
import threading
|
|
|
|
|
from PIL import Image
|
|
|
|
|
from PIL.PngImagePlugin import PngInfo
|
|
|
|
|
import json
|
|
|
|
|
|
|
|
|
|
genlock = threading.Lock()
|
|
|
|
|
|
|
|
|
|
MAX_N_SAMPLES = int(os.getenv("MAX_N_SAMPLES", 100))
|
|
|
|
|
TOKEN = os.getenv("TOKEN", None)
|
|
|
|
|
print(f"Starting Hydra Node HTTP TOKEN={TOKEN}")
|
|
|
|
|
|
|
|
|
|
#Initialize model and config
|
|
|
|
|
model, config, model_hash = init_config_model()
|
|
|
|
|
try:
|
|
|
|
|
embedmodel = EmbedderModel()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print("couldn't load embed model, suggestions won't work:", e)
|
|
|
|
|
embedmodel = False
|
|
|
|
|
logger = config.logger
|
|
|
|
|
try:
|
|
|
|
|
config.mainpid = int(open("gunicorn.pid", "r").read())
|
|
|
|
|
except FileNotFoundError:
|
|
|
|
|
config.mainpid = os.getpid()
|
|
|
|
|
mainpid = config.mainpid
|
|
|
|
|
hostname = socket.gethostname()
|
|
|
|
|
sent_first_message = False
|
|
|
|
|
|
|
|
|
|
class TaskQueueFullException(Exception):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
class TaskData(TypedDict):
|
|
|
|
|
task_id: str
|
|
|
|
|
position: str
|
|
|
|
|
status: str
|
|
|
|
|
updated_time: float
|
|
|
|
|
callback: LambdaType
|
|
|
|
|
current_step: int
|
|
|
|
|
total_steps: int
|
|
|
|
|
payload: Any
|
|
|
|
|
response: Any
|
|
|
|
|
|
|
|
|
|
class TaskQueue:
|
|
|
|
|
max_queue_size = 10
|
|
|
|
|
recovery_queue_size = 6
|
|
|
|
|
is_busy = False
|
|
|
|
|
|
|
|
|
|
loop_thread = None
|
|
|
|
|
gc_loop_thread = None
|
|
|
|
|
running = False
|
|
|
|
|
gc_running = False
|
|
|
|
|
|
|
|
|
|
lock = threading.Lock()
|
|
|
|
|
|
|
|
|
|
queued_task_list: List[TaskData] = []
|
|
|
|
|
task_map: Dict[str, TaskData] = {}
|
|
|
|
|
|
|
|
|
|
def __init__(self, max_queue_size = 10, recovery_queue_size = 6) -> None:
|
|
|
|
|
self.max_queue_size = max_queue_size
|
|
|
|
|
self.max_queue_size = recovery_queue_size
|
|
|
|
|
|
|
|
|
|
self.gc_loop_thread = threading.Thread(name="TaskQueueGC", target=self._gc_loop)
|
|
|
|
|
self.gc_running = True
|
|
|
|
|
self.gc_loop_thread.start()
|
|
|
|
|
|
|
|
|
|
logger.info("Task queue created")
|
|
|
|
|
|
|
|
|
|
def add_task(self, callback: LambdaType, payload: Any = {}) -> str:
|
|
|
|
|
if self.is_busy:
|
|
|
|
|
raise TaskQueueFullException("Task queue is full")
|
|
|
|
|
if len(self.queued_task_list) >= self.max_queue_size: # mark busy
|
|
|
|
|
self.is_busy = True
|
|
|
|
|
raise TaskQueueFullException("Task queue is full")
|
|
|
|
|
|
|
|
|
|
task_id = ''.join(random.sample(string.ascii_letters + string.digits, 16))
|
|
|
|
|
task = TaskData(
|
|
|
|
|
task_id=task_id,
|
|
|
|
|
position=0,
|
|
|
|
|
status="queued",
|
|
|
|
|
updated_time=time.time(),
|
|
|
|
|
callback=callback,
|
|
|
|
|
current_step=0,
|
|
|
|
|
total_steps=0,
|
|
|
|
|
payload=payload
|
|
|
|
|
)
|
|
|
|
|
with self.lock:
|
|
|
|
|
self.queued_task_list.append(task)
|
|
|
|
|
task["position"] = len(self.queued_task_list) - 1
|
|
|
|
|
logger.info("Added task: %s, queue size: %d" % (task_id, len(self.queued_task_list)))
|
|
|
|
|
# create index
|
|
|
|
|
with self.lock:
|
|
|
|
|
self.task_map[task_id] = task
|
|
|
|
|
|
|
|
|
|
self._start()
|
|
|
|
|
|
|
|
|
|
return task_id
|
|
|
|
|
|
|
|
|
|
def get_task_data(self, task_id: str) -> Union[TaskData, bool]:
|
|
|
|
|
if task_id in self.task_map:
|
|
|
|
|
return self.task_map[task_id]
|
|
|
|
|
else:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def delete_task_data(self, task_id: str) -> bool:
|
|
|
|
|
if task_id in self.task_map:
|
|
|
|
|
with self.lock:
|
|
|
|
|
del(self.task_map[task_id])
|
|
|
|
|
else:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def stop(self):
|
|
|
|
|
self.running = False
|
|
|
|
|
self.gc_running = False
|
|
|
|
|
self.queued_task_list = []
|
|
|
|
|
|
|
|
|
|
def _start(self):
|
|
|
|
|
with self.lock:
|
|
|
|
|
if self.running == False:
|
|
|
|
|
self.loop_thread = threading.Thread(name="TaskQueue", target=self._loop)
|
|
|
|
|
self.running = True
|
|
|
|
|
self.loop_thread.start()
|
|
|
|
|
|
|
|
|
|
def _loop(self):
|
|
|
|
|
while self.running and len(self.queued_task_list) > 0:
|
|
|
|
|
current_task = self.queued_task_list[0]
|
|
|
|
|
logger.info("Start task:%s." % current_task["task_id"])
|
|
|
|
|
try:
|
|
|
|
|
current_task["status"] = "running"
|
|
|
|
|
current_task["updated_time"] = time.time()
|
|
|
|
|
|
|
|
|
|
# run task
|
|
|
|
|
res = current_task["callback"](current_task)
|
|
|
|
|
|
|
|
|
|
# call gc
|
|
|
|
|
gc.collect()
|
|
|
|
|
torch_gc()
|
|
|
|
|
|
|
|
|
|
current_task["status"] = "finished"
|
|
|
|
|
current_task["updated_time"] = time.time()
|
|
|
|
|
current_task["response"] = res
|
|
|
|
|
current_task["current_step"] = current_task["total_steps"]
|
|
|
|
|
except Exception as e:
|
|
|
|
|
current_task["status"] = "error"
|
|
|
|
|
current_task["updated_time"] = time.time()
|
|
|
|
|
current_task["response"] = e
|
|
|
|
|
gc.collect()
|
|
|
|
|
e_s = str(e)
|
|
|
|
|
if "CUDA out of memory" in e_s or \
|
|
|
|
|
"an illegal memory access" in e_s or "CUDA" in e_s:
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
logger.error(str(e))
|
|
|
|
|
logger.error("GPU error in task.")
|
|
|
|
|
torch_gc()
|
|
|
|
|
current_task["response"] = Exception("GPU error in task.")
|
|
|
|
|
# os.kill(mainpid, signal.SIGTERM)
|
|
|
|
|
|
|
|
|
|
with self.lock:
|
|
|
|
|
self.queued_task_list.pop(0)
|
|
|
|
|
# reset queue position
|
|
|
|
|
for i in range(0, len(self.queued_task_list)):
|
|
|
|
|
self.queued_task_list[i]["position"] = i
|
|
|
|
|
|
|
|
|
|
if self.is_busy and len(self.queued_task_list) < self.recovery_queue_size: # mark not busy
|
|
|
|
|
self.is_busy = False
|
|
|
|
|
|
|
|
|
|
logger.info("Task:%s finished, queue size: %d" % (current_task["task_id"], len(self.queued_task_list)))
|
|
|
|
|
|
|
|
|
|
with self.lock:
|
|
|
|
|
self.running = False
|
|
|
|
|
|
|
|
|
|
# Task to remove finished task
|
|
|
|
|
def _gc_loop(self):
|
|
|
|
|
while self.gc_running:
|
|
|
|
|
current_time = time.time()
|
|
|
|
|
with self.lock:
|
|
|
|
|
will_remove_keys = []
|
|
|
|
|
for (key, task_data) in self.task_map.items():
|
|
|
|
|
if task_data["status"] == "finished" or task_data["status"] == "error":
|
|
|
|
|
if task_data["updated_time"] + 120 < current_time: # delete task which finished 2 minutes ago
|
|
|
|
|
will_remove_keys.append(key)
|
|
|
|
|
|
|
|
|
|
for key in will_remove_keys:
|
|
|
|
|
del(self.task_map[key])
|
|
|
|
|
|
|
|
|
|
time.sleep(1)
|
|
|
|
|
|
|
|
|
|
queue = TaskQueue(
|
|
|
|
|
max_queue_size=int(os.getenv("QUEUE_MAX_SIZE", 10)),
|
|
|
|
|
recovery_queue_size=int(os.getenv("QUEUE_RECOVERY_SIZE", 6))
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def verify_token(req: Request):
|
|
|
|
|
if TOKEN:
|
|
|
|
|
valid = "Authorization" in req.headers and req.headers["Authorization"] == "Bearer "+TOKEN
|
|
|
|
|
if not valid:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=401,
|
|
|
|
|
detail="Unauthorized"
|
|
|
|
|
)
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
#Initialize fastapi
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
|
|
|
|
app.add_middleware(
|
|
|
|
|
CORSMiddleware,
|
|
|
|
|
allow_origins=["*"],
|
|
|
|
|
allow_credentials=False,
|
|
|
|
|
allow_methods=["*"],
|
|
|
|
|
allow_headers=["*"],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@app.on_event("startup")
|
|
|
|
|
def startup_event():
|
|
|
|
|
logger.info("FastAPI Started, serving")
|
|
|
|
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
|
|
|
def shutdown_event():
|
|
|
|
|
logger.info("FastAPI Shutdown, exiting")
|
|
|
|
|
queue.stop()
|
|
|
|
|
|
|
|
|
|
class Masker(TypedDict):
|
|
|
|
|
seed: int
|
|
|
|
|
mask: str
|
|
|
|
|
|
|
|
|
|
class Tags(TypedDict):
|
|
|
|
|
tag: str
|
|
|
|
|
count: int
|
|
|
|
|
confidence: float
|
|
|
|
|
|
|
|
|
|
class GenerationRequest(BaseModel):
|
|
|
|
|
prompt: str
|
|
|
|
|
image: str = None
|
|
|
|
|
n_samples: int = 1
|
|
|
|
|
steps: int = 50
|
|
|
|
|
sampler: str = "plms"
|
|
|
|
|
fixed_code: bool = False
|
|
|
|
|
ddim_eta: float = 0.0
|
|
|
|
|
height: int = 512
|
|
|
|
|
width: int = 512
|
|
|
|
|
latent_channels: int = 4
|
|
|
|
|
downsampling_factor: int = 8
|
|
|
|
|
scale: float = 7.0
|
|
|
|
|
dynamic_threshold: float = None
|
|
|
|
|
seed: int = None
|
|
|
|
|
temp: float = 1.0
|
|
|
|
|
top_k: int = 256
|
|
|
|
|
grid_size: int = 4
|
|
|
|
|
advanced: bool = False
|
|
|
|
|
stage_two_seed: int = None
|
|
|
|
|
strength: float = 0.69
|
|
|
|
|
noise: float = 0.667
|
|
|
|
|
mitigate: bool = False
|
|
|
|
|
module: str = None
|
|
|
|
|
masks: List[Masker] = None
|
|
|
|
|
uc: str = None
|
|
|
|
|
|
|
|
|
|
class TaskIdOutput(BaseModel):
|
|
|
|
|
task_id: str
|
|
|
|
|
|
|
|
|
|
class TextRequest(BaseModel):
|
|
|
|
|
prompt: str
|
|
|
|
|
|
|
|
|
|
class TagOutput(BaseModel):
|
|
|
|
|
tags: List[Tags]
|
|
|
|
|
|
|
|
|
|
class TextOutput(BaseModel):
|
|
|
|
|
is_safe: str
|
|
|
|
|
corrected_text: str
|
|
|
|
|
|
|
|
|
|
class TaskIdRequest(BaseModel):
|
|
|
|
|
task_id: str
|
|
|
|
|
|
|
|
|
|
class TaskDataOutput(BaseModel):
|
|
|
|
|
status: str
|
|
|
|
|
position: int
|
|
|
|
|
current_step: int
|
|
|
|
|
total_steps: int
|
|
|
|
|
|
|
|
|
|
class GenerationOutput(BaseModel):
|
|
|
|
|
output: List[str]
|
|
|
|
|
|
|
|
|
|
class ErrorOutput(BaseModel):
|
|
|
|
|
error: str
|
|
|
|
|
|
|
|
|
|
def saveimage(image, request):
|
|
|
|
|
os.makedirs("images", exist_ok=True)
|
|
|
|
|
|
|
|
|
|
filename = request.prompt.replace('masterpiece, best quality, ', '')
|
|
|
|
|
filename = re.sub(r'[/\\<>:"|]', '', filename)
|
|
|
|
|
filename = filename[:128]
|
|
|
|
|
filename += f' s-{request.seed}'
|
|
|
|
|
filename = os.path.join("images", filename.strip())
|
|
|
|
|
|
|
|
|
|
for n in range(1000000):
|
|
|
|
|
suff = '.png'
|
|
|
|
|
if n:
|
|
|
|
|
suff = f'-{n}.png'
|
|
|
|
|
if not os.path.exists(filename + suff):
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
with open(filename + suff, "wb") as f:
|
|
|
|
|
f.write(image)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print("failed to save image:", e)
|
|
|
|
|
|
|
|
|
|
def _generate_stream(request: GenerationRequest, task_info: TaskData = TaskData()):
|
|
|
|
|
try:
|
|
|
|
|
task_info["total_steps"] = request.steps + 1
|
|
|
|
|
|
|
|
|
|
def _on_step(step_num, total_steps):
|
|
|
|
|
task_info["total_steps"] = total_steps + 1
|
|
|
|
|
task_info["current_step"] = step_num
|
|
|
|
|
|
|
|
|
|
if request.advanced:
|
|
|
|
|
if request.n_samples > 1:
|
|
|
|
|
return ErrorOutput(error="advanced mode does not support n_samples > 1")
|
|
|
|
|
|
|
|
|
|
images = model.sample_two_stages(request, callback=_on_step)
|
|
|
|
|
else:
|
|
|
|
|
images = model.sample(request, callback=_on_step)
|
|
|
|
|
|
|
|
|
|
logger.info("Sample finished.")
|
|
|
|
|
|
|
|
|
|
seed = request.seed
|
|
|
|
|
|
|
|
|
|
images_encoded = []
|
|
|
|
|
for x in range(len(images)):
|
|
|
|
|
if seed is not None:
|
|
|
|
|
request.seed = seed
|
|
|
|
|
seed += 1
|
|
|
|
|
comment = json.dumps({"steps":request.steps,"sampler":request.sampler,"seed":request.seed,"strength":request.strength,"noise":request.noise,"scale":request.scale,"uc":request.uc})
|
|
|
|
|
metadata = PngInfo()
|
|
|
|
|
metadata.add_text("Title", "AI generated image")
|
|
|
|
|
metadata.add_text("Description", request.prompt)
|
|
|
|
|
metadata.add_text("Software", "NovelAI")
|
|
|
|
|
metadata.add_text("Source", "Stable Diffusion "+model_hash)
|
|
|
|
|
metadata.add_text("Comment", comment)
|
|
|
|
|
image = Image.fromarray(images[x])
|
|
|
|
|
#save pillow image with bytesIO
|
|
|
|
|
output = io.BytesIO()
|
|
|
|
|
image.save(output, format='PNG', pnginfo=metadata)
|
|
|
|
|
image = output.getvalue()
|
|
|
|
|
if config.savefiles:
|
|
|
|
|
saveimage(image, request)
|
|
|
|
|
#get base64 of image
|
|
|
|
|
image = base64.b64encode(image).decode("ascii")
|
|
|
|
|
images_encoded.append(image)
|
|
|
|
|
|
|
|
|
|
task_info["current_step"] += 1
|
|
|
|
|
|
|
|
|
|
del images
|
|
|
|
|
|
|
|
|
|
logger.info("Images encoded.")
|
|
|
|
|
|
|
|
|
|
return images_encoded
|
|
|
|
|
except Exception as e:
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
logger.error(str(e))
|
|
|
|
|
e_s = str(e)
|
|
|
|
|
gc.collect()
|
|
|
|
|
if "CUDA out of memory" in e_s or \
|
|
|
|
|
"an illegal memory access" in e_s or "CUDA" in e_s:
|
|
|
|
|
torch_gc()
|
|
|
|
|
# logger.error("GPU error, committing seppuku.")
|
|
|
|
|
# os.kill(mainpid, signal.SIGTERM)
|
|
|
|
|
else:
|
|
|
|
|
raise e
|
|
|
|
|
|
|
|
|
|
@app.post('/generate-stream')
|
|
|
|
|
async def generate_stream(request: GenerationRequest, authorized: bool = Depends(verify_token)):
|
|
|
|
|
t = time.perf_counter()
|
|
|
|
|
try:
|
|
|
|
|
request.n_samples = min(request.n_samples, MAX_N_SAMPLES) # 应用生成上限
|
|
|
|
|
output = sanitize_input(config, request)
|
|
|
|
|
|
|
|
|
|
if output[0]:
|
|
|
|
|
request = output[1]
|
|
|
|
|
else:
|
|
|
|
|
return ErrorOutput(error=output[1])
|
|
|
|
|
|
|
|
|
|
task_id = None
|
|
|
|
|
try:
|
|
|
|
|
task_id = queue.add_task(lambda t: _generate_stream(request, t), request)
|
|
|
|
|
except TaskQueueFullException:
|
|
|
|
|
return ErrorOutput(error="Task queue is full, please wait for minutes.")
|
|
|
|
|
except Exception as err:
|
|
|
|
|
raise err
|
|
|
|
|
|
|
|
|
|
images_encoded = []
|
|
|
|
|
while True:
|
|
|
|
|
task_data = queue.get_task_data(task_id)
|
|
|
|
|
if not task_data:
|
|
|
|
|
raise Exception("Task not found")
|
|
|
|
|
if task_data["status"] == "finished":
|
|
|
|
|
images_encoded = task_data["response"]
|
|
|
|
|
break
|
|
|
|
|
elif task_data["status"] == "error":
|
|
|
|
|
return {"error": str(task_data["response"])}
|
|
|
|
|
await asyncio.sleep(0.1)
|
|
|
|
|
|
|
|
|
|
process_time = time.perf_counter() - t
|
|
|
|
|
logger.info(f"Request took {process_time:0.3f} seconds")
|
|
|
|
|
data = ""
|
|
|
|
|
ptr = 0
|
|
|
|
|
for x in images_encoded:
|
|
|
|
|
ptr += 1
|
|
|
|
|
data += ("event: newImage\nid: {}\ndata:{}\n\n").format(ptr, x)
|
|
|
|
|
return Response(content=data, media_type="text/event-stream")
|
|
|
|
|
#return GenerationOutput(output=images)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
logger.error(str(e))
|
|
|
|
|
e_s = str(e)
|
|
|
|
|
gc.collect()
|
|
|
|
|
if "CUDA out of memory" in e_s or \
|
|
|
|
|
"an illegal memory access" in e_s or "CUDA" in e_s:
|
|
|
|
|
torch_gc()
|
|
|
|
|
# logger.error("GPU error, committing seppuku.")
|
|
|
|
|
# os.kill(mainpid, signal.SIGTERM)
|
|
|
|
|
return {"error": str(e)}
|
|
|
|
|
|
|
|
|
|
@app.post('/start-generate-stream')
|
|
|
|
|
async def start_generate_stream(request: GenerationRequest, authorized: bool = Depends(verify_token)):
|
|
|
|
|
try:
|
|
|
|
|
request.n_samples = min(request.n_samples, MAX_N_SAMPLES) # 应用生成上限
|
|
|
|
|
output = sanitize_input(config, request)
|
|
|
|
|
|
|
|
|
|
if output[0]:
|
|
|
|
|
request = output[1]
|
|
|
|
|
else:
|
|
|
|
|
return ErrorOutput(error=output[1])
|
|
|
|
|
|
|
|
|
|
task_id = None
|
|
|
|
|
try:
|
|
|
|
|
task_id = queue.add_task(lambda t: _generate_stream(request, t), request)
|
|
|
|
|
except TaskQueueFullException:
|
|
|
|
|
return ErrorOutput(error="Task queue is full, please wait for minutes.")
|
|
|
|
|
except Exception as err:
|
|
|
|
|
raise err
|
|
|
|
|
|
|
|
|
|
return TaskIdOutput(task_id=task_id)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
logger.error(str(e))
|
|
|
|
|
e_s = str(e)
|
|
|
|
|
gc.collect()
|
|
|
|
|
if "CUDA out of memory" in e_s or \
|
|
|
|
|
"an illegal memory access" in e_s or "CUDA" in e_s:
|
|
|
|
|
torch_gc()
|
|
|
|
|
# logger.error("GPU error, committing seppuku.")
|
|
|
|
|
# os.kill(mainpid, signal.SIGTERM)
|
|
|
|
|
return {"error": str(e)}
|
|
|
|
|
|
|
|
|
|
@app.post('/get-generate-stream-output')
|
|
|
|
|
async def generate_stream_output(request: TaskIdRequest, authorized: bool = Depends(verify_token)):
|
|
|
|
|
try:
|
|
|
|
|
task_id = request.task_id
|
|
|
|
|
|
|
|
|
|
task_data = queue.get_task_data(task_id)
|
|
|
|
|
if not task_data:
|
|
|
|
|
return ErrorOutput(error="Task not found.")
|
|
|
|
|
if task_data["status"] == "finished":
|
|
|
|
|
images_encoded = task_data["response"]
|
|
|
|
|
|
|
|
|
|
data = ""
|
|
|
|
|
ptr = 0
|
|
|
|
|
for x in images_encoded:
|
|
|
|
|
ptr += 1
|
|
|
|
|
data += ("event: newImage\nid: {}\ndata:{}\n\n").format(ptr, x)
|
|
|
|
|
return Response(content=data, media_type="text/event-stream")
|
|
|
|
|
|
|
|
|
|
elif task_data["status"] == "error":
|
|
|
|
|
raise task_data["response"]
|
|
|
|
|
else:
|
|
|
|
|
return ErrorOutput(error="Task is not finished.")
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
logger.error(str(e))
|
|
|
|
|
e_s = str(e)
|
|
|
|
|
gc.collect()
|
|
|
|
|
if "CUDA out of memory" in e_s or \
|
|
|
|
|
"an illegal memory access" in e_s or "CUDA" in e_s:
|
|
|
|
|
torch_gc()
|
|
|
|
|
# logger.error("GPU error, committing seppuku.")
|
|
|
|
|
# os.kill(mainpid, signal.SIGTERM)
|
|
|
|
|
return {"error": str(e)}
|
|
|
|
|
|
|
|
|
|
def _generate(request: GenerationRequest, task_info: TaskData):
|
|
|
|
|
try:
|
|
|
|
|
task_info["total_steps"] = request.steps + 1
|
|
|
|
|
|
|
|
|
|
def _on_step(step_num, total_steps):
|
|
|
|
|
task_info["total_steps"] = total_steps + 1
|
|
|
|
|
task_info["current_step"] = step_num
|
|
|
|
|
|
|
|
|
|
images = model.sample(request, callback=_on_step)
|
|
|
|
|
|
|
|
|
|
images_encoded = []
|
|
|
|
|
for x in range(len(images)):
|
|
|
|
|
image = Image.fromarray(images[x])
|
|
|
|
|
comment = json.dumps({"steps":request.steps,"sampler":request.sampler,"seed":request.seed,"strength":request.strength,"noise":request.noise,"scale":request.scale,"uc":request.uc})
|
|
|
|
|
metadata = PngInfo()
|
|
|
|
|
metadata.add_text("Title", "AI generated image")
|
|
|
|
|
metadata.add_text("Description", request.prompt)
|
|
|
|
|
metadata.add_text("Software", "NovelAI")
|
|
|
|
|
metadata.add_text("Source", "Stable Diffusion "+model_hash)
|
|
|
|
|
metadata.add_text("Comment", comment)
|
|
|
|
|
image = Image.fromarray(images[x])
|
|
|
|
|
#save pillow image with bytesIO
|
|
|
|
|
output = io.BytesIO()
|
|
|
|
|
image.save(output, format='PNG', pnginfo=metadata)
|
|
|
|
|
image = output.getvalue()
|
|
|
|
|
if config.savefiles:
|
|
|
|
|
saveimage(image, request)
|
|
|
|
|
#get base64 of image
|
|
|
|
|
image = base64.b64encode(image).decode("ascii")
|
|
|
|
|
images_encoded.append(image)
|
|
|
|
|
|
|
|
|
|
task_info["current_step"] += 1
|
|
|
|
|
|
|
|
|
|
del images
|
|
|
|
|
|
|
|
|
|
return images_encoded
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
logger.error(str(e))
|
|
|
|
|
e_s = str(e)
|
|
|
|
|
gc.collect()
|
|
|
|
|
if "CUDA out of memory" in e_s or \
|
|
|
|
|
"an illegal memory access" in e_s or "CUDA" in e_s:
|
|
|
|
|
torch_gc()
|
|
|
|
|
# logger.error("GPU error, committing seppuku.")
|
|
|
|
|
# os.kill(mainpid, signal.SIGTERM)
|
|
|
|
|
else:
|
|
|
|
|
raise e
|
|
|
|
|
|
|
|
|
|
@app.post('/generate', response_model=Union[GenerationOutput, ErrorOutput])
|
|
|
|
|
async def generate(request: GenerationRequest, authorized: bool = Depends(verify_token)):
|
|
|
|
|
t = time.perf_counter()
|
|
|
|
|
try:
|
|
|
|
|
request.n_samples = min(request.n_samples, MAX_N_SAMPLES) # 应用生成上限
|
|
|
|
|
output = sanitize_input(config, request)
|
|
|
|
|
|
|
|
|
|
if output[0]:
|
|
|
|
|
request = output[1]
|
|
|
|
|
else:
|
|
|
|
|
return ErrorOutput(error=output[1])
|
|
|
|
|
|
|
|
|
|
task_id = None
|
|
|
|
|
try:
|
|
|
|
|
task_id = queue.add_task(lambda t: _generate(request, t), request)
|
|
|
|
|
except TaskQueueFullException:
|
|
|
|
|
return ErrorOutput(error="Task queue is full, please wait for minutes.")
|
|
|
|
|
except Exception as err:
|
|
|
|
|
raise err
|
|
|
|
|
|
|
|
|
|
images_encoded = []
|
|
|
|
|
while True:
|
|
|
|
|
task_data = queue.get_task_data(task_id)
|
|
|
|
|
if not task_data:
|
|
|
|
|
raise Exception("Task not found")
|
|
|
|
|
if task_data["status"] == "finished":
|
|
|
|
|
images_encoded = task_data["response"]
|
|
|
|
|
break
|
|
|
|
|
elif task_data["status"] == "error":
|
|
|
|
|
return {"error": str(task_data["response"])}
|
|
|
|
|
await asyncio.sleep(0.1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
process_time = time.perf_counter() - t
|
|
|
|
|
logger.info(f"Request took {process_time:0.3f} seconds")
|
|
|
|
|
return GenerationOutput(output=images_encoded)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
logger.error(str(e))
|
|
|
|
|
e_s = str(e)
|
|
|
|
|
gc.collect()
|
|
|
|
|
if "CUDA out of memory" in e_s or \
|
|
|
|
|
"an illegal memory access" in e_s or "CUDA" in e_s:
|
|
|
|
|
torch_gc()
|
|
|
|
|
# logger.error("GPU error, committing seppuku.")
|
|
|
|
|
# os.kill(mainpid, signal.SIGTERM)
|
|
|
|
|
return {"error": str(e)}
|
|
|
|
|
|
|
|
|
|
@app.post('/start-generate', response_model=Union[GenerationOutput, ErrorOutput])
|
|
|
|
|
async def generate(request: GenerationRequest, authorized: bool = Depends(verify_token)):
|
|
|
|
|
try:
|
|
|
|
|
request.n_samples = min(request.n_samples, MAX_N_SAMPLES) # 应用生成上限
|
|
|
|
|
output = sanitize_input(config, request)
|
|
|
|
|
|
|
|
|
|
if output[0]:
|
|
|
|
|
request = output[1]
|
|
|
|
|
else:
|
|
|
|
|
return ErrorOutput(error=output[1])
|
|
|
|
|
|
|
|
|
|
task_id = None
|
|
|
|
|
try:
|
|
|
|
|
task_id = queue.add_task(lambda t: _generate(request, t), request)
|
|
|
|
|
except TaskQueueFullException:
|
|
|
|
|
return ErrorOutput(error="Task queue is full, please wait for minutes.")
|
|
|
|
|
except Exception as err:
|
|
|
|
|
raise err
|
|
|
|
|
|
|
|
|
|
return TaskIdOutput(task_id=task_id)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
logger.error(str(e))
|
|
|
|
|
e_s = str(e)
|
|
|
|
|
gc.collect()
|
|
|
|
|
if "CUDA out of memory" in e_s or \
|
|
|
|
|
"an illegal memory access" in e_s or "CUDA" in e_s:
|
|
|
|
|
torch_gc()
|
|
|
|
|
# logger.error("GPU error, committing seppuku.")
|
|
|
|
|
# os.kill(mainpid, signal.SIGTERM)
|
|
|
|
|
return {"error": str(e)}
|
|
|
|
|
|
|
|
|
|
@app.post('/get-generate-output')
|
|
|
|
|
async def generate_stream_output(request: TaskIdRequest, authorized: bool = Depends(verify_token)):
|
|
|
|
|
try:
|
|
|
|
|
task_id = request.task_id
|
|
|
|
|
|
|
|
|
|
task_data = queue.get_task_data(task_id)
|
|
|
|
|
if not task_data:
|
|
|
|
|
return ErrorOutput(error="Task not found.")
|
|
|
|
|
if task_data["status"] == "finished":
|
|
|
|
|
images_encoded = task_data["response"]
|
|
|
|
|
|
|
|
|
|
return GenerationOutput(output=images_encoded)
|
|
|
|
|
|
|
|
|
|
elif task_data["status"] == "error":
|
|
|
|
|
raise task_data["response"]
|
|
|
|
|
else:
|
|
|
|
|
return ErrorOutput(error="Task is not finished.")
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
logger.error(str(e))
|
|
|
|
|
e_s = str(e)
|
|
|
|
|
gc.collect()
|
|
|
|
|
if "CUDA out of memory" in e_s or \
|
|
|
|
|
"an illegal memory access" in e_s or "CUDA" in e_s:
|
|
|
|
|
torch_gc()
|
|
|
|
|
# logger.error("GPU error, committing seppuku.")
|
|
|
|
|
# os.kill(mainpid, signal.SIGTERM)
|
|
|
|
|
return {"error": str(e)}
|
|
|
|
|
|
|
|
|
|
@app.post('/generate-text', response_model=Union[TextOutput, ErrorOutput])
|
|
|
|
|
def generate_text(request: TextRequest, authorized: bool = Depends(verify_token)):
|
|
|
|
|
t = time.perf_counter()
|
|
|
|
|
try:
|
|
|
|
|
output = sanitize_input(config, request)
|
|
|
|
|
if output[0]:
|
|
|
|
|
request = output[1]
|
|
|
|
|
else:
|
|
|
|
|
return ErrorOutput(error=output[1])
|
|
|
|
|
|
|
|
|
|
is_safe, corrected_text = model.sample(request)
|
|
|
|
|
|
|
|
|
|
process_time = time.perf_counter() - t
|
|
|
|
|
logger.info(f"Request took {process_time:0.3f} seconds")
|
|
|
|
|
return TextOutput(is_safe=is_safe, corrected_text=corrected_text)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
logger.error(str(e))
|
|
|
|
|
e_s = str(e)
|
|
|
|
|
gc.collect()
|
|
|
|
|
if "CUDA out of memory" in e_s or \
|
|
|
|
|
"an illegal memory access" in e_s or "CUDA" in e_s:
|
|
|
|
|
torch_gc()
|
|
|
|
|
# logger.error("GPU error, committing seppuku.")
|
|
|
|
|
# os.kill(mainpid, signal.SIGTERM)
|
|
|
|
|
return ErrorOutput(error=str(e))
|
|
|
|
|
|
|
|
|
|
@app.get('/predict-tags', response_model=Union[TagOutput, ErrorOutput])
|
|
|
|
|
async def predict_tags(prompt="", authorized: bool = Depends(verify_token)):
|
|
|
|
|
t = time.perf_counter()
|
|
|
|
|
try:
|
|
|
|
|
#output = sanitize_input(config, request)
|
|
|
|
|
#if output[0]:
|
|
|
|
|
# request = output[1]
|
|
|
|
|
#else:
|
|
|
|
|
# return ErrorOutput(error=output[1])
|
|
|
|
|
|
|
|
|
|
tags = embedmodel.get_top_k(prompt)
|
|
|
|
|
|
|
|
|
|
process_time = time.perf_counter() - t
|
|
|
|
|
logger.info(f"Request took {process_time:0.3f} seconds")
|
|
|
|
|
return TagOutput(tags=[Tags(tag=tag, count=count, confidence=confidence) for tag, count, confidence in tags])
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
logger.error(str(e))
|
|
|
|
|
e_s = str(e)
|
|
|
|
|
gc.collect()
|
|
|
|
|
if "CUDA out of memory" in e_s or \
|
|
|
|
|
"an illegal memory access" in e_s or "CUDA" in e_s:
|
|
|
|
|
torch_gc()
|
|
|
|
|
# logger.error("GPU error, committing seppuku.")
|
|
|
|
|
# os.kill(mainpid, signal.SIGTERM)
|
|
|
|
|
return ErrorOutput(error=str(e))
|
|
|
|
|
|
|
|
|
|
@app.post('/task-info', response_model=Union[TaskDataOutput, ErrorOutput])
|
|
|
|
|
async def get_task_info(request: TaskIdRequest):
|
|
|
|
|
task_data = queue.get_task_data(request.task_id)
|
|
|
|
|
if task_data:
|
|
|
|
|
return TaskDataOutput(
|
|
|
|
|
status=task_data["status"],
|
|
|
|
|
position=task_data["position"],
|
|
|
|
|
current_step=task_data["current_step"],
|
|
|
|
|
total_steps=task_data["total_steps"]
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
return ErrorOutput(error="Cannot find current task.")
|
|
|
|
|
|
|
|
|
|
@app.get('/')
|
|
|
|
|
def index():
|
|
|
|
|
return FileResponse('static/index.html')
|
|
|
|
|
|
|
|
|
|
app.mount("/", StaticFiles(directory="static/"), name="static")
|
|
|
|
|
|
|
|
|
|
def start():
|
|
|
|
|
uvicorn.run("main:app", host="0.0.0.0", port=4315, log_level="info")
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
start()
|