from concurrent.futures import thread import os import random import re import string import sys from types import LambdaType from aiohttp import request from async_timeout import asyncio from fastapi import FastAPI, Request, Depends from numpy import number from pydantic import BaseModel from fastapi.responses import HTMLResponse, PlainTextResponse, Response from fastapi.exceptions import HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from starlette.responses import FileResponse from hydra_node.config import init_config_model from hydra_node.models import EmbedderModel, torch_gc from typing import Optional, List, Any from typing_extensions import TypedDict import socket from hydra_node.sanitize import sanitize_input import uvicorn from typing import Union, Dict import time import gc import io import signal import base64 import traceback import threading from PIL import Image from PIL.PngImagePlugin import PngInfo import json genlock = threading.Lock() MAX_N_SAMPLES = int(os.getenv("MAX_N_SAMPLES", 100)) TOKEN = os.getenv("TOKEN", None) print(f"Starting Hydra Node HTTP TOKEN={TOKEN}") #Initialize model and config model, config, model_hash = init_config_model() try: embedmodel = EmbedderModel() except Exception as e: print("couldn't load embed model, suggestions won't work:", e) embedmodel = False logger = config.logger try: config.mainpid = int(open("gunicorn.pid", "r").read()) except FileNotFoundError: config.mainpid = os.getpid() mainpid = config.mainpid hostname = socket.gethostname() sent_first_message = False class TaskQueueFullException(Exception): pass class TaskData(TypedDict): task_id: str position: str status: str updated_time: float callback: LambdaType current_step: int total_steps: int payload: Any response: Any class TaskQueue: max_queue_size = 10 recovery_queue_size = 6 is_busy = False loop_thread = None gc_loop_thread = None running = False gc_running = False lock = threading.Lock() queued_task_list: List[TaskData] = [] task_map: Dict[str, TaskData] = {} def __init__(self, max_queue_size = 10, recovery_queue_size = 6) -> None: self.max_queue_size = max_queue_size self.max_queue_size = recovery_queue_size self.gc_loop_thread = threading.Thread(name="TaskQueueGC", target=self._gc_loop) self.gc_running = True self.gc_loop_thread.start() logger.info("Task queue created") def add_task(self, callback: LambdaType, payload: Any = {}) -> str: if self.is_busy: raise TaskQueueFullException("Task queue is full") if len(self.queued_task_list) >= self.max_queue_size: # mark busy self.is_busy = True raise TaskQueueFullException("Task queue is full") task_id = ''.join(random.sample(string.ascii_letters + string.digits, 16)) task = TaskData( task_id=task_id, position=0, status="queued", updated_time=time.time(), callback=callback, current_step=0, total_steps=0, payload=payload ) with self.lock: self.queued_task_list.append(task) task["position"] = len(self.queued_task_list) - 1 logger.info("Added task: %s, queue size: %d" % (task_id, len(self.queued_task_list))) # create index with self.lock: self.task_map[task_id] = task self._start() return task_id def get_task_data(self, task_id: str) -> Union[TaskData, bool]: if task_id in self.task_map: return self.task_map[task_id] else: return False def delete_task_data(self, task_id: str) -> bool: if task_id in self.task_map: with self.lock: del(self.task_map[task_id]) else: return False def stop(self): self.running = False self.gc_running = False self.queued_task_list = [] def _start(self): with self.lock: if self.running == False: self.loop_thread = threading.Thread(name="TaskQueue", target=self._loop) self.running = True self.loop_thread.start() def _loop(self): while self.running and len(self.queued_task_list) > 0: current_task = self.queued_task_list[0] logger.info("Start task:%s." % current_task["task_id"]) try: current_task["status"] = "running" current_task["updated_time"] = time.time() # run task res = current_task["callback"](current_task) # call gc gc.collect() torch_gc() current_task["status"] = "finished" current_task["updated_time"] = time.time() current_task["response"] = res current_task["current_step"] = current_task["total_steps"] except Exception as e: current_task["status"] = "error" current_task["updated_time"] = time.time() current_task["response"] = e gc.collect() e_s = str(e) if "CUDA out of memory" in e_s or \ "an illegal memory access" in e_s or "CUDA" in e_s: traceback.print_exc() logger.error(str(e)) logger.error("GPU error in task.") torch_gc() current_task["response"] = Exception("GPU error in task.") # os.kill(mainpid, signal.SIGTERM) with self.lock: self.queued_task_list.pop(0) # reset queue position for i in range(0, len(self.queued_task_list)): self.queued_task_list[i]["position"] = i if self.is_busy and len(self.queued_task_list) < self.recovery_queue_size: # mark not busy self.is_busy = False logger.info("Task:%s finished, queue size: %d" % (current_task["task_id"], len(self.queued_task_list))) with self.lock: self.running = False # Task to remove finished task def _gc_loop(self): while self.gc_running: current_time = time.time() with self.lock: will_remove_keys = [] for (key, task_data) in self.task_map.items(): if task_data["status"] == "finished" or task_data["status"] == "error": if task_data["updated_time"] + 120 < current_time: # delete task which finished 2 minutes ago will_remove_keys.append(key) for key in will_remove_keys: del(self.task_map[key]) time.sleep(1) queue = TaskQueue( max_queue_size=int(os.getenv("QUEUE_MAX_SIZE", 10)), recovery_queue_size=int(os.getenv("QUEUE_RECOVERY_SIZE", 6)) ) def verify_token(req: Request): if TOKEN: valid = "Authorization" in req.headers and req.headers["Authorization"] == "Bearer "+TOKEN if not valid: raise HTTPException( status_code=401, detail="Unauthorized" ) return True #Initialize fastapi app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=False, allow_methods=["*"], allow_headers=["*"], ) @app.on_event("startup") def startup_event(): logger.info("FastAPI Started, serving") @app.on_event("shutdown") def shutdown_event(): logger.info("FastAPI Shutdown, exiting") queue.stop() class Masker(TypedDict): seed: int mask: str class Tags(TypedDict): tag: str count: int confidence: float class GenerationRequest(BaseModel): prompt: str image: str = None n_samples: int = 1 steps: int = 50 sampler: str = "plms" fixed_code: bool = False ddim_eta: float = 0.0 height: int = 512 width: int = 512 latent_channels: int = 4 downsampling_factor: int = 8 scale: float = 7.0 dynamic_threshold: float = None seed: int = None temp: float = 1.0 top_k: int = 256 grid_size: int = 4 advanced: bool = False stage_two_seed: int = None strength: float = 0.69 noise: float = 0.667 mitigate: bool = False module: str = None masks: List[Masker] = None uc: str = None class TaskIdOutput(BaseModel): task_id: str class TextRequest(BaseModel): prompt: str class TagOutput(BaseModel): tags: List[Tags] class TextOutput(BaseModel): is_safe: str corrected_text: str class TaskIdRequest(BaseModel): task_id: str class TaskDataOutput(BaseModel): status: str position: int current_step: int total_steps: int class GenerationOutput(BaseModel): output: List[str] class ErrorOutput(BaseModel): error: str def saveimage(image, request): os.makedirs("images", exist_ok=True) filename = request.prompt.replace('masterpiece, best quality, ', '') filename = re.sub(r'[/\\<>:"|]', '', filename) filename = filename[:128] filename += f' s-{request.seed}' filename = os.path.join("images", filename.strip()) for n in range(1000000): suff = '.png' if n: suff = f'-{n}.png' if not os.path.exists(filename + suff): break try: with open(filename + suff, "wb") as f: f.write(image) except Exception as e: print("failed to save image:", e) def _generate_stream(request: GenerationRequest, task_info: TaskData = TaskData()): try: task_info["total_steps"] = request.steps + 1 def _on_step(step_num, total_steps): task_info["total_steps"] = total_steps + 1 task_info["current_step"] = step_num if request.advanced: if request.n_samples > 1: return ErrorOutput(error="advanced mode does not support n_samples > 1") images = model.sample_two_stages(request, callback=_on_step) else: images = model.sample(request, callback=_on_step) logger.info("Sample finished.") seed = request.seed images_encoded = [] for x in range(len(images)): if seed is not None: request.seed = seed seed += 1 comment = json.dumps({"steps":request.steps,"sampler":request.sampler,"seed":request.seed,"strength":request.strength,"noise":request.noise,"scale":request.scale,"uc":request.uc}) metadata = PngInfo() metadata.add_text("Title", "AI generated image") metadata.add_text("Description", request.prompt) metadata.add_text("Software", "NovelAI") metadata.add_text("Source", "Stable Diffusion "+model_hash) metadata.add_text("Comment", comment) image = Image.fromarray(images[x]) #save pillow image with bytesIO output = io.BytesIO() image.save(output, format='PNG', pnginfo=metadata) image = output.getvalue() if config.savefiles: saveimage(image, request) #get base64 of image image = base64.b64encode(image).decode("ascii") images_encoded.append(image) task_info["current_step"] += 1 del images logger.info("Images encoded.") return images_encoded except Exception as e: traceback.print_exc() logger.error(str(e)) e_s = str(e) gc.collect() if "CUDA out of memory" in e_s or \ "an illegal memory access" in e_s or "CUDA" in e_s: torch_gc() # logger.error("GPU error, committing seppuku.") # os.kill(mainpid, signal.SIGTERM) else: raise e @app.post('/generate-stream') async def generate_stream(request: GenerationRequest, authorized: bool = Depends(verify_token)): t = time.perf_counter() try: request.n_samples = min(request.n_samples, MAX_N_SAMPLES) # 应用生成上限 output = sanitize_input(config, request) if output[0]: request = output[1] else: return ErrorOutput(error=output[1]) task_id = None try: task_id = queue.add_task(lambda t: _generate_stream(request, t), request) except TaskQueueFullException: return ErrorOutput(error="Task queue is full, please wait for minutes.") except Exception as err: raise err images_encoded = [] while True: task_data = queue.get_task_data(task_id) if not task_data: raise Exception("Task not found") if task_data["status"] == "finished": images_encoded = task_data["response"] break elif task_data["status"] == "error": return {"error": str(task_data["response"])} await asyncio.sleep(0.1) process_time = time.perf_counter() - t logger.info(f"Request took {process_time:0.3f} seconds") data = "" ptr = 0 for x in images_encoded: ptr += 1 data += ("event: newImage\nid: {}\ndata:{}\n\n").format(ptr, x) return Response(content=data, media_type="text/event-stream") #return GenerationOutput(output=images) except Exception as e: traceback.print_exc() logger.error(str(e)) e_s = str(e) gc.collect() if "CUDA out of memory" in e_s or \ "an illegal memory access" in e_s or "CUDA" in e_s: torch_gc() # logger.error("GPU error, committing seppuku.") # os.kill(mainpid, signal.SIGTERM) return {"error": str(e)} @app.post('/start-generate-stream') async def start_generate_stream(request: GenerationRequest, authorized: bool = Depends(verify_token)): try: request.n_samples = min(request.n_samples, MAX_N_SAMPLES) # 应用生成上限 output = sanitize_input(config, request) if output[0]: request = output[1] else: return ErrorOutput(error=output[1]) task_id = None try: task_id = queue.add_task(lambda t: _generate_stream(request, t), request) except TaskQueueFullException: return ErrorOutput(error="Task queue is full, please wait for minutes.") except Exception as err: raise err return TaskIdOutput(task_id=task_id) except Exception as e: traceback.print_exc() logger.error(str(e)) e_s = str(e) gc.collect() if "CUDA out of memory" in e_s or \ "an illegal memory access" in e_s or "CUDA" in e_s: torch_gc() # logger.error("GPU error, committing seppuku.") # os.kill(mainpid, signal.SIGTERM) return {"error": str(e)} @app.post('/get-generate-stream-output') async def generate_stream_output(request: TaskIdRequest, authorized: bool = Depends(verify_token)): try: task_id = request.task_id task_data = queue.get_task_data(task_id) if not task_data: return ErrorOutput(error="Task not found.") if task_data["status"] == "finished": images_encoded = task_data["response"] data = "" ptr = 0 for x in images_encoded: ptr += 1 data += ("event: newImage\nid: {}\ndata:{}\n\n").format(ptr, x) return Response(content=data, media_type="text/event-stream") elif task_data["status"] == "error": raise task_data["response"] else: return ErrorOutput(error="Task is not finished.") except Exception as e: traceback.print_exc() logger.error(str(e)) e_s = str(e) gc.collect() if "CUDA out of memory" in e_s or \ "an illegal memory access" in e_s or "CUDA" in e_s: torch_gc() # logger.error("GPU error, committing seppuku.") # os.kill(mainpid, signal.SIGTERM) return {"error": str(e)} def _generate(request: GenerationRequest, task_info: TaskData): try: task_info["total_steps"] = request.steps + 1 def _on_step(step_num, total_steps): task_info["total_steps"] = total_steps + 1 task_info["current_step"] = step_num images = model.sample(request, callback=_on_step) images_encoded = [] for x in range(len(images)): image = Image.fromarray(images[x]) comment = json.dumps({"steps":request.steps,"sampler":request.sampler,"seed":request.seed,"strength":request.strength,"noise":request.noise,"scale":request.scale,"uc":request.uc}) metadata = PngInfo() metadata.add_text("Title", "AI generated image") metadata.add_text("Description", request.prompt) metadata.add_text("Software", "NovelAI") metadata.add_text("Source", "Stable Diffusion "+model_hash) metadata.add_text("Comment", comment) image = Image.fromarray(images[x]) #save pillow image with bytesIO output = io.BytesIO() image.save(output, format='PNG', pnginfo=metadata) image = output.getvalue() if config.savefiles: saveimage(image, request) #get base64 of image image = base64.b64encode(image).decode("ascii") images_encoded.append(image) task_info["current_step"] += 1 del images return images_encoded except Exception as e: traceback.print_exc() logger.error(str(e)) e_s = str(e) gc.collect() if "CUDA out of memory" in e_s or \ "an illegal memory access" in e_s or "CUDA" in e_s: torch_gc() # logger.error("GPU error, committing seppuku.") # os.kill(mainpid, signal.SIGTERM) else: raise e @app.post('/generate', response_model=Union[GenerationOutput, ErrorOutput]) async def generate(request: GenerationRequest, authorized: bool = Depends(verify_token)): t = time.perf_counter() try: request.n_samples = min(request.n_samples, MAX_N_SAMPLES) # 应用生成上限 output = sanitize_input(config, request) if output[0]: request = output[1] else: return ErrorOutput(error=output[1]) task_id = None try: task_id = queue.add_task(lambda t: _generate(request, t), request) except TaskQueueFullException: return ErrorOutput(error="Task queue is full, please wait for minutes.") except Exception as err: raise err images_encoded = [] while True: task_data = queue.get_task_data(task_id) if not task_data: raise Exception("Task not found") if task_data["status"] == "finished": images_encoded = task_data["response"] break elif task_data["status"] == "error": return {"error": str(task_data["response"])} await asyncio.sleep(0.1) process_time = time.perf_counter() - t logger.info(f"Request took {process_time:0.3f} seconds") return GenerationOutput(output=images_encoded) except Exception as e: traceback.print_exc() logger.error(str(e)) e_s = str(e) gc.collect() if "CUDA out of memory" in e_s or \ "an illegal memory access" in e_s or "CUDA" in e_s: torch_gc() # logger.error("GPU error, committing seppuku.") # os.kill(mainpid, signal.SIGTERM) return {"error": str(e)} @app.post('/start-generate', response_model=Union[GenerationOutput, ErrorOutput]) async def generate(request: GenerationRequest, authorized: bool = Depends(verify_token)): try: request.n_samples = min(request.n_samples, MAX_N_SAMPLES) # 应用生成上限 output = sanitize_input(config, request) if output[0]: request = output[1] else: return ErrorOutput(error=output[1]) task_id = None try: task_id = queue.add_task(lambda t: _generate(request, t), request) except TaskQueueFullException: return ErrorOutput(error="Task queue is full, please wait for minutes.") except Exception as err: raise err return TaskIdOutput(task_id=task_id) except Exception as e: traceback.print_exc() logger.error(str(e)) e_s = str(e) gc.collect() if "CUDA out of memory" in e_s or \ "an illegal memory access" in e_s or "CUDA" in e_s: torch_gc() # logger.error("GPU error, committing seppuku.") # os.kill(mainpid, signal.SIGTERM) return {"error": str(e)} @app.post('/get-generate-output') async def generate_stream_output(request: TaskIdRequest, authorized: bool = Depends(verify_token)): try: task_id = request.task_id task_data = queue.get_task_data(task_id) if not task_data: return ErrorOutput(error="Task not found.") if task_data["status"] == "finished": images_encoded = task_data["response"] return GenerationOutput(output=images_encoded) elif task_data["status"] == "error": raise task_data["response"] else: return ErrorOutput(error="Task is not finished.") except Exception as e: traceback.print_exc() logger.error(str(e)) e_s = str(e) gc.collect() if "CUDA out of memory" in e_s or \ "an illegal memory access" in e_s or "CUDA" in e_s: torch_gc() # logger.error("GPU error, committing seppuku.") # os.kill(mainpid, signal.SIGTERM) return {"error": str(e)} @app.post('/generate-text', response_model=Union[TextOutput, ErrorOutput]) def generate_text(request: TextRequest, authorized: bool = Depends(verify_token)): t = time.perf_counter() try: output = sanitize_input(config, request) if output[0]: request = output[1] else: return ErrorOutput(error=output[1]) is_safe, corrected_text = model.sample(request) process_time = time.perf_counter() - t logger.info(f"Request took {process_time:0.3f} seconds") return TextOutput(is_safe=is_safe, corrected_text=corrected_text) except Exception as e: traceback.print_exc() logger.error(str(e)) e_s = str(e) gc.collect() if "CUDA out of memory" in e_s or \ "an illegal memory access" in e_s or "CUDA" in e_s: torch_gc() # logger.error("GPU error, committing seppuku.") # os.kill(mainpid, signal.SIGTERM) return ErrorOutput(error=str(e)) @app.get('/predict-tags', response_model=Union[TagOutput, ErrorOutput]) async def predict_tags(prompt="", authorized: bool = Depends(verify_token)): t = time.perf_counter() try: #output = sanitize_input(config, request) #if output[0]: # request = output[1] #else: # return ErrorOutput(error=output[1]) tags = embedmodel.get_top_k(prompt) process_time = time.perf_counter() - t logger.info(f"Request took {process_time:0.3f} seconds") return TagOutput(tags=[Tags(tag=tag, count=count, confidence=confidence) for tag, count, confidence in tags]) except Exception as e: traceback.print_exc() logger.error(str(e)) e_s = str(e) gc.collect() if "CUDA out of memory" in e_s or \ "an illegal memory access" in e_s or "CUDA" in e_s: torch_gc() # logger.error("GPU error, committing seppuku.") # os.kill(mainpid, signal.SIGTERM) return ErrorOutput(error=str(e)) @app.post('/task-info', response_model=Union[TaskDataOutput, ErrorOutput]) async def get_task_info(request: TaskIdRequest): task_data = queue.get_task_data(request.task_id) if task_data: return TaskDataOutput( status=task_data["status"], position=task_data["position"], current_step=task_data["current_step"], total_steps=task_data["total_steps"] ) else: return ErrorOutput(error="Cannot find current task.") @app.get('/') def index(): return FileResponse('static/index.html') app.mount("/", StaticFiles(directory="static/"), name="static") def start(): uvicorn.run("main:app", host="0.0.0.0", port=4315, log_level="info") if __name__ == "__main__": start()