You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

664 lines
21 KiB
Python

from asyncore import socket_map
from concurrent.futures import thread
import os
import re
import signal
from types import LambdaType
from async_timeout import asyncio
from fastapi import FastAPI, Request, Depends, WebSocket, WebSocketDisconnect
from numpy import number
from pydantic import BaseModel
from fastapi.responses import HTMLResponse, PlainTextResponse, Response
from fastapi.exceptions import HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse
from starlette.websockets import WebSocketState
from controller import TaskDataSocketController
from hydra_node.config import init_config_model
from hydra_node.models import EmbedderModel, torch_gc
from typing import Dict, List, Union
from typing_extensions import TypedDict
import socket
from hydra_node.sanitize import sanitize_input
from task import TaskData, TaskQueue, TaskQueueFullException
import uvicorn
import time
import gc
import io
import base64
import traceback
import threading
from PIL import Image
from PIL.PngImagePlugin import PngInfo
import json
genlock = threading.Lock()
running = True
MAX_N_SAMPLES = int(os.getenv("MAX_N_SAMPLES", 100))
TOKEN = os.getenv("TOKEN", None)
print(f"Starting Hydra Node HTTP TOKEN={TOKEN}")
#Initialize model and config
model, config, model_hash = init_config_model()
try:
embedmodel = EmbedderModel()
except Exception as e:
print("couldn't load embed model, suggestions won't work:", e)
embedmodel = False
logger = config.logger
try:
config.mainpid = int(open("gunicorn.pid", "r").read())
except FileNotFoundError:
config.mainpid = os.getpid()
mainpid = config.mainpid
hostname = socket.gethostname()
sent_first_message = False
queue = TaskQueue(
logger=logger,
max_queue_size=int(os.getenv("QUEUE_MAX_SIZE", 10)),
recovery_queue_size=int(os.getenv("QUEUE_RECOVERY_SIZE", 6))
)
task_ws_controller = TaskDataSocketController()
def on_task_update(task_info: Dict):
loop = asyncio.get_event_loop()
loop.run_until_complete(task_ws_controller.emit(task_info["task_id"], task_info))
queue.on_update = on_task_update
def verify_token(req: Request):
if TOKEN:
valid = "Authorization" in req.headers and req.headers["Authorization"] == "Bearer "+TOKEN
if not valid:
raise HTTPException(
status_code=401,
detail="Unauthorized"
)
return True
#Initialize fastapi
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
@app.on_event("startup")
def startup_event():
logger.info("FastAPI Started, serving")
@app.on_event("shutdown")
def shutdown_event():
logger.info("FastAPI Shutdown, exiting")
queue.stop()
time.sleep(2)
os.kill(mainpid, signal.SIGTERM)
class Masker(TypedDict):
seed: int
mask: str
class Tags(TypedDict):
tag: str
count: int
confidence: float
class GenerationRequest(BaseModel):
prompt: str
image: str = None
n_samples: int = 1
steps: int = 50
sampler: str = "plms"
fixed_code: bool = False
ddim_eta: float = 0.0
height: int = 512
width: int = 512
latent_channels: int = 4
downsampling_factor: int = 8
scale: float = 7.0
dynamic_threshold: float = None
seed: int = None
temp: float = 1.0
top_k: int = 256
grid_size: int = 4
advanced: bool = False
stage_two_seed: int = None
strength: float = 0.69
noise: float = 0.667
mitigate: bool = False
module: str = None
masks: List[Masker] = None
uc: str = None
class TaskIdOutput(BaseModel):
task_id: str
class TextRequest(BaseModel):
prompt: str
class TagOutput(BaseModel):
tags: List[Tags]
class TextOutput(BaseModel):
is_safe: str
corrected_text: str
class TaskIdRequest(BaseModel):
task_id: str
class TaskDataOutput(BaseModel):
status: str
position: int
current_step: int
total_steps: int
class GenerationOutput(BaseModel):
output: List[str]
class ErrorOutput(BaseModel):
error: str
def saveimage(image, request):
os.makedirs("images", exist_ok=True)
filename = request.prompt.replace('masterpiece, best quality, ', '')
filename = re.sub(r'[/\\<>:"|]', '', filename)
filename = filename[:128]
filename += f' s-{request.seed}'
filename = os.path.join("images", filename.strip())
for n in range(1000000):
suff = '.png'
if n:
suff = f'-{n}.png'
if not os.path.exists(filename + suff):
break
try:
with open(filename + suff, "wb") as f:
f.write(image)
except Exception as e:
print("failed to save image:", e)
def _generate_stream(request: GenerationRequest, task_info: TaskData):
try:
task_info.update({
"total_steps": request.steps + 1
})
def _on_step(step_num, total_steps):
task_info.update({
"total_steps": total_steps + 1,
"current_step": step_num
})
if request.advanced:
if request.n_samples > 1:
return ErrorOutput(error="advanced mode does not support n_samples > 1")
images = model.sample_two_stages(request, callback=_on_step)
else:
images = model.sample(request, callback=_on_step)
logger.info("Sample finished.")
seed = request.seed
images_encoded = []
for x in range(len(images)):
if seed is not None:
request.seed = seed
seed += 1
comment = json.dumps({"steps":request.steps,"sampler":request.sampler,"seed":request.seed,"strength":request.strength,"noise":request.noise,"scale":request.scale,"uc":request.uc})
metadata = PngInfo()
metadata.add_text("Title", "AI generated image")
metadata.add_text("Description", request.prompt)
metadata.add_text("Software", "NovelAI")
metadata.add_text("Source", "Stable Diffusion "+model_hash)
metadata.add_text("Comment", comment)
image = Image.fromarray(images[x])
#save pillow image with bytesIO
output = io.BytesIO()
image.save(output, format='PNG', pnginfo=metadata)
image = output.getvalue()
if config.savefiles:
saveimage(image, request)
#get base64 of image
image = base64.b64encode(image).decode("ascii")
images_encoded.append(image)
task_info.update({
"current_step": task_info.current_step + 1
})
del images
logger.info("Images encoded.")
return images_encoded
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
else:
raise e
@app.post('/generate-stream')
async def generate_stream(request: GenerationRequest, authorized: bool = Depends(verify_token)):
t = time.perf_counter()
try:
request.n_samples = min(request.n_samples, MAX_N_SAMPLES) # 应用生成上限
output = sanitize_input(config, request)
if output[0]:
request = output[1]
else:
return ErrorOutput(error=output[1])
task_id = None
try:
task_id = queue.add_task(lambda t: _generate_stream(request, t), request)
except TaskQueueFullException:
return ErrorOutput(error="Task queue is full, please wait for minutes.")
except Exception as err:
raise err
images_encoded = []
while running:
task_data = queue.get_task_data(task_id)
if not task_data:
raise Exception("Task not found")
if task_data.status == "finished":
images_encoded = task_data.response
break
elif task_data.status == "error":
return {"error": str(task_data.response)}
await asyncio.sleep(0.1)
process_time = time.perf_counter() - t
logger.info(f"Request took {process_time:0.3f} seconds")
data = ""
ptr = 0
for x in images_encoded:
ptr += 1
data += ("event: newImage\nid: {}\ndata:{}\n\n").format(ptr, x)
return Response(content=data, media_type="text/event-stream")
#return GenerationOutput(output=images)
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
return {"error": str(e)}
@app.post('/start-generate-stream')
async def start_generate_stream(request: GenerationRequest, authorized: bool = Depends(verify_token)):
try:
request.n_samples = min(request.n_samples, MAX_N_SAMPLES) # 应用生成上限
output = sanitize_input(config, request)
if output[0]:
request = output[1]
else:
return ErrorOutput(error=output[1])
task_id = None
try:
task_id = queue.add_task(lambda t: _generate_stream(request, t), request)
except TaskQueueFullException:
return ErrorOutput(error="Task queue is full, please wait for minutes.")
except Exception as err:
raise err
return TaskIdOutput(task_id=task_id)
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
return {"error": str(e)}
@app.post('/get-generate-stream-output')
async def generate_stream_output(request: TaskIdRequest, authorized: bool = Depends(verify_token)):
try:
task_id = request.task_id
task_data = queue.get_task_data(task_id)
if not task_data:
return ErrorOutput(error="Task not found.")
if task_data.status == "finished":
images_encoded = task_data.response
data = ""
ptr = 0
for x in images_encoded:
ptr += 1
data += ("event: newImage\nid: {}\ndata:{}\n\n").format(ptr, x)
return Response(content=data, media_type="text/event-stream")
elif task_data.status == "error":
raise task_data.response
else:
return ErrorOutput(error="Task is not finished.")
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
return {"error": str(e)}
def _generate(request: GenerationRequest, task_info: TaskData):
try:
task_info.update({
"total_steps": request.steps + 1
})
def _on_step(step_num, total_steps):
task_info.update({
"total_steps": total_steps + 1,
"current_step": step_num
})
images = model.sample(request, callback=_on_step)
images_encoded = []
for x in range(len(images)):
image = Image.fromarray(images[x])
comment = json.dumps({"steps":request.steps,"sampler":request.sampler,"seed":request.seed,"strength":request.strength,"noise":request.noise,"scale":request.scale,"uc":request.uc})
metadata = PngInfo()
metadata.add_text("Title", "AI generated image")
metadata.add_text("Description", request.prompt)
metadata.add_text("Software", "NovelAI")
metadata.add_text("Source", "Stable Diffusion "+model_hash)
metadata.add_text("Comment", comment)
image = Image.fromarray(images[x])
#save pillow image with bytesIO
output = io.BytesIO()
image.save(output, format='PNG', pnginfo=metadata)
image = output.getvalue()
if config.savefiles:
saveimage(image, request)
#get base64 of image
image = base64.b64encode(image).decode("ascii")
images_encoded.append(image)
task_info.update({
"current_step": task_info.current_step + 1
})
del images
return images_encoded
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
else:
raise e
@app.post('/generate', response_model=Union[GenerationOutput, ErrorOutput])
async def generate(request: GenerationRequest, authorized: bool = Depends(verify_token)):
t = time.perf_counter()
try:
request.n_samples = min(request.n_samples, MAX_N_SAMPLES) # 应用生成上限
output = sanitize_input(config, request)
if output[0]:
request = output[1]
else:
return ErrorOutput(error=output[1])
task_id = None
try:
task_id = queue.add_task(lambda t: _generate(request, t), request)
except TaskQueueFullException:
return ErrorOutput(error="Task queue is full, please wait for minutes.")
except Exception as err:
raise err
images_encoded = []
while running:
task_data = queue.get_task_data(task_id)
if not task_data:
raise Exception("Task not found")
if task_data.status == "finished":
images_encoded = task_data.response
break
elif task_data.status == "error":
return {"error": str(task_data.response)}
await asyncio.sleep(0.1)
process_time = time.perf_counter() - t
logger.info(f"Request took {process_time:0.3f} seconds")
return GenerationOutput(output=images_encoded)
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
return {"error": str(e)}
@app.post('/start-generate', response_model=Union[GenerationOutput, ErrorOutput])
async def generate(request: GenerationRequest, authorized: bool = Depends(verify_token)):
try:
request.n_samples = min(request.n_samples, MAX_N_SAMPLES) # 应用生成上限
output = sanitize_input(config, request)
if output[0]:
request = output[1]
else:
return ErrorOutput(error=output[1])
task_id = None
try:
task_id = queue.add_task(lambda t: _generate(request, t), request)
except TaskQueueFullException:
return ErrorOutput(error="Task queue is full, please wait for minutes.")
except Exception as err:
raise err
return TaskIdOutput(task_id=task_id)
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
return {"error": str(e)}
@app.post('/get-generate-output')
async def generate_stream_output(request: TaskIdRequest, authorized: bool = Depends(verify_token)):
try:
task_id = request.task_id
task_data = queue.get_task_data(task_id)
if not task_data:
return ErrorOutput(error="Task not found.")
if task_data.status == "finished":
images_encoded = task_data.response
return GenerationOutput(output=images_encoded)
elif task_data.status == "error":
raise task_data.response
else:
return ErrorOutput(error="Task is not finished.")
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
return {"error": str(e)}
@app.post('/generate-text', response_model=Union[TextOutput, ErrorOutput])
def generate_text(request: TextRequest, authorized: bool = Depends(verify_token)):
t = time.perf_counter()
try:
output = sanitize_input(config, request)
if output[0]:
request = output[1]
else:
return ErrorOutput(error=output[1])
is_safe, corrected_text = model.sample(request)
process_time = time.perf_counter() - t
logger.info(f"Request took {process_time:0.3f} seconds")
return TextOutput(is_safe=is_safe, corrected_text=corrected_text)
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
return ErrorOutput(error=str(e))
@app.get('/predict-tags', response_model=Union[TagOutput, ErrorOutput])
async def predict_tags(prompt="", authorized: bool = Depends(verify_token)):
t = time.perf_counter()
try:
#output = sanitize_input(config, request)
#if output[0]:
# request = output[1]
#else:
# return ErrorOutput(error=output[1])
tags = embedmodel.get_top_k(prompt)
process_time = time.perf_counter() - t
logger.info(f"Request took {process_time:0.3f} seconds")
return TagOutput(tags=[Tags(tag=tag, count=count, confidence=confidence) for tag, count, confidence in tags])
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
return ErrorOutput(error=str(e))
@app.post('/task-info', response_model=Union[TaskDataOutput, ErrorOutput])
async def get_task_info(request: TaskIdRequest):
task_data = queue.get_task_data(request.task_id)
if task_data:
return TaskDataOutput(
status=task_data.status,
position=task_data.position,
current_step=task_data.current_step,
total_steps=task_data.total_steps
)
else:
return ErrorOutput(error="Cannot find current task.", code="ERR::TASK_NOT_FOUND")
@app.websocket('/task-info')
async def websocket_endpoint(websocket: WebSocket, task_id: str):
await websocket.accept()
try:
task_data = queue.get_task_data(task_id)
if not task_data:
await websocket.send_json({
"error": "ERR::TASK_NOT_FOUND",
"message": "Cannot find current task."
})
await websocket.close()
return
# Send current task info
await websocket.send_json(task_data.to_dict())
if task_data.status == "finished" or task_data.status == "error":
await websocket.close()
return
task_ws_controller.add(task_id, websocket)
while running and websocket.application_state == WebSocketState.CONNECTED and websocket.client_state == WebSocketState.CONNECTED:
data = await websocket.receive_text()
if data == "ping":
await websocket.send_json({
"pong": "pong"
})
if websocket.application_state == WebSocketState.CONNECTED and websocket.client_state == WebSocketState.CONNECTED:
await websocket.close()
except WebSocketDisconnect:
task_ws_controller.remove(task_id, websocket)
@app.get('/')
def index():
return FileResponse('static/index.html')
app.mount("/", StaticFiles(directory="static/"), name="static")
def start():
uvicorn.run("main:app", host="0.0.0.0", port=4315, log_level="info")
if __name__ == "__main__":
start()