You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

765 lines
25 KiB
Python

2 years ago
from concurrent.futures import thread
import os
import random
import re
import string
import sys
from types import LambdaType
from aiohttp import request
from async_timeout import asyncio
from fastapi import FastAPI, Request, Depends
from numpy import number
from pydantic import BaseModel
from fastapi.responses import HTMLResponse, PlainTextResponse, Response
from fastapi.exceptions import HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse
from hydra_node.config import init_config_model
from hydra_node.models import EmbedderModel, torch_gc
from typing import Optional, List, Any
from typing_extensions import TypedDict
import socket
from hydra_node.sanitize import sanitize_input
import uvicorn
from typing import Union, Dict
import time
import gc
import io
import signal
import base64
import traceback
import threading
from PIL import Image
from PIL.PngImagePlugin import PngInfo
import json
genlock = threading.Lock()
MAX_N_SAMPLES = int(os.getenv("MAX_N_SAMPLES", 100))
TOKEN = os.getenv("TOKEN", None)
print(f"Starting Hydra Node HTTP TOKEN={TOKEN}")
#Initialize model and config
model, config, model_hash = init_config_model()
try:
embedmodel = EmbedderModel()
except Exception as e:
print("couldn't load embed model, suggestions won't work:", e)
embedmodel = False
logger = config.logger
try:
config.mainpid = int(open("gunicorn.pid", "r").read())
except FileNotFoundError:
config.mainpid = os.getpid()
mainpid = config.mainpid
hostname = socket.gethostname()
sent_first_message = False
class TaskQueueFullException(Exception):
pass
class TaskData(TypedDict):
task_id: str
position: str
status: str
updated_time: float
callback: LambdaType
current_step: int
total_steps: int
payload: Any
response: Any
class TaskQueue:
max_queue_size = 10
recovery_queue_size = 6
is_busy = False
loop_thread = None
gc_loop_thread = None
running = False
gc_running = False
lock = threading.Lock()
queued_task_list: List[TaskData] = []
task_map: Dict[str, TaskData] = {}
def __init__(self, max_queue_size = 10, recovery_queue_size = 6) -> None:
self.max_queue_size = max_queue_size
self.max_queue_size = recovery_queue_size
self.gc_loop_thread = threading.Thread(name="TaskQueueGC", target=self._gc_loop)
self.gc_running = True
self.gc_loop_thread.start()
logger.info("Task queue created")
def add_task(self, callback: LambdaType, payload: Any = {}) -> str:
if self.is_busy:
raise TaskQueueFullException("Task queue is full")
if len(self.queued_task_list) >= self.max_queue_size: # mark busy
self.is_busy = True
raise TaskQueueFullException("Task queue is full")
task_id = ''.join(random.sample(string.ascii_letters + string.digits, 16))
task = TaskData(
task_id=task_id,
position=0,
status="queued",
updated_time=time.time(),
callback=callback,
current_step=0,
total_steps=0,
payload=payload
)
with self.lock:
self.queued_task_list.append(task)
task["position"] = len(self.queued_task_list) - 1
logger.info("Added task: %s, queue size: %d" % (task_id, len(self.queued_task_list)))
# create index
with self.lock:
self.task_map[task_id] = task
self._start()
return task_id
def get_task_data(self, task_id: str) -> Union[TaskData, bool]:
if task_id in self.task_map:
return self.task_map[task_id]
else:
return False
def delete_task_data(self, task_id: str) -> bool:
if task_id in self.task_map:
with self.lock:
del(self.task_map[task_id])
else:
return False
def stop(self):
self.running = False
self.gc_running = False
self.queued_task_list = []
def _start(self):
with self.lock:
if self.running == False:
self.loop_thread = threading.Thread(name="TaskQueue", target=self._loop)
self.running = True
self.loop_thread.start()
def _loop(self):
while self.running and len(self.queued_task_list) > 0:
current_task = self.queued_task_list[0]
logger.info("Start task%s." % current_task["task_id"])
try:
current_task["status"] = "running"
current_task["updated_time"] = time.time()
# run task
res = current_task["callback"](current_task)
# call gc
gc.collect()
torch_gc()
current_task["status"] = "finished"
current_task["updated_time"] = time.time()
current_task["response"] = res
current_task["current_step"] = current_task["total_steps"]
except Exception as e:
current_task["status"] = "error"
current_task["updated_time"] = time.time()
current_task["response"] = e
gc.collect()
e_s = str(e)
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
traceback.print_exc()
logger.error(str(e))
logger.error("GPU error in task.")
torch_gc()
current_task["response"] = Exception("GPU error in task.")
# os.kill(mainpid, signal.SIGTERM)
with self.lock:
self.queued_task_list.pop(0)
# reset queue position
for i in range(0, len(self.queued_task_list)):
self.queued_task_list[i]["position"] = i
if self.is_busy and len(self.queued_task_list) < self.recovery_queue_size: # mark not busy
self.is_busy = False
logger.info("Task%s finished, queue size: %d" % (current_task["task_id"], len(self.queued_task_list)))
with self.lock:
self.running = False
# Task to remove finished task
def _gc_loop(self):
while self.gc_running:
current_time = time.time()
with self.lock:
will_remove_keys = []
for (key, task_data) in self.task_map.items():
if task_data["status"] == "finished" or task_data["status"] == "error":
if task_data["updated_time"] + 120 < current_time: # delete task which finished 2 minutes ago
will_remove_keys.append(key)
for key in will_remove_keys:
del(self.task_map[key])
time.sleep(1)
queue = TaskQueue(
max_queue_size=int(os.getenv("QUEUE_MAX_SIZE", 10)),
recovery_queue_size=int(os.getenv("QUEUE_RECOVERY_SIZE", 6))
)
def verify_token(req: Request):
if TOKEN:
valid = "Authorization" in req.headers and req.headers["Authorization"] == "Bearer "+TOKEN
if not valid:
raise HTTPException(
status_code=401,
detail="Unauthorized"
)
return True
#Initialize fastapi
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
@app.on_event("startup")
def startup_event():
logger.info("FastAPI Started, serving")
@app.on_event("shutdown")
def shutdown_event():
logger.info("FastAPI Shutdown, exiting")
queue.stop()
class Masker(TypedDict):
seed: int
mask: str
class Tags(TypedDict):
tag: str
count: int
confidence: float
class GenerationRequest(BaseModel):
prompt: str
image: str = None
n_samples: int = 1
steps: int = 50
sampler: str = "plms"
fixed_code: bool = False
ddim_eta: float = 0.0
height: int = 512
width: int = 512
latent_channels: int = 4
downsampling_factor: int = 8
scale: float = 7.0
dynamic_threshold: float = None
seed: int = None
temp: float = 1.0
top_k: int = 256
grid_size: int = 4
advanced: bool = False
stage_two_seed: int = None
strength: float = 0.69
noise: float = 0.667
mitigate: bool = False
module: str = None
masks: List[Masker] = None
uc: str = None
class TaskIdOutput(BaseModel):
task_id: str
class TextRequest(BaseModel):
prompt: str
class TagOutput(BaseModel):
tags: List[Tags]
class TextOutput(BaseModel):
is_safe: str
corrected_text: str
class TaskIdRequest(BaseModel):
task_id: str
class TaskDataOutput(BaseModel):
status: str
position: int
current_step: int
total_steps: int
class GenerationOutput(BaseModel):
output: List[str]
class ErrorOutput(BaseModel):
error: str
def saveimage(image, request):
os.makedirs("images", exist_ok=True)
filename = request.prompt.replace('masterpiece, best quality, ', '')
filename = re.sub(r'[/\\<>:"|]', '', filename)
filename = filename[:128]
filename += f' s-{request.seed}'
filename = os.path.join("images", filename.strip())
for n in range(1000000):
suff = '.png'
if n:
suff = f'-{n}.png'
if not os.path.exists(filename + suff):
break
try:
with open(filename + suff, "wb") as f:
f.write(image)
except Exception as e:
print("failed to save image:", e)
def _generate_stream(request: GenerationRequest, task_info: TaskData = TaskData()):
try:
task_info["total_steps"] = request.steps + 1
def _on_step(step_num, total_steps):
task_info["total_steps"] = total_steps + 1
task_info["current_step"] = step_num
if request.advanced:
if request.n_samples > 1:
return ErrorOutput(error="advanced mode does not support n_samples > 1")
images = model.sample_two_stages(request, callback=_on_step)
else:
images = model.sample(request, callback=_on_step)
logger.info("Sample finished.")
seed = request.seed
images_encoded = []
for x in range(len(images)):
if seed is not None:
request.seed = seed
seed += 1
comment = json.dumps({"steps":request.steps,"sampler":request.sampler,"seed":request.seed,"strength":request.strength,"noise":request.noise,"scale":request.scale,"uc":request.uc})
metadata = PngInfo()
metadata.add_text("Title", "AI generated image")
metadata.add_text("Description", request.prompt)
metadata.add_text("Software", "NovelAI")
metadata.add_text("Source", "Stable Diffusion "+model_hash)
metadata.add_text("Comment", comment)
image = Image.fromarray(images[x])
#save pillow image with bytesIO
output = io.BytesIO()
image.save(output, format='PNG', pnginfo=metadata)
image = output.getvalue()
if config.savefiles:
saveimage(image, request)
#get base64 of image
image = base64.b64encode(image).decode("ascii")
images_encoded.append(image)
task_info["current_step"] += 1
del images
logger.info("Images encoded.")
return images_encoded
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
else:
raise e
@app.post('/generate-stream')
async def generate_stream(request: GenerationRequest, authorized: bool = Depends(verify_token)):
t = time.perf_counter()
try:
request.n_samples = min(request.n_samples, MAX_N_SAMPLES) # 应用生成上限
output = sanitize_input(config, request)
if output[0]:
request = output[1]
else:
return ErrorOutput(error=output[1])
task_id = None
try:
task_id = queue.add_task(lambda t: _generate_stream(request, t), request)
except TaskQueueFullException:
return ErrorOutput(error="Task queue is full, please wait for minutes.")
except Exception as err:
raise err
images_encoded = []
while True:
task_data = queue.get_task_data(task_id)
if not task_data:
raise Exception("Task not found")
if task_data["status"] == "finished":
images_encoded = task_data["response"]
break
elif task_data["status"] == "error":
return {"error": str(task_data["response"])}
await asyncio.sleep(0.1)
process_time = time.perf_counter() - t
logger.info(f"Request took {process_time:0.3f} seconds")
data = ""
ptr = 0
for x in images_encoded:
ptr += 1
data += ("event: newImage\nid: {}\ndata:{}\n\n").format(ptr, x)
return Response(content=data, media_type="text/event-stream")
#return GenerationOutput(output=images)
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
return {"error": str(e)}
@app.post('/start-generate-stream')
async def start_generate_stream(request: GenerationRequest, authorized: bool = Depends(verify_token)):
try:
request.n_samples = min(request.n_samples, MAX_N_SAMPLES) # 应用生成上限
output = sanitize_input(config, request)
if output[0]:
request = output[1]
else:
return ErrorOutput(error=output[1])
task_id = None
try:
task_id = queue.add_task(lambda t: _generate_stream(request, t), request)
except TaskQueueFullException:
return ErrorOutput(error="Task queue is full, please wait for minutes.")
except Exception as err:
raise err
return TaskIdOutput(task_id=task_id)
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
return {"error": str(e)}
@app.post('/get-generate-stream-output')
async def generate_stream_output(request: TaskIdRequest, authorized: bool = Depends(verify_token)):
try:
task_id = request.task_id
task_data = queue.get_task_data(task_id)
if not task_data:
return ErrorOutput(error="Task not found.")
if task_data["status"] == "finished":
images_encoded = task_data["response"]
data = ""
ptr = 0
for x in images_encoded:
ptr += 1
data += ("event: newImage\nid: {}\ndata:{}\n\n").format(ptr, x)
return Response(content=data, media_type="text/event-stream")
elif task_data["status"] == "error":
raise task_data["response"]
else:
return ErrorOutput(error="Task is not finished.")
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
return {"error": str(e)}
def _generate(request: GenerationRequest, task_info: TaskData):
try:
task_info["total_steps"] = request.steps + 1
def _on_step(step_num, total_steps):
task_info["total_steps"] = total_steps + 1
task_info["current_step"] = step_num
images = model.sample(request, callback=_on_step)
images_encoded = []
for x in range(len(images)):
image = Image.fromarray(images[x])
comment = json.dumps({"steps":request.steps,"sampler":request.sampler,"seed":request.seed,"strength":request.strength,"noise":request.noise,"scale":request.scale,"uc":request.uc})
metadata = PngInfo()
metadata.add_text("Title", "AI generated image")
metadata.add_text("Description", request.prompt)
metadata.add_text("Software", "NovelAI")
metadata.add_text("Source", "Stable Diffusion "+model_hash)
metadata.add_text("Comment", comment)
image = Image.fromarray(images[x])
#save pillow image with bytesIO
output = io.BytesIO()
image.save(output, format='PNG', pnginfo=metadata)
image = output.getvalue()
if config.savefiles:
saveimage(image, request)
#get base64 of image
image = base64.b64encode(image).decode("ascii")
images_encoded.append(image)
task_info["current_step"] += 1
del images
return images_encoded
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
else:
raise e
@app.post('/generate', response_model=Union[GenerationOutput, ErrorOutput])
async def generate(request: GenerationRequest, authorized: bool = Depends(verify_token)):
t = time.perf_counter()
try:
request.n_samples = min(request.n_samples, MAX_N_SAMPLES) # 应用生成上限
output = sanitize_input(config, request)
if output[0]:
request = output[1]
else:
return ErrorOutput(error=output[1])
task_id = None
try:
task_id = queue.add_task(lambda t: _generate(request, t), request)
except TaskQueueFullException:
return ErrorOutput(error="Task queue is full, please wait for minutes.")
except Exception as err:
raise err
images_encoded = []
while True:
task_data = queue.get_task_data(task_id)
if not task_data:
raise Exception("Task not found")
if task_data["status"] == "finished":
images_encoded = task_data["response"]
break
elif task_data["status"] == "error":
return {"error": str(task_data["response"])}
await asyncio.sleep(0.1)
process_time = time.perf_counter() - t
logger.info(f"Request took {process_time:0.3f} seconds")
return GenerationOutput(output=images_encoded)
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
return {"error": str(e)}
@app.post('/start-generate', response_model=Union[GenerationOutput, ErrorOutput])
async def generate(request: GenerationRequest, authorized: bool = Depends(verify_token)):
try:
request.n_samples = min(request.n_samples, MAX_N_SAMPLES) # 应用生成上限
output = sanitize_input(config, request)
if output[0]:
request = output[1]
else:
return ErrorOutput(error=output[1])
task_id = None
try:
task_id = queue.add_task(lambda t: _generate(request, t), request)
except TaskQueueFullException:
return ErrorOutput(error="Task queue is full, please wait for minutes.")
except Exception as err:
raise err
return TaskIdOutput(task_id=task_id)
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
return {"error": str(e)}
@app.post('/get-generate-output')
async def generate_stream_output(request: TaskIdRequest, authorized: bool = Depends(verify_token)):
try:
task_id = request.task_id
task_data = queue.get_task_data(task_id)
if not task_data:
return ErrorOutput(error="Task not found.")
if task_data["status"] == "finished":
images_encoded = task_data["response"]
return GenerationOutput(output=images_encoded)
elif task_data["status"] == "error":
raise task_data["response"]
else:
return ErrorOutput(error="Task is not finished.")
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
return {"error": str(e)}
@app.post('/generate-text', response_model=Union[TextOutput, ErrorOutput])
def generate_text(request: TextRequest, authorized: bool = Depends(verify_token)):
t = time.perf_counter()
try:
output = sanitize_input(config, request)
if output[0]:
request = output[1]
else:
return ErrorOutput(error=output[1])
is_safe, corrected_text = model.sample(request)
process_time = time.perf_counter() - t
logger.info(f"Request took {process_time:0.3f} seconds")
return TextOutput(is_safe=is_safe, corrected_text=corrected_text)
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
return ErrorOutput(error=str(e))
@app.get('/predict-tags', response_model=Union[TagOutput, ErrorOutput])
async def predict_tags(prompt="", authorized: bool = Depends(verify_token)):
t = time.perf_counter()
try:
#output = sanitize_input(config, request)
#if output[0]:
# request = output[1]
#else:
# return ErrorOutput(error=output[1])
tags = embedmodel.get_top_k(prompt)
process_time = time.perf_counter() - t
logger.info(f"Request took {process_time:0.3f} seconds")
return TagOutput(tags=[Tags(tag=tag, count=count, confidence=confidence) for tag, count, confidence in tags])
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
return ErrorOutput(error=str(e))
@app.post('/task-info', response_model=Union[TaskDataOutput, ErrorOutput])
async def get_task_info(request: TaskIdRequest):
task_data = queue.get_task_data(request.task_id)
if task_data:
return TaskDataOutput(
status=task_data["status"],
position=task_data["position"],
current_step=task_data["current_step"],
total_steps=task_data["total_steps"]
)
else:
return ErrorOutput(error="Cannot find current task.")
@app.get('/')
def index():
return FileResponse('static/index.html')
app.mount("/", StaticFiles(directory="static/"), name="static")
def start():
uvicorn.run("main:app", host="0.0.0.0", port=4315, log_level="info")
if __name__ == "__main__":
start()