from asyncore import socket_map from concurrent.futures import thread import os import re import signal from types import LambdaType from async_timeout import asyncio from fastapi import FastAPI, Request, Depends, WebSocket, WebSocketDisconnect from numpy import number from pydantic import BaseModel from fastapi.responses import HTMLResponse, PlainTextResponse, Response from fastapi.exceptions import HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from starlette.responses import FileResponse from starlette.websockets import WebSocketState from controller import TaskDataSocketController from hydra_node.config import init_config_model from hydra_node.models import EmbedderModel, torch_gc from typing import Dict, List, Union from typing_extensions import TypedDict import socket from hydra_node.sanitize import sanitize_input from task import TaskData, TaskQueue, TaskQueueFullException import uvicorn import time import gc import io import base64 import traceback import threading from PIL import Image from PIL.PngImagePlugin import PngInfo import json genlock = threading.Lock() running = True MAX_N_SAMPLES = int(os.getenv("MAX_N_SAMPLES", 100)) TOKEN = os.getenv("TOKEN", None) print(f"Starting Hydra Node HTTP TOKEN={TOKEN}") #Initialize model and config model, config, model_hash = init_config_model() try: embedmodel = EmbedderModel() except Exception as e: print("couldn't load embed model, suggestions won't work:", e) embedmodel = False logger = config.logger try: config.mainpid = int(open("gunicorn.pid", "r").read()) except FileNotFoundError: config.mainpid = os.getpid() mainpid = config.mainpid hostname = socket.gethostname() sent_first_message = False queue = TaskQueue( logger=logger, max_queue_size=int(os.getenv("QUEUE_MAX_SIZE", 10)), recovery_queue_size=int(os.getenv("QUEUE_RECOVERY_SIZE", 6)) ) task_ws_controller = TaskDataSocketController() def on_task_update(task_info: Dict): loop = asyncio.get_event_loop() loop.run_until_complete(task_ws_controller.emit(task_info["task_id"], task_info)) queue.on_update = on_task_update def verify_token(req: Request): if TOKEN: valid = "Authorization" in req.headers and req.headers["Authorization"] == "Bearer "+TOKEN if not valid: raise HTTPException( status_code=401, detail="Unauthorized" ) return True #Initialize fastapi app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=False, allow_methods=["*"], allow_headers=["*"], ) @app.on_event("startup") def startup_event(): logger.info("FastAPI Started, serving") @app.on_event("shutdown") def shutdown_event(): logger.info("FastAPI Shutdown, exiting") queue.stop() time.sleep(2) os.kill(mainpid, signal.SIGTERM) class Masker(TypedDict): seed: int mask: str class Tags(TypedDict): tag: str count: int confidence: float class GenerationRequest(BaseModel): prompt: str image: str = None n_samples: int = 1 steps: int = 50 sampler: str = "plms" fixed_code: bool = False ddim_eta: float = 0.0 height: int = 512 width: int = 512 latent_channels: int = 4 downsampling_factor: int = 8 scale: float = 7.0 dynamic_threshold: float = None seed: int = None temp: float = 1.0 top_k: int = 256 grid_size: int = 4 advanced: bool = False stage_two_seed: int = None strength: float = 0.69 noise: float = 0.667 mitigate: bool = False module: str = None masks: List[Masker] = None uc: str = None class TaskIdOutput(BaseModel): task_id: str class TextRequest(BaseModel): prompt: str class TagOutput(BaseModel): tags: List[Tags] class TextOutput(BaseModel): is_safe: str corrected_text: str class TaskIdRequest(BaseModel): task_id: str class TaskDataOutput(BaseModel): status: str position: int current_step: int total_steps: int class GenerationOutput(BaseModel): output: List[str] class ErrorOutput(BaseModel): error: str def saveimage(image, request): os.makedirs("images", exist_ok=True) filename = request.prompt.replace('masterpiece, best quality, ', '') filename = re.sub(r'[/\\<>:"|]', '', filename) filename = filename[:128] filename += f' s-{request.seed}' filename = os.path.join("images", filename.strip()) for n in range(1000000): suff = '.png' if n: suff = f'-{n}.png' if not os.path.exists(filename + suff): break try: with open(filename + suff, "wb") as f: f.write(image) except Exception as e: print("failed to save image:", e) def _generate_stream(request: GenerationRequest, task_info: TaskData): try: task_info.update({ "total_steps": request.steps + 1 }) def _on_step(step_num, total_steps): task_info.update({ "total_steps": total_steps + 1, "current_step": step_num }) if request.advanced: if request.n_samples > 1: return ErrorOutput(error="advanced mode does not support n_samples > 1") images = model.sample_two_stages(request, callback=_on_step) else: images = model.sample(request, callback=_on_step) logger.info("Sample finished.") seed = request.seed images_encoded = [] for x in range(len(images)): if seed is not None: request.seed = seed seed += 1 comment = json.dumps({"steps":request.steps,"sampler":request.sampler,"seed":request.seed,"strength":request.strength,"noise":request.noise,"scale":request.scale,"uc":request.uc}) metadata = PngInfo() metadata.add_text("Title", "AI generated image") metadata.add_text("Description", request.prompt) metadata.add_text("Software", "NovelAI") metadata.add_text("Source", "Stable Diffusion "+model_hash) metadata.add_text("Comment", comment) image = Image.fromarray(images[x]) #save pillow image with bytesIO output = io.BytesIO() image.save(output, format='PNG', pnginfo=metadata) image = output.getvalue() if config.savefiles: saveimage(image, request) #get base64 of image image = base64.b64encode(image).decode("ascii") images_encoded.append(image) task_info.update({ "current_step": task_info.current_step + 1 }) del images logger.info("Images encoded.") return images_encoded except Exception as e: traceback.print_exc() logger.error(str(e)) e_s = str(e) gc.collect() if "CUDA out of memory" in e_s or \ "an illegal memory access" in e_s or "CUDA" in e_s: torch_gc() # logger.error("GPU error, committing seppuku.") # os.kill(mainpid, signal.SIGTERM) else: raise e @app.post('/generate-stream') async def generate_stream(request: GenerationRequest, authorized: bool = Depends(verify_token)): t = time.perf_counter() try: request.n_samples = min(request.n_samples, MAX_N_SAMPLES) # 应用生成上限 output = sanitize_input(config, request) if output[0]: request = output[1] else: return ErrorOutput(error=output[1]) task_id = None try: task_id = queue.add_task(lambda t: _generate_stream(request, t), request) except TaskQueueFullException: return ErrorOutput(error="Task queue is full, please wait for minutes.") except Exception as err: raise err images_encoded = [] while running: task_data = queue.get_task_data(task_id) if not task_data: raise Exception("Task not found") if task_data.status == "finished": images_encoded = task_data.response break elif task_data.status == "error": return {"error": str(task_data.response)} await asyncio.sleep(0.1) process_time = time.perf_counter() - t logger.info(f"Request took {process_time:0.3f} seconds") data = "" ptr = 0 for x in images_encoded: ptr += 1 data += ("event: newImage\nid: {}\ndata:{}\n\n").format(ptr, x) return Response(content=data, media_type="text/event-stream") #return GenerationOutput(output=images) except Exception as e: traceback.print_exc() logger.error(str(e)) e_s = str(e) gc.collect() if "CUDA out of memory" in e_s or \ "an illegal memory access" in e_s or "CUDA" in e_s: torch_gc() # logger.error("GPU error, committing seppuku.") # os.kill(mainpid, signal.SIGTERM) return {"error": str(e)} @app.post('/start-generate-stream') async def start_generate_stream(request: GenerationRequest, authorized: bool = Depends(verify_token)): try: request.n_samples = min(request.n_samples, MAX_N_SAMPLES) # 应用生成上限 output = sanitize_input(config, request) if output[0]: request = output[1] else: return ErrorOutput(error=output[1]) task_id = None try: task_id = queue.add_task(lambda t: _generate_stream(request, t), request) except TaskQueueFullException: return ErrorOutput(error="Task queue is full, please wait for minutes.") except Exception as err: raise err return TaskIdOutput(task_id=task_id) except Exception as e: traceback.print_exc() logger.error(str(e)) e_s = str(e) gc.collect() if "CUDA out of memory" in e_s or \ "an illegal memory access" in e_s or "CUDA" in e_s: torch_gc() # logger.error("GPU error, committing seppuku.") # os.kill(mainpid, signal.SIGTERM) return {"error": str(e)} @app.post('/get-generate-stream-output') async def generate_stream_output(request: TaskIdRequest, authorized: bool = Depends(verify_token)): try: task_id = request.task_id task_data = queue.get_task_data(task_id) if not task_data: return ErrorOutput(error="Task not found.") if task_data.status == "finished": images_encoded = task_data.response data = "" ptr = 0 for x in images_encoded: ptr += 1 data += ("event: newImage\nid: {}\ndata:{}\n\n").format(ptr, x) return Response(content=data, media_type="text/event-stream") elif task_data.status == "error": raise task_data.response else: return ErrorOutput(error="Task is not finished.") except Exception as e: traceback.print_exc() logger.error(str(e)) e_s = str(e) gc.collect() if "CUDA out of memory" in e_s or \ "an illegal memory access" in e_s or "CUDA" in e_s: torch_gc() # logger.error("GPU error, committing seppuku.") # os.kill(mainpid, signal.SIGTERM) return {"error": str(e)} def _generate(request: GenerationRequest, task_info: TaskData): try: task_info.update({ "total_steps": request.steps + 1 }) def _on_step(step_num, total_steps): task_info.update({ "total_steps": total_steps + 1, "current_step": step_num }) images = model.sample(request, callback=_on_step) images_encoded = [] for x in range(len(images)): image = Image.fromarray(images[x]) comment = json.dumps({"steps":request.steps,"sampler":request.sampler,"seed":request.seed,"strength":request.strength,"noise":request.noise,"scale":request.scale,"uc":request.uc}) metadata = PngInfo() metadata.add_text("Title", "AI generated image") metadata.add_text("Description", request.prompt) metadata.add_text("Software", "NovelAI") metadata.add_text("Source", "Stable Diffusion "+model_hash) metadata.add_text("Comment", comment) image = Image.fromarray(images[x]) #save pillow image with bytesIO output = io.BytesIO() image.save(output, format='PNG', pnginfo=metadata) image = output.getvalue() if config.savefiles: saveimage(image, request) #get base64 of image image = base64.b64encode(image).decode("ascii") images_encoded.append(image) task_info.update({ "current_step": task_info.current_step + 1 }) del images return images_encoded except Exception as e: traceback.print_exc() logger.error(str(e)) e_s = str(e) gc.collect() if "CUDA out of memory" in e_s or \ "an illegal memory access" in e_s or "CUDA" in e_s: torch_gc() # logger.error("GPU error, committing seppuku.") # os.kill(mainpid, signal.SIGTERM) else: raise e @app.post('/generate', response_model=Union[GenerationOutput, ErrorOutput]) async def generate(request: GenerationRequest, authorized: bool = Depends(verify_token)): t = time.perf_counter() try: request.n_samples = min(request.n_samples, MAX_N_SAMPLES) # 应用生成上限 output = sanitize_input(config, request) if output[0]: request = output[1] else: return ErrorOutput(error=output[1]) task_id = None try: task_id = queue.add_task(lambda t: _generate(request, t), request) except TaskQueueFullException: return ErrorOutput(error="Task queue is full, please wait for minutes.") except Exception as err: raise err images_encoded = [] while running: task_data = queue.get_task_data(task_id) if not task_data: raise Exception("Task not found") if task_data.status == "finished": images_encoded = task_data.response break elif task_data.status == "error": return {"error": str(task_data.response)} await asyncio.sleep(0.1) process_time = time.perf_counter() - t logger.info(f"Request took {process_time:0.3f} seconds") return GenerationOutput(output=images_encoded) except Exception as e: traceback.print_exc() logger.error(str(e)) e_s = str(e) gc.collect() if "CUDA out of memory" in e_s or \ "an illegal memory access" in e_s or "CUDA" in e_s: torch_gc() # logger.error("GPU error, committing seppuku.") # os.kill(mainpid, signal.SIGTERM) return {"error": str(e)} @app.post('/start-generate', response_model=Union[GenerationOutput, ErrorOutput]) async def generate(request: GenerationRequest, authorized: bool = Depends(verify_token)): try: request.n_samples = min(request.n_samples, MAX_N_SAMPLES) # 应用生成上限 output = sanitize_input(config, request) if output[0]: request = output[1] else: return ErrorOutput(error=output[1]) task_id = None try: task_id = queue.add_task(lambda t: _generate(request, t), request) except TaskQueueFullException: return ErrorOutput(error="Task queue is full, please wait for minutes.") except Exception as err: raise err return TaskIdOutput(task_id=task_id) except Exception as e: traceback.print_exc() logger.error(str(e)) e_s = str(e) gc.collect() if "CUDA out of memory" in e_s or \ "an illegal memory access" in e_s or "CUDA" in e_s: torch_gc() # logger.error("GPU error, committing seppuku.") # os.kill(mainpid, signal.SIGTERM) return {"error": str(e)} @app.post('/get-generate-output') async def generate_stream_output(request: TaskIdRequest, authorized: bool = Depends(verify_token)): try: task_id = request.task_id task_data = queue.get_task_data(task_id) if not task_data: return ErrorOutput(error="Task not found.") if task_data.status == "finished": images_encoded = task_data.response return GenerationOutput(output=images_encoded) elif task_data.status == "error": raise task_data.response else: return ErrorOutput(error="Task is not finished.") except Exception as e: traceback.print_exc() logger.error(str(e)) e_s = str(e) gc.collect() if "CUDA out of memory" in e_s or \ "an illegal memory access" in e_s or "CUDA" in e_s: torch_gc() # logger.error("GPU error, committing seppuku.") # os.kill(mainpid, signal.SIGTERM) return {"error": str(e)} @app.post('/generate-text', response_model=Union[TextOutput, ErrorOutput]) def generate_text(request: TextRequest, authorized: bool = Depends(verify_token)): t = time.perf_counter() try: output = sanitize_input(config, request) if output[0]: request = output[1] else: return ErrorOutput(error=output[1]) is_safe, corrected_text = model.sample(request) process_time = time.perf_counter() - t logger.info(f"Request took {process_time:0.3f} seconds") return TextOutput(is_safe=is_safe, corrected_text=corrected_text) except Exception as e: traceback.print_exc() logger.error(str(e)) e_s = str(e) gc.collect() if "CUDA out of memory" in e_s or \ "an illegal memory access" in e_s or "CUDA" in e_s: torch_gc() # logger.error("GPU error, committing seppuku.") # os.kill(mainpid, signal.SIGTERM) return ErrorOutput(error=str(e)) @app.get('/predict-tags', response_model=Union[TagOutput, ErrorOutput]) async def predict_tags(prompt="", authorized: bool = Depends(verify_token)): t = time.perf_counter() try: #output = sanitize_input(config, request) #if output[0]: # request = output[1] #else: # return ErrorOutput(error=output[1]) tags = embedmodel.get_top_k(prompt) process_time = time.perf_counter() - t logger.info(f"Request took {process_time:0.3f} seconds") return TagOutput(tags=[Tags(tag=tag, count=count, confidence=confidence) for tag, count, confidence in tags]) except Exception as e: traceback.print_exc() logger.error(str(e)) e_s = str(e) gc.collect() if "CUDA out of memory" in e_s or \ "an illegal memory access" in e_s or "CUDA" in e_s: torch_gc() # logger.error("GPU error, committing seppuku.") # os.kill(mainpid, signal.SIGTERM) return ErrorOutput(error=str(e)) @app.post('/task-info', response_model=Union[TaskDataOutput, ErrorOutput]) async def get_task_info(request: TaskIdRequest): task_data = queue.get_task_data(request.task_id) if task_data: return TaskDataOutput( status=task_data.status, position=task_data.position, current_step=task_data.current_step, total_steps=task_data.total_steps ) else: return ErrorOutput(error="Cannot find current task.", code="ERR::TASK_NOT_FOUND") @app.websocket('/task-info') async def websocket_endpoint(websocket: WebSocket, task_id: str): await websocket.accept() try: task_data = queue.get_task_data(task_id) if not task_data: await websocket.send_json({ "error": "ERR::TASK_NOT_FOUND", "message": "Cannot find current task." }) await websocket.close() return # Send current task info await websocket.send_json(task_data.to_dict()) if task_data.status == "finished" or task_data.status == "error": await websocket.close() return task_ws_controller.add(task_id, websocket) while running and websocket.application_state == WebSocketState.CONNECTED and websocket.client_state == WebSocketState.CONNECTED: data = await websocket.receive_text() if data == "ping": await websocket.send_json({ "pong": "pong" }) if websocket.application_state == WebSocketState.CONNECTED and websocket.client_state == WebSocketState.CONNECTED: await websocket.close() except WebSocketDisconnect: task_ws_controller.remove(task_id, websocket) @app.get('/') def index(): return FileResponse('static/index.html') app.mount("/", StaticFiles(directory="static/"), name="static") def start(): uvicorn.run("main:app", host="0.0.0.0", port=4315, log_level="info") if __name__ == "__main__": start()