上传项目

master
落雨楓 2 years ago
commit e23b2b2f95

@ -0,0 +1,29 @@
# Naifu 异世界百科版
为了和群内的老哥们一起玩耍,添加了队列功能,以给多人使用带来更好的体验。
为1050Ti等低速生成设备增加了网页显示生成进度条的功能适合用手机登录网站生成。
* 增加生成进度功能
* 增加并行生成限制功能
* 增加简单队列功能
前端源码:[https://git.isekai.cn/hyperzlib/naifu-frontend](https://git.isekai.cn/hyperzlib/naifu-frontend)
## 安装
将文件覆盖到Naifu的程序目录。
低显存设备需要将 ```hydra_node-lowvram``` 下的文件复制到 ```hydra_node``` 下。
前端程序下载:[https://git.isekai.cn/hyperzlib/naifu-frontend/releases](https://git.isekai.cn/hyperzlib/naifu-frontend/releases)
下载 ```static.zip```将解压后的文件覆盖到Naifu的程序目录。
## 环境变量配置项目
以下参数在 ```run.bat``` 或者 ```run.sh``` 中进行更改
| 参数名 | 简介 | 详细介绍 |
| ------------------- | ---------------------------------- | ---------------------------------------------------------------------------------------- |
| QUEUE_MAX_SIZE | 队列最大大小(数字) | 到达队列最大大小后,会进入“队列已满”状态,直到队列中任务数低于 ```QUEUE_RECOVERY_SIZE``` |
| QUEUE_RECOVERY_SIZE | 队列恢复大小(数字) | 进入“队列已满”状态后,用户点击“生成”时会提示“队列已满”,直到队列中任务数低于队列恢复大小 |
| MAX_N_SAMPLES | 最大并行生成数量(数字) | 由于并行生成任务会长时间阻塞队列,需要限制用户一次可以生成的图片数量(本地使用无需设置) |
| LOWVRAM | 开启低显存模式1: 开启, 0: 关闭) | 在低显存设备上使用,需要开启此选项,高显存设备上开启此选项会大幅降低生成速度 |

@ -0,0 +1,192 @@
import os
import torch
import logging
import os
import platform
import socket
import sys
import time
from dotmap import DotMap
from hydra_node.models import StableDiffusionModel, DalleMiniModel, BasedformerModel, EmbedderModel
from hydra_node import lowvram
import traceback
import zlib
from pathlib import Path
from ldm.modules.attention import CrossAttention, HyperLogic
from . import lautocast
model_map = {
"stable-diffusion": StableDiffusionModel,
"dalle-mini": DalleMiniModel,
"basedformer": BasedformerModel,
"embedder": EmbedderModel,
}
def no_init(loading_code):
def dummy(self):
return
modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm]
original = {}
for mod in modules:
original[mod] = mod.reset_parameters
mod.reset_parameters = dummy
result = loading_code()
for mod in modules:
mod.reset_parameters = original[mod]
return result
def crc32(filename, chunksize=65536):
"""Compute the CRC-32 checksum of the contents of the given filename"""
with open(filename, "rb") as f:
checksum = 0
while (chunk := f.read(chunksize)) :
checksum = zlib.crc32(chunk, checksum)
return '%08X' % (checksum & 0xFFFFFFFF)
def load_modules(path):
path = Path(path)
modules = {}
if not path.is_dir():
return
for file in path.iterdir():
module = load_module(file, "cpu")
modules[file.stem] = module
print(f"Loaded module {file.stem}")
return modules
def load_module(path, device):
path = Path(path)
if not path.is_file():
print("Module path {} is not a file".format(path))
network = {
768: (HyperLogic(768).to(device), HyperLogic(768).to(device)),
1280: (HyperLogic(1280).to(device), HyperLogic(1280).to(device)),
640: (HyperLogic(640).to(device), HyperLogic(640).to(device)),
320: (HyperLogic(320).to(device), HyperLogic(320).to(device)),
}
state_dict = torch.load(path)
for key in state_dict.keys():
network[key][0].load_state_dict(state_dict[key][0])
network[key][1].load_state_dict(state_dict[key][1])
return network
def init_config_model():
config = DotMap()
config.savefiles = os.getenv("SAVE_FILES", False)
config.dtype = os.getenv("DTYPE", "float16")
config.device = os.getenv("DEVICE", "cuda")
config.amp = os.getenv("AMP", False)
if config.amp == "1":
config.amp = True
elif config.amp == "0":
config.amp = False
is_dev = ""
environment = "production"
if os.environ['DEV'] == "True":
is_dev = "_dev"
environment = "staging"
config.is_dev = is_dev
# Setup logger
logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)
fh = logging.StreamHandler()
fh_formatter = logging.Formatter(
"%(asctime)s %(levelname)s %(filename)s(%(process)d) - %(message)s"
)
fh.setFormatter(fh_formatter)
logger.addHandler(fh)
config.logger = logger
# Gather node information
config.cuda_dev = torch.cuda.current_device()
cpu_id = platform.processor()
if os.path.exists('/proc/cpuinfo'):
cpu_id = [line for line in open("/proc/cpuinfo", 'r').readlines() if
'model name' in line][0].rstrip().split(': ')[-1]
config.cpu_id = cpu_id
config.gpu_id = torch.cuda.get_device_name(config.cuda_dev)
config.node_id = platform.node()
# Report on our CUDA memory and model.
gb_gpu = int(torch.cuda.get_device_properties(
config.cuda_dev).total_memory / (1000 * 1000 * 1000))
logger.info(f"CPU: {config.cpu_id}")
logger.info(f"GPU: {config.gpu_id}")
logger.info(f"GPU RAM: {gb_gpu}gb")
config.model_name = os.environ['MODEL']
logger.info(f"MODEL: {config.model_name}")
# Resolve where we get our model and data from.
config.model_path = os.getenv('MODEL_PATH', None)
config.enable_ema = os.getenv('ENABLE_EMA', "1")
config.basedformer = os.getenv('BASEDFORMER', "0")
config.penultimate = os.getenv('PENULTIMATE', "0")
config.vae_path = os.getenv('VAE_PATH', None)
config.module_path = os.getenv('MODULE_PATH', None)
config.prior_path = os.getenv('PRIOR_PATH', None)
config.default_config = os.getenv('DEFAULT_CONFIG', None)
config.quality_hack = os.getenv('QUALITY_HACK', "0")
config.clip_contexts = os.getenv('CLIP_CONTEXTS', "1")
try:
config.clip_contexts = int(config.clip_contexts)
if config.clip_contexts < 1 or config.clip_contexts > 10:
config.clip_contexts = 1
except:
config.clip_contexts = 1
# Misc settings
config.model_alias = os.getenv('MODEL_ALIAS')
# Instantiate our actual model.
load_time = time.time()
model_hash = None
try:
if config.model_name != "dalle-mini":
model = no_init(lambda: model_map[config.model_name](config))
else:
model = model_map[config.model_name](config)
except Exception as e:
traceback.print_exc()
logger.error(f"Failed to load model: {str(e)}")
#exit gunicorn
sys.exit(4)
if config.model_name == "stable-diffusion":
folder = Path(config.model_path)
if (folder / "pruned.ckpt").is_file():
model_path = folder / "pruned.ckpt"
else:
model_path = folder / "model.ckpt"
model_hash = crc32(model_path)
#Load Modules
if config.module_path is not None:
modules = load_modules(config.module_path)
#attach it to the model
model.premodules = modules
if lautocast.lowvram == False and lautocast.medvram == False:
lowvram.setup_for_low_vram(model.model, True)
config.model = model
time_load = time.time() - load_time
logger.info(f"Models loaded in {time_load:.2f}s")
return model, config, model_hash

@ -0,0 +1,9 @@
import torch
import os
lowvram = True if os.environ.get('LOWVRAM') == "1" else False
medvram = True if os.environ.get('MEDVRAM') == "1" else False
dtype = torch.float32 if os.environ.get('DTYPE', 'float32') else torch.float16
print("using dtype: " + os.environ.get('DTYPE', 'float32'))

@ -0,0 +1,225 @@
# from github.com/AUTOMATIC1111/stable-diffusion-webui
import torch
from torch.nn.functional import silu
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
module_in_gpu = None
cpu = torch.device("cpu")
device = gpu = torch.device("cuda")
def send_everything_to_cpu():
global module_in_gpu
if module_in_gpu is not None:
module_in_gpu.to(cpu)
module_in_gpu = None
def setup_for_low_vram(sd_model, use_medvram):
parents = {}
def send_me_to_gpu(module, _):
"""send this module to GPU; send whatever tracked module was previous in GPU to CPU;
we add this as forward_pre_hook to a lot of modules and this way all but one of them will
be in CPU
"""
global module_in_gpu
module = parents.get(module, module)
if module_in_gpu == module:
return
if module_in_gpu is not None:
module_in_gpu.to(cpu)
module.to(gpu)
module_in_gpu = module
# see below for register_forward_pre_hook;
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
# useless here, and we just replace those methods
def first_stage_model_encode_wrap(self, encoder, x):
send_me_to_gpu(self, None)
return encoder(x)
def first_stage_model_decode_wrap(self, decoder, z):
send_me_to_gpu(self, None)
return decoder(z)
# remove three big modules, cond, first_stage, and unet from the model and then
# send the model to GPU. Then put modules back. the modules will be in CPU.
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
sd_model.to(device)
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
# register hooks for those the first two models
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if use_medvram:
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
else:
diff_model = sd_model.model.diffusion_model
# the third remaining model is still too big for 4 GB, so we also do the same for its submodules
# so that only one of them is in GPU at a time
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
sd_model.model.to(device)
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
# install hooks for bits of third model
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.input_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.output_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
ldm.modules.diffusionmodules.model.nonlinearity = silu
try:
import xformers
except ImportError:
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
import math
import torch
from torch import einsum
from ldm.util import default
from einops import rearrange
# taken from https://github.com/Doggettx/stable-diffusion
def split_cross_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context) * self.scale
v_in = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
q1 = self.q(h_)
k1 = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q1.shape
q2 = q1.reshape(b, c, h*w)
del q1
q = q2.permute(0, 2, 1) # b,hw,c
del q2
k = k1.reshape(b, c, h*w) # b,c,hw
del k1
h_ = torch.zeros_like(k, device=q.device)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w2 = w1 * (int(c)**(-0.5))
del w1
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
del w2
# attend to values
v1 = v.reshape(b, c, h*w)
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
del w3
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
del v1, w4
h2 = h_.reshape(b, c, h, w)
del h_
h3 = self.proj_out(h2)
del h2
h3 += x
return h3

@ -0,0 +1,824 @@
import os
import bisect
import json
from re import S
import torch
import torch.nn as nn
from pathlib import Path
from omegaconf import OmegaConf
from dotmap import DotMap
import numpy as np
from torch import autocast
from einops import rearrange, repeat
from torchvision.utils import make_grid
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.modules.attention import CrossAttention, HyperLogic
from PIL import Image
import k_diffusion as K
import contextlib
import random
import base64
from .lowvram import setup_for_low_vram
from . import lautocast
class CallbackDelegate:
total_steps = 0
current_step = -1
callback = None
def __init__(self, callback, total_steps) -> None:
self.callback = callback
self.total_steps = total_steps
def update(self, n = None):
self.current_step += 1
if self.callback:
self.callback(self.current_step, self.total_steps)
return n
def seed_everything(seed: int):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def pil_upscale(image, scale=1):
device = image.device
dtype = image.dtype
image = Image.fromarray((image.cpu().permute(1,2,0).numpy().astype(np.float32) * 255.).astype(np.uint8))
if scale > 1:
image = image.resize((int(image.width * scale), int(image.height * scale)), resample=Image.LANCZOS)
image = np.array(image)
image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
image = 2.*image - 1.
image = repeat(image, '1 ... -> b ...', b=1)
return image.to(device)
def fix_batch(tensor, bs):
return torch.stack([tensor.squeeze(0)]*bs, dim=0)
def torch_gc():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
# make uc and prompt shapes match via padding for long prompts
# finetune
null_cond = None
def fix_cond_shapes(model, prompt_condition, uc):
global null_cond
if null_cond is None:
null_cond = model.get_learned_conditioning([""])
while prompt_condition.shape[1] > uc.shape[1]:
uc = torch.cat((uc, null_cond.repeat((uc.shape[0], 1, 1))), axis=1)
while prompt_condition.shape[1] < uc.shape[1]:
prompt_condition = torch.cat((prompt_condition, null_cond.repeat((prompt_condition.shape[0], 1, 1))), axis=1)
return prompt_condition, uc
# mix conditioning vectors for prompts
# @aero
def prompt_mixing(model, prompt_body, batch_size):
if "|" in prompt_body:
prompt_parts = prompt_body.split("|")
prompt_total_power = 0
prompt_sum = None
for prompt_part in prompt_parts:
prompt_power = 1
if ":" in prompt_part:
prompt_sub_parts = prompt_part.split(":")
try:
prompt_power = float(prompt_sub_parts[1])
prompt_part = prompt_sub_parts[0]
except:
print("Error parsing prompt power! Assuming 1")
prompt_vector = model.get_learned_conditioning([prompt_part])
if prompt_sum is None:
prompt_sum = prompt_vector * prompt_power
else:
prompt_sum, prompt_vector = fix_cond_shapes(model, prompt_sum, prompt_vector)
prompt_sum = prompt_sum + (prompt_vector * prompt_power)
prompt_total_power = prompt_total_power + prompt_power
return fix_batch(prompt_sum / prompt_total_power, batch_size)
else:
return fix_batch(model.get_learned_conditioning([prompt_body]), batch_size)
def sample_start_noise(seed, C, H, W, f, device="cuda"):
if seed:
gen = torch.Generator(device=device)
gen.manual_seed(seed)
noise = torch.randn([C, (H) // f, (W) // f], generator=gen, device=device).unsqueeze(0)
else:
noise = torch.randn([C, (H) // f, (W) // f], device=device).unsqueeze(0)
return noise
def sample_start_noise_special(seed, request, device="cuda"):
if seed:
gen = torch.Generator(device=device)
gen.manual_seed(seed)
noise = torch.randn([request.latent_channels, request.height // request.downsampling_factor, request.width // request.downsampling_factor], generator=gen, device=device).unsqueeze(0)
else:
noise = torch.randn([request.latent_channels, request.height // request.downsampling_factor, request.width // request.downsampling_factor], device=device).unsqueeze(0)
return noise
@torch.no_grad()
def encode_image(image, model):
if isinstance(image, Image.Image):
image = np.array(image)
image = torch.from_numpy(image)
if isinstance(image, np.ndarray):
image = torch.from_numpy(image)
#dtype = image.dtype
image = image.to(torch.float32)
#gets image as numpy array and returns as tensor
def preprocess_vqgan(x):
x = x / 255.0
x = 2.*x - 1.
return x
image = image.permute(2, 0, 1).unsqueeze(0).float().cuda()
image = preprocess_vqgan(image)
image = model.encode(image).sample()
#image = image.to(dtype)
return image
@torch.no_grad()
def decode_image(image, model):
def custom_to_pil(x):
x = x.detach().float().cpu()
x = torch.clamp(x, -1., 1.)
x = (x + 1.)/2.
x = x.permute(0, 2, 3, 1)#.numpy()
#x = (255*x).astype(np.uint8)
#x = Image.fromarray(x)
#if not x.mode == "RGB":
# x = x.convert("RGB")
return x
image = model.decode(image)
image = custom_to_pil(image)
return image
class VectorAdjustPrior(nn.Module):
def __init__(self, hidden_size, inter_dim=64):
super().__init__()
self.vector_proj = nn.Linear(hidden_size*2, inter_dim, bias=True)
self.out_proj = nn.Linear(hidden_size+inter_dim, hidden_size, bias=True)
def forward(self, z):
b, s = z.shape[0:2]
x1 = torch.mean(z, dim=1).repeat(s, 1)
x2 = z.reshape(b*s, -1)
x = torch.cat((x1, x2), dim=1)
x = self.vector_proj(x)
x = torch.cat((x2, x), dim=1)
x = self.out_proj(x)
x = x.reshape(b, s, -1)
return x
@classmethod
def load_model(cls, model_path, hidden_size=768, inter_dim=64):
model = cls(hidden_size=hidden_size, inter_dim=inter_dim)
model.load_state_dict(torch.load(model_path)["state_dict"])
return model
class StableInterface(nn.Module):
def __init__(self, model, thresholder = None):
super().__init__()
self.inner_model = model
self.sigma_to_t = model.sigma_to_t
self.thresholder = thresholder
self.get_sigmas = model.get_sigmas
@torch.no_grad()
def forward(self, x, sigma, uncond, cond, cond_scale):
x_two = torch.cat([x] * 2)
sigma_two = torch.cat([sigma] * 2)
cond_full = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_two, sigma_two, cond=cond_full).chunk(2)
x_0 = uncond + (cond - uncond) * cond_scale
if self.thresholder is not None:
x_0 = self.thresholder(x_0)
return x_0
class StableDiffusionModel(nn.Module):
def __init__(self, config):
nn.Module.__init__(self)
self.config = config
self.premodules = None
if Path(self.config.model_path).is_dir():
config.logger.info(f"Loading model from folder {self.config.model_path}")
model, model_config = self.from_folder(config.model_path)
elif Path(self.config.model_path).is_file():
config.logger.info(f"Loading model from file {self.config.model_path}")
model, model_config = self.from_file(config.model_path)
else:
raise Exception("Invalid model path!")
if config.dtype == "float16":
typex = torch.float16
else:
typex = torch.float32
print("----------------------------------")
if lautocast.medvram == True:
print("passed lowvram setup for medvram (6G)")
setup_for_low_vram(model, True)
elif lautocast.lowvram == True:
print("setup for lowvram (4G and lower)")
setup_for_low_vram(model, False)
else:
print("loading model")
model.to(config.device)
print("----------------------------------")
self.model = model.to(typex)
# self.model = model.to(config.device).to(typex)
if self.config.vae_path:
ckpt=torch.load(self.config.vae_path, map_location="cpu")
loss = []
for i in ckpt["state_dict"].keys():
if i[0:4] == "loss":
loss.append(i)
for i in loss:
del ckpt["state_dict"][i]
model.first_stage_model = model.first_stage_model.float()
model.first_stage_model.load_state_dict(ckpt["state_dict"])
model.first_stage_model = model.first_stage_model.float()
del ckpt
del loss
config.logger.info(f"Using VAE from {self.config.vae_path}")
if self.config.penultimate == "1":
model.cond_stage_model.return_layer = -2
model.cond_stage_model.do_final_ln = True
config.logger.info(f"CLIP: Using penultimate layer")
if self.config.clip_contexts > 1:
model.cond_stage_model.clip_extend = True
model.cond_stage_model.max_clip_extend = 75 * self.config.clip_contexts
model.cond_stage_model.inference_mode = True
self.k_model = K.external.CompVisDenoiser(model)
self.k_model = StableInterface(self.k_model)
self.device = config.device
self.model_config = model_config
self.plms = PLMSSampler(model)
self.ddim = DDIMSampler(model)
self.ema_manager = self.model.ema_scope
if self.config.enable_ema == "0":
self.ema_manager = contextlib.nullcontext
config.logger.info("Disabling EMA")
else:
config.logger.info(f"Using EMA")
self.sampler_map = {
'plms': self.plms.sample,
'ddim': self.ddim.sample,
'k_euler': K.sampling.sample_euler,
'k_euler_ancestral': K.sampling.sample_euler_ancestral,
'k_heun': K.sampling.sample_heun,
'k_dpm_2': K.sampling.sample_dpm_2,
'k_dpm_2_ancestral': K.sampling.sample_dpm_2_ancestral,
'k_lms': K.sampling.sample_lms,
}
if config.prior_path:
self.prior = VectorAdjustPrior.load_model(config.prior_path).to(self.device)
self.copied_ema = False
@property
def get_default_config(self):
dict_config = {
'steps': 30,
'sampler': "k_euler_ancestral",
'n_samples': 1,
'image': None,
'fixed_code': False,
'ddim_eta': 0.0,
'height': 512,
'width': 512,
'latent_channels': 4,
'downsampling_factor': 8,
'scale': 12.0,
'dynamic_threshold': None,
'seed': None,
'stage_two_seed': None,
'module': None,
'masks': None,
'output': None,
}
return DotMap(dict_config)
def from_folder(self, folder):
folder = Path(folder)
model_config = OmegaConf.load(folder / "config.yaml")
if (folder / "pruned.ckpt").is_file():
model_path = folder / "pruned.ckpt"
else:
model_path = folder / "model.ckpt"
model = self.load_model_from_config(model_config, model_path)
return model, model_config
def from_path(self, file):
default_config = Path(self.config.default_config)
if not default_config.is_file():
raise Exception("Default config to load not found! Either give a folder on MODEL_PATH or specify a config to use with this checkpoint on DEFAULT_CONFIG")
model_config = OmegaConf.load(default_config)
model = self.load_model_from_config(model_config, file)
return model, model_config
def load_model_from_config(self, config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd.get('state_dict', pl_sd)
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
model.eval()
return model
@torch.no_grad()
@torch.autocast("cuda", enabled=False, dtype=lautocast.dtype)
def sample(self, request, callback = None):
if request.module is not None:
if request.module == "vanilla":
pass
else:
module = self.premodules[request.module]
CrossAttention.set_hypernetwork(module)
if request.seed is not None:
seed_everything(request.seed)
if request.image is not None:
request.steps = 50
#request.sampler = "ddim_img2img" #enforce ddim for now
if request.sampler == "plms":
request.sampler = "k_lms"
if request.sampler == "ddim":
request.sampler = "k_lms"
self.ddim.make_schedule(ddim_num_steps=request.steps, ddim_eta=request.ddim_eta, verbose=False)
start_code = encode_image(request.image, self.model.first_stage_model).to(self.device)
start_code = self.model.get_first_stage_encoding(start_code)
start_code = torch.repeat_interleave(start_code, request.n_samples, dim=0)
main_noise = []
start_noise = []
for seed in range(request.seed, request.seed+request.n_samples):
main_noise.append(sample_start_noise(seed, request.latent_channels, request.height, request.width, request.downsampling_factor, self.device))
start_noise.append(sample_start_noise(seed, request.latent_channels, request.height, request.width, request.downsampling_factor, self.device))
main_noise = torch.cat(main_noise, dim=0)
start_noise = torch.cat(start_noise, dim=0)
start_code = start_code + (start_noise * request.noise)
t_enc = int(request.strength * request.steps)
if request.sampler.startswith("k_"):
sampler = "k-diffusion"
elif request.sampler == 'ddim_img2img':
sampler = 'img2img'
else:
sampler = "normal"
if request.image is None:
main_noise = []
for seed_offset in range(request.n_samples):
if request.masks is not None:
noise_x = sample_start_noise_special(request.seed, request, self.device)
else:
noise_x = sample_start_noise_special(request.seed+seed_offset, request, self.device)
if request.masks is not None:
for maskobj in request.masks:
mask_seed = maskobj["seed"]
mask = maskobj["mask"]
mask = np.asarray(mask)
mask = torch.from_numpy(mask).clone().to(self.device).permute(2, 0, 1)
mask = mask.float() / 255.0
# convert RGB or grayscale image into 4-channel
mask = mask[0].unsqueeze(0)
mask = torch.repeat_interleave(mask, request.latent_channels, dim=0)
mask = (mask < 0.5).float()
# interpolate start noise
noise_x = (noise_x * (1-mask)) + (sample_start_noise_special(mask_seed+seed_offset, request, self.device) * mask)
main_noise.append(noise_x)
main_noise = torch.cat(main_noise, dim=0)
start_code = main_noise
prompt = [request.prompt]
prompt_condition = prompt_mixing(self.model, prompt[0], 1)
if hasattr(self, "prior") and request.mitigate:
prompt_condition = self.prior(prompt_condition)
uc = None
if request.scale != 1.0:
if request.uc is not None:
uc = [request.uc]
uc = prompt_mixing(self.model, uc[0], 1)
else:
if self.config.quality_hack == "1":
uc = ["Tags: lowres"]
uc = prompt_mixing(self.model, uc[0], 1)
else:
uc = self.model.get_learned_conditioning([""])
prompt_condition, uc = fix_cond_shapes(self.model, prompt_condition, uc)
shape = [
request.latent_channels,
request.height // request.downsampling_factor,
request.width // request.downsampling_factor
]
c_dele = CallbackDelegate(callback, request.steps * request.n_samples)
# handle images one at a time because batches eat absurd VRAM
sampless = []
for main_noise, start_code in zip(main_noise.chunk(request.n_samples), start_code.chunk(request.n_samples)):
if sampler == "normal":
with self.ema_manager():
c_dele.update()
samples, _ = self.sampler_map[request.sampler](
S=request.steps,
conditioning=prompt_condition,
batch_size=1,
shape=shape,
verbose=False,
unconditional_guidance_scale=request.scale,
unconditional_conditioning=uc,
eta=request.ddim_eta,
dynamic_threshold=request.dynamic_threshold,
x_T=start_code,
callback=c_dele.update
)
elif sampler == "k-diffusion":
with self.ema_manager():
sigmas = self.k_model.get_sigmas(request.steps)
if request.image is not None:
noise = main_noise * sigmas[request.steps - t_enc - 1]
start_code = start_code + noise
sigmas = sigmas[request.steps - t_enc - 1:]
else:
start_code = start_code * sigmas[0]
extra_args = {'cond': prompt_condition, 'uncond': uc, 'cond_scale': request.scale}
c_dele.update()
samples = self.sampler_map[request.sampler](
self.k_model,
start_code,
sigmas,
request.seed,
callback=c_dele.update,
extra_args=extra_args
)
sampless.append(samples)
torch_gc()
images = []
for samples in sampless:
with torch.autocast("cuda", enabled=self.config.amp):
x_samples_ddim = self.model.decode_first_stage(samples.float())
#x_samples_ddim = decode_image(samples, self.model.first_stage_model)
#x_samples_ddim = self.model.first_stage_model.decode(samples.float())
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = x_sample.astype(np.uint8)
x_sample = np.ascontiguousarray(x_sample)
images.append(x_sample)
torch_gc()
if request.seed is not None:
torch.seed()
np.random.seed()
#set hypernetwork to none after generation
CrossAttention.set_hypernetwork(None)
return images
@torch.no_grad()
def sample_two_stages(self, request, callback = None):
request = DotMap(request)
if request.seed is not None:
seed_everything(request.seed)
if request.plms:
sampler = self.plms
else:
sampler = self.ddim
start_code = None
if request.fixed_code:
start_code = torch.randn([
request.n_samples,
request.latent_channels,
request.height // request.downsampling_factor,
request.width // request.downsampling_factor,
], device=self.device)
prompt = [request.prompt] * request.n_samples
prompt_condition = prompt_mixing(self.model, prompt[0], request.n_samples)
uc = None
if request.scale != 1.0:
uc = self.model.get_learned_conditioning(request.n_samples * [""])
prompt_condition, uc = fix_cond_shapes(self.model, prompt_condition, uc)
shape = [
request.latent_channels,
request.height // request.downsampling_factor,
request.width // request.downsampling_factor
]
c_dele = CallbackDelegate(callback, request.steps * request.n_samples)
with torch.autocast("cuda", enabled=self.config.amp):
with self.ema_manager():
samples, _ = sampler.sample(
S=request.steps,
conditioning=prompt_condition,
batch_size=request.n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=request.scale,
unconditional_conditioning=uc,
eta=request.ddim_eta,
dynamic_threshold=request.dynamic_threshold,
x_T=start_code,
callback=c_dele.update
)
x_samples_ddim = self.model.decode_first_stage(samples)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).squeeze(0)
x_samples_ddim = pil_upscale(x_samples_ddim, scale=2)
if request.stage_two_seed is not None:
torch.manual_seed(request.stage_two_seed)
np.random.seed(request.stage_two_seed)
with torch.autocast("cuda", enabled=self.config.amp):
with self.ema_manager():
init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(x_samples_ddim))
self.ddim.make_schedule(ddim_num_steps=request.steps, ddim_eta=request.ddim_eta, verbose=False)
t_enc = int(request.strength * request.steps)
print("init latent shape:")
print(init_latent.shape)
init_latent = init_latent + (torch.randn_like(init_latent) * request.noise)
prompt_condition = prompt_mixing(self.model, prompt[0], request.n_samples)
uc = None
if request.scale != 1.0:
uc = self.model.get_learned_conditioning(request.n_samples * [""])
prompt_condition, uc = fix_cond_shapes(self.model, prompt_condition, uc)
# encode (scaled latent)
start_code_terped=None
z_enc = self.ddim.stochastic_encode(init_latent, torch.tensor([t_enc]*request.n_samples).to(self.device), noise=start_code_terped)
# decode it
samples = self.ddim.decode(z_enc, prompt_condition, t_enc, unconditional_guidance_scale=request.scale,
unconditional_conditioning=uc,)
x_samples_ddim = self.model.decode_first_stage(samples)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
images = []
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = x_sample.astype(np.uint8)
x_sample = np.ascontiguousarray(x_sample)
images.append(x_sample)
if request.seed is not None:
torch.seed()
np.random.seed()
return images
@torch.no_grad()
def sample_from_image(self, request):
return
class DalleMiniModel(nn.Module):
def __init__(self, config):
nn.Module.__init__(self)
from min_dalle import MinDalle
self.config = config
self.model = MinDalle(
models_root=config.model_path,
dtype=torch.float16,
device='cuda',
is_mega=True,
is_reusable=True
)
@torch.no_grad()
def sample(self, request, callback = None):
if request.seed is not None:
seed = request.seed
else:
seed = -1
images = self.model.generate_images(
text=request.prompt,
seed=seed,
grid_size=request.grid_size,
is_seamless=False,
temperature=request.temp,
top_k=request.top_k,
supercondition_factor=request.scale,
is_verbose=False
)
images = images.to('cpu').numpy()
images = images.astype(np.uint8)
images = np.ascontiguousarray(images)
if request.seed is not None:
torch.seed()
np.random.seed()
return images
def apply_temp(logits, temperature):
logits = logits / temperature
return logits
@torch.no_grad()
def generate(forward, prompt_tokens, tokenizer, tokens_to_generate=50, ds=False, ops_list=[{"temp": 0.9}], hypernetwork=None, non_deterministic=False, fully_deterministic=False):
in_tokens = prompt_tokens
context = prompt_tokens
generated = torch.zeros(len(ops_list), 0, dtype=torch.long).to(in_tokens.device)
kv = None
if non_deterministic:
torch.seed()
#soft_required = ["top_k", "top_p"]
op_map = {
"temp": apply_temp,
}
for _ in range(tokens_to_generate):
if ds:
logits, kv = forward(in_tokens, past_key_values=kv, use_cache=True)
else:
logits, kv = forward(in_tokens, cache=True, kv=kv, hypernetwork=hypernetwork)
logits = logits[:, -1, :] #get the last token in the seq
logits = torch.log_softmax(logits, dim=-1)
batch = []
for i, ops in enumerate(ops_list):
item = logits[i, ...].unsqueeze(0)
ctx = context[i, ...].unsqueeze(0)
for op, value in ops.items():
if op == "rep_pen":
item = op_map[op](ctx, item, **value)
else:
item = op_map[op](item, value)
batch.append(item)
logits = torch.cat(batch, dim=0)
logits = torch.softmax(logits, dim=-1)
#fully_deterministic makes it deterministic across the batch
if fully_deterministic:
logits = logits.split(1, dim=0)
logit_list = []
for logit in logits:
torch.manual_seed(69)
logit_list.append(torch.multinomial(logit, 1))
logits = torch.cat(logit_list, dim=0)
else:
logits = torch.multinomial(logits, 1)
if logits[0, 0] == 48585:
if generated[0, -1] == 1400:
pass
elif generated[0, -1] == 3363:
return "safe", "none"
else:
return "notsafe", tokenizer.decode(generated.squeeze()).split("Output: ")[-1]
generated = torch.cat([generated, logits], dim=-1)
context = torch.cat([context, logits], dim=-1)
in_tokens = logits
return "unknown", tokenizer.decode(generated.squeeze())
class BasedformerModel(nn.Module):
def __init__(self, config):
nn.Module.__init__(self)
from basedformer import lm_utils
from transformers import GPT2TokenizerFast
self.config = config
self.model = lm_utils.load_from_path(config.model_path).half().cuda()
self.model = self.model.convert_to_ds()
self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
@torch.no_grad()
def sample(self, request, callback = None):
prompt = request.prompt
prompt = self.tokenizer.encode("Input: " + prompt, return_tensors='pt').cuda().long()
prompt = torch.cat([prompt, torch.tensor([[49527]], dtype=torch.long).cuda()], dim=1)
is_safe, corrected = generate(self.model.module, prompt, self.tokenizer, tokens_to_generate=150, ds=True)
return is_safe, corrected
class EmbedderModel(nn.Module):
def __init__(self, config=None):
nn.Module.__init__(self)
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer('./models/sentence-transformers_all-MiniLM-L6-v2').cuda()
self.tags = [tuple(x) for x in json.load(open("models/tags.json"))]
self.knn = self.load_knn("models/tags.index")
print("Loaded tag suggestion model using phrase embeddings")
def load_knn(self, filename):
import faiss
try:
return faiss.read_index(filename)
except RuntimeError:
print(f"Generating tag embedding index for {len(self.tags)} tags.")
i = faiss.IndexFlatL2(384)
i.add(self([name for name, count in self.tags]))
faiss.write_index(i, filename)
return i
def __call__(self, sentences):
with torch.no_grad():
sentence_embeddings = self.model.encode(sentences)
return sentence_embeddings
def get_top_k(self, text):
#check if text is a substring in tag_count.keys()
found = []
a = bisect.bisect_left(self.tags, (text,))
b = bisect.bisect_left(self.tags, (text + '\xff',), lo=a)
for tag, count in self.tags[a:b]:
if len(tag) >= len(text) and tag.startswith(text):
found.append([tag, count, 0])
results = []
embedding = self([text])
k = 15
D, I = self.knn.search(embedding, k)
D, I = D.squeeze(), I.squeeze()
for id, prob in zip(I, D):
tag, count = self.tags[id]
results.append([tag, count, prob])
found.sort(key=lambda x: x[1], reverse=True)
found = found[:5]
# found = heapq.nlargest(5, found, key=lambda x: x[1])
results_tags = [x[0] for x in found]
for result in results.copy():
if result[0] in results_tags:
results.remove(result)
results = sorted(results, key=lambda x: x[2], reverse=True)
#filter results for >0.5 confidence unless it has the search text in it and confidence is >0.4
results = [x for x in results if x[2] > 0.5 or (x[2] > 0.4 and text in x[0])]
if found:
results = found + results
#max 10 results
results = results[:10]
results = sorted(results, key=lambda x: x[1], reverse=True)
return results

@ -0,0 +1,221 @@
import traceback
from dotmap import DotMap
import math
from io import BytesIO
import base64
import random
v1pp_defaults = {
'steps': 50,
'sampler': "plms",
'image': None,
'fixed_code': False,
'ddim_eta': 0.0,
'height': 512,
'width': 512,
'latent_channels': 4,
'downsampling_factor': 8,
'scale': 7.0,
'dynamic_threshold': None,
'seed': None,
'stage_two_seed': None,
'module': None,
'masks': None,
}
v1pp_forced_defaults = {
'latent_channels': 4,
'downsampling_factor': 8,
}
dalle_mini_defaults = {
'temp': 1.0,
'top_k': 256,
'scale': 16,
'grid_size': 4,
}
dalle_mini_forced_defaults = {
}
defaults = {
'stable-diffusion': (v1pp_defaults, v1pp_forced_defaults),
'dalle-mini': (dalle_mini_defaults, dalle_mini_forced_defaults),
'basedformer': ({}, {}),
'embedder': ({}, {}),
}
samplers = [
"plms",
"ddim",
"k_euler",
"k_euler_ancestral",
"k_heun",
"k_dpm_2",
"k_dpm_2_ancestral",
"k_lms"
]
def closest_multiple(num, mult):
num_int = int(num)
floor = math.floor(num_int / mult) * mult
ceil = math.ceil(num_int / mult) * mult
return floor if (num_int - floor) < (ceil - num_int) else ceil
def sanitize_stable_diffusion(request, config):
if request.steps > 50:
return False, "steps must be smaller than 50"
if request.width * request.height == 0:
return False, "width and height must be non-zero"
if request.width <= 0:
return False, "width must be positive"
if request.height <= 0:
return False, "height must be positive"
if request.steps <= 0:
return False, "steps must be positive"
if request.ddim_eta < 0:
return False, "ddim_eta shouldn't be negative"
if request.scale < 1.0:
return False, "scale should be at least 1.0"
if request.dynamic_threshold is not None and request.dynamic_threshold < 0:
return False, "dynamic_threshold shouldn't be negative"
if request.width * request.height >= 1024*1025:
return False, "width and height must be less than 1024*1025"
if request.strength < 0.0 or request.strength >= 1.0:
return False, "strength should be more than 0.0 and less than 1.0"
if request.noise < 0.0 or request.noise > 1.0:
return False, "noise should be more than 0.0 and less than 1.0"
if request.advanced:
request.width = closest_multiple(request.width // 2, 64)
request.height = closest_multiple(request.height // 2, 64)
if request.sampler not in samplers:
return False, "sampler should be one of {}".format(samplers)
if request.seed is None:
state = random.getstate()
request.seed = random.randint(0, 2**32)
random.setstate(state)
if request.module is not None:
if request.module not in config.model.premodules and request.module != "vanilla":
return False, "module should be one of: " + ", ".join(config.model.premodules)
max_gens = 100
if 0:
num_gen_tiers = [(1024*512, 4), (640*640, 6), (704*512, 8), (512*512, 16), (384*640, 18)]
pixel_count = request.width * request.height
for tier in num_gen_tiers:
if pixel_count <= tier[0]:
max_gens = tier[1]
else:
break
if request.n_samples > max_gens:
return False, f"requested more ({request.n_samples}) images than possible at this resolution"
if request.image is not None:
#decode from base64
try:
request.image = base64.b64decode(request.image.encode('utf-8'))
except Exception as e:
traceback.print_exc()
return False, "image is not valid base64"
#check if image is valid
try:
from PIL import Image
image = Image.open(BytesIO(request.image))
image.verify()
except Exception as e:
traceback.print_exc()
return False, "image is not valid"
#image is valid, load it again(still check again, verify() can't be sure as it doesn't decode.)
try:
image = Image.open(BytesIO(request.image))
image = image.convert('RGB')
image = image.resize((request.width, request.height), resample=Image.Resampling.LANCZOS)
request.image = image
except Exception as e:
traceback.print_exc()
return False, "Error while opening and cleaning image"
if request.masks is not None:
masks = request.masks
for x in range(len(masks)):
image = masks[x]["mask"]
try:
image_bytes = base64.b64decode(image.encode('utf-8'))
except Exception as e:
traceback.print_exc()
return False, "image is not valid base64"
try:
from PIL import Image
image = Image.open(BytesIO(image_bytes))
image.verify()
except Exception as e:
traceback.print_exc()
return False, "image is not valid"
#image is valid, load it again(still check again, verify() can't be sure as it doesn't decode.)
try:
image = Image.open(BytesIO(image_bytes))
#image = image.convert('RGB')
image = image.resize((request.width//request.downsampling_factor, request.height//request.downsampling_factor), resample=Image.Resampling.LANCZOS)
except Exception as e:
traceback.print_exc()
return False, "Error while opening and cleaning image"
masks[x]["mask"] = image
return True, request
def sanitize_dalle_mini(request):
return True, request
def sanitize_basedformer(request):
return True, request
def sanitize_embedder(request):
return True, request
def sanitize_input(config, request):
"""
Sanitize the input data and set defaults
"""
request = DotMap(request)
default, forced_default = defaults[config.model_name]
for k, v in default.items():
if k not in request:
request[k] = v
for k, v in forced_default.items():
request[k] = v
if config.model_name == 'stable-diffusion':
return sanitize_stable_diffusion(request, config)
elif config.model_name == 'dalle-mini':
return sanitize_dalle_mini(request)
elif config.model_name == 'basedformer':
return sanitize_basedformer(request)
elif config.model_name == "embedder":
return sanitize_embedder(request)

@ -0,0 +1,193 @@
import os
import torch
import logging
import os
import platform
import socket
import sys
import time
from dotmap import DotMap
from hydra_node.models import StableDiffusionModel, DalleMiniModel, BasedformerModel, EmbedderModel
from hydra_node import lowvram
import traceback
import zlib
from pathlib import Path
from ldm.modules.attention import CrossAttention, HyperLogic
model_map = {
"stable-diffusion": StableDiffusionModel,
"dalle-mini": DalleMiniModel,
"basedformer": BasedformerModel,
"embedder": EmbedderModel,
}
def no_init(loading_code):
def dummy(self):
return
modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm]
original = {}
for mod in modules:
original[mod] = mod.reset_parameters
mod.reset_parameters = dummy
result = loading_code()
for mod in modules:
mod.reset_parameters = original[mod]
return result
def crc32(filename, chunksize=65536):
"""Compute the CRC-32 checksum of the contents of the given filename"""
with open(filename, "rb") as f:
checksum = 0
while True:
chunk = f.read(chunksize)
if not chunk:
break
checksum = zlib.crc32(chunk, checksum)
#while (chunk := f.read(chunksize)) :
# checksum = zlib.crc32(chunk, checksum)
return '%08X' % (checksum & 0xFFFFFFFF)
def load_modules(path):
path = Path(path)
modules = {}
if not path.is_dir():
return
for file in path.iterdir():
module = load_module(file, "cpu")
modules[file.stem] = module
print(f"Loaded module {file.stem}")
return modules
def load_module(path, device):
path = Path(path)
if not path.is_file():
print("Module path {} is not a file".format(path))
network = {
768: (HyperLogic(768).to(device), HyperLogic(768).to(device)),
1280: (HyperLogic(1280).to(device), HyperLogic(1280).to(device)),
640: (HyperLogic(640).to(device), HyperLogic(640).to(device)),
320: (HyperLogic(320).to(device), HyperLogic(320).to(device)),
}
state_dict = torch.load(path)
for key in state_dict.keys():
network[key][0].load_state_dict(state_dict[key][0])
network[key][1].load_state_dict(state_dict[key][1])
return network
def init_config_model():
config = DotMap()
config.savefiles = os.getenv("SAVE_FILES", False)
config.dtype = os.getenv("DTYPE", "float16")
config.device = os.getenv("DEVICE", "cuda")
config.amp = os.getenv("AMP", False)
if config.amp == "1":
config.amp = True
elif config.amp == "0":
config.amp = False
is_dev = ""
environment = "production"
if os.environ['DEV'] == "True":
is_dev = "_dev"
environment = "staging"
config.is_dev = is_dev
# Setup logger
logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)
fh = logging.StreamHandler()
fh_formatter = logging.Formatter(
"%(asctime)s %(levelname)s %(filename)s(%(process)d) - %(message)s"
)
fh.setFormatter(fh_formatter)
logger.addHandler(fh)
config.logger = logger
# Gather node information
config.cuda_dev = torch.cuda.current_device()
cpu_id = platform.processor()
if os.path.exists('/proc/cpuinfo'):
cpu_id = [line for line in open("/proc/cpuinfo", 'r').readlines() if
'model name' in line][0].rstrip().split(': ')[-1]
config.cpu_id = cpu_id
config.gpu_id = torch.cuda.get_device_name(config.cuda_dev)
config.node_id = platform.node()
# Report on our CUDA memory and model.
gb_gpu = int(torch.cuda.get_device_properties(
config.cuda_dev).total_memory / (1000 * 1000 * 1000))
logger.info(f"CPU: {config.cpu_id}")
logger.info(f"GPU: {config.gpu_id}")
logger.info(f"GPU RAM: {gb_gpu}gb")
config.model_name = os.environ['MODEL']
logger.info(f"MODEL: {config.model_name}")
# Resolve where we get our model and data from.
config.model_path = os.getenv('MODEL_PATH', None)
config.enable_ema = os.getenv('ENABLE_EMA', "1")
config.basedformer = os.getenv('BASEDFORMER', "0")
config.penultimate = os.getenv('PENULTIMATE', "0")
config.vae_path = os.getenv('VAE_PATH', None)
config.module_path = os.getenv('MODULE_PATH', None)
config.prior_path = os.getenv('PRIOR_PATH', None)
config.default_config = os.getenv('DEFAULT_CONFIG', None)
config.quality_hack = os.getenv('QUALITY_HACK', "0")
config.clip_contexts = os.getenv('CLIP_CONTEXTS', "1")
try:
config.clip_contexts = int(config.clip_contexts)
if config.clip_contexts < 1 or config.clip_contexts > 10:
config.clip_contexts = 1
except:
config.clip_contexts = 1
# Misc settings
config.model_alias = os.getenv('MODEL_ALIAS')
# Instantiate our actual model.
load_time = time.time()
model_hash = None
try:
if config.model_name != "dalle-mini":
model = no_init(lambda: model_map[config.model_name](config))
else:
model = model_map[config.model_name](config)
except Exception as e:
traceback.print_exc()
logger.error(f"Failed to load model: {str(e)}")
#exit gunicorn
sys.exit(4)
if config.model_name == "stable-diffusion":
folder = Path(config.model_path)
if (folder / "pruned.ckpt").is_file():
model_path = folder / "pruned.ckpt"
else:
model_path = folder / "model.ckpt"
model_hash = crc32(model_path)
#Load Modules
if config.module_path is not None:
modules = load_modules(config.module_path)
#attach it to the model
model.premodules = modules
lowvram.setup_for_low_vram(model.model, True)
config.model = model
time_load = time.time() - load_time
logger.info(f"Models loaded in {time_load:.2f}s")
return model, config, model_hash

@ -0,0 +1,225 @@
# from github.com/AUTOMATIC1111/stable-diffusion-webui
import torch
from torch.nn.functional import silu
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
module_in_gpu = None
cpu = torch.device("cpu")
device = gpu = torch.device("cuda")
def send_everything_to_cpu():
global module_in_gpu
if module_in_gpu is not None:
module_in_gpu.to(cpu)
module_in_gpu = None
def setup_for_low_vram(sd_model, use_medvram):
parents = {}
def send_me_to_gpu(module, _):
"""send this module to GPU; send whatever tracked module was previous in GPU to CPU;
we add this as forward_pre_hook to a lot of modules and this way all but one of them will
be in CPU
"""
global module_in_gpu
module = parents.get(module, module)
if module_in_gpu == module:
return
if module_in_gpu is not None:
module_in_gpu.to(cpu)
module.to(gpu)
module_in_gpu = module
# see below for register_forward_pre_hook;
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
# useless here, and we just replace those methods
def first_stage_model_encode_wrap(self, encoder, x):
send_me_to_gpu(self, None)
return encoder(x)
def first_stage_model_decode_wrap(self, decoder, z):
send_me_to_gpu(self, None)
return decoder(z)
# remove three big modules, cond, first_stage, and unet from the model and then
# send the model to GPU. Then put modules back. the modules will be in CPU.
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
sd_model.to(device)
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
# register hooks for those the first two models
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if use_medvram:
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
else:
diff_model = sd_model.model.diffusion_model
# the third remaining model is still too big for 4 GB, so we also do the same for its submodules
# so that only one of them is in GPU at a time
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
sd_model.model.to(device)
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
# install hooks for bits of third model
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.input_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.output_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
ldm.modules.diffusionmodules.model.nonlinearity = silu
try:
import xformers
except ImportError:
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
import math
import torch
from torch import einsum
from ldm.util import default
from einops import rearrange
# taken from https://github.com/Doggettx/stable-diffusion
def split_cross_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context) * self.scale
v_in = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
q1 = self.q(h_)
k1 = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q1.shape
q2 = q1.reshape(b, c, h*w)
del q1
q = q2.permute(0, 2, 1) # b,hw,c
del q2
k = k1.reshape(b, c, h*w) # b,c,hw
del k1
h_ = torch.zeros_like(k, device=q.device)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w2 = w1 * (int(c)**(-0.5))
del w1
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
del w2
# attend to values
v1 = v.reshape(b, c, h*w)
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
del w3
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
del v1, w4
h2 = h_.reshape(b, c, h, w)
del h_
h3 = self.proj_out(h2)
del h2
h3 += x
return h3

@ -0,0 +1,806 @@
import os
import bisect
import json
from re import S
import torch
import torch.nn as nn
from pathlib import Path
from omegaconf import OmegaConf
from dotmap import DotMap
import numpy as np
from torch import autocast
from einops import rearrange, repeat
from torchvision.utils import make_grid
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.modules.attention import CrossAttention, HyperLogic
from PIL import Image
import k_diffusion as K
import contextlib
import random
class CallbackDelegate:
total_steps = 0
current_step = -1
callback = None
def __init__(self, callback, total_steps) -> None:
self.callback = callback
self.total_steps = total_steps
def update(self, n = None):
self.current_step += 1
if self.callback:
self.callback(self.current_step, self.total_steps)
return n
def seed_everything(seed: int):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def pil_upscale(image, scale=1):
device = image.device
dtype = image.dtype
image = Image.fromarray((image.cpu().permute(1,2,0).numpy().astype(np.float32) * 255.).astype(np.uint8))
if scale > 1:
image = image.resize((int(image.width * scale), int(image.height * scale)), resample=Image.LANCZOS)
image = np.array(image)
image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
image = 2.*image - 1.
image = repeat(image, '1 ... -> b ...', b=1)
return image.to(device)
def fix_batch(tensor, bs):
return torch.stack([tensor.squeeze(0)]*bs, dim=0)
def torch_gc():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
# make uc and prompt shapes match via padding for long prompts
# finetune
null_cond = None
def fix_cond_shapes(model, prompt_condition, uc):
global null_cond
if null_cond is None:
null_cond = model.get_learned_conditioning([""])
while prompt_condition.shape[1] > uc.shape[1]:
uc = torch.cat((uc, null_cond.repeat((uc.shape[0], 1, 1))), axis=1)
while prompt_condition.shape[1] < uc.shape[1]:
prompt_condition = torch.cat((prompt_condition, null_cond.repeat((prompt_condition.shape[0], 1, 1))), axis=1)
return prompt_condition, uc
# mix conditioning vectors for prompts
# @aero
def prompt_mixing(model, prompt_body, batch_size):
if "|" in prompt_body:
prompt_parts = prompt_body.split("|")
prompt_total_power = 0
prompt_sum = None
for prompt_part in prompt_parts:
prompt_power = 1
if ":" in prompt_part:
prompt_sub_parts = prompt_part.split(":")
try:
prompt_power = float(prompt_sub_parts[1])
prompt_part = prompt_sub_parts[0]
except:
print("Error parsing prompt power! Assuming 1")
prompt_vector = model.get_learned_conditioning([prompt_part])
if prompt_sum is None:
prompt_sum = prompt_vector * prompt_power
else:
prompt_sum, prompt_vector = fix_cond_shapes(model, prompt_sum, prompt_vector)
prompt_sum = prompt_sum + (prompt_vector * prompt_power)
prompt_total_power = prompt_total_power + prompt_power
return fix_batch(prompt_sum / prompt_total_power, batch_size)
else:
return fix_batch(model.get_learned_conditioning([prompt_body]), batch_size)
def sample_start_noise(seed, C, H, W, f, device="cuda"):
if seed:
gen = torch.Generator(device=device)
gen.manual_seed(seed)
noise = torch.randn([C, (H) // f, (W) // f], generator=gen, device=device).unsqueeze(0)
else:
noise = torch.randn([C, (H) // f, (W) // f], device=device).unsqueeze(0)
return noise
def sample_start_noise_special(seed, request, device="cuda"):
if seed:
gen = torch.Generator(device=device)
gen.manual_seed(seed)
noise = torch.randn([request.latent_channels, request.height // request.downsampling_factor, request.width // request.downsampling_factor], generator=gen, device=device).unsqueeze(0)
else:
noise = torch.randn([request.latent_channels, request.height // request.downsampling_factor, request.width // request.downsampling_factor], device=device).unsqueeze(0)
return noise
@torch.no_grad()
#@torch.autocast("cuda", enabled=True, dtype=torch.float16)
def encode_image(image, model):
if isinstance(image, Image.Image):
image = np.array(image)
image = torch.from_numpy(image)
if isinstance(image, np.ndarray):
image = torch.from_numpy(image)
#dtype = image.dtype
image = image.to(torch.float32)
#gets image as numpy array and returns as tensor
def preprocess_vqgan(x):
x = x / 255.0
x = 2.*x - 1.
return x
image = image.permute(2, 0, 1).unsqueeze(0).float().cuda()
image = preprocess_vqgan(image)
image = model.encode(image).sample()
#image = image.to(dtype)
return image
@torch.no_grad()
def decode_image(image, model):
def custom_to_pil(x):
x = x.detach().float().cpu()
x = torch.clamp(x, -1., 1.)
x = (x + 1.)/2.
x = x.permute(0, 2, 3, 1)#.numpy()
#x = (255*x).astype(np.uint8)
#x = Image.fromarray(x)
#if not x.mode == "RGB":
# x = x.convert("RGB")
return x
image = model.decode(image)
image = custom_to_pil(image)
return image
class VectorAdjustPrior(nn.Module):
def __init__(self, hidden_size, inter_dim=64):
super().__init__()
self.vector_proj = nn.Linear(hidden_size*2, inter_dim, bias=True)
self.out_proj = nn.Linear(hidden_size+inter_dim, hidden_size, bias=True)
def forward(self, z):
b, s = z.shape[0:2]
x1 = torch.mean(z, dim=1).repeat(s, 1)
x2 = z.reshape(b*s, -1)
x = torch.cat((x1, x2), dim=1)
x = self.vector_proj(x)
x = torch.cat((x2, x), dim=1)
x = self.out_proj(x)
x = x.reshape(b, s, -1)
return x
@classmethod
def load_model(cls, model_path, hidden_size=768, inter_dim=64):
model = cls(hidden_size=hidden_size, inter_dim=inter_dim)
model.load_state_dict(torch.load(model_path)["state_dict"])
return model
class StableInterface(nn.Module):
def __init__(self, model, thresholder = None):
super().__init__()
self.inner_model = model
self.sigma_to_t = model.sigma_to_t
self.thresholder = thresholder
self.get_sigmas = model.get_sigmas
@torch.no_grad()
def forward(self, x, sigma, uncond, cond, cond_scale):
x_two = torch.cat([x] * 2)
sigma_two = torch.cat([sigma] * 2)
cond_full = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_two, sigma_two, cond=cond_full).chunk(2)
x_0 = uncond + (cond - uncond) * cond_scale
if self.thresholder is not None:
x_0 = self.thresholder(x_0)
return x_0
class StableDiffusionModel(nn.Module):
def __init__(self, config):
nn.Module.__init__(self)
self.config = config
self.premodules = None
if Path(self.config.model_path).is_dir():
config.logger.info(f"Loading model from folder {self.config.model_path}")
model, model_config = self.from_folder(config.model_path)
elif Path(self.config.model_path).is_file():
config.logger.info(f"Loading model from file {self.config.model_path}")
model, model_config = self.from_file(config.model_path)
else:
raise Exception("Invalid model path!")
if config.dtype == "float16":
typex = torch.float16
else:
typex = torch.float32
self.model = model.to(config.device).to(typex)
if self.config.vae_path:
ckpt=torch.load(self.config.vae_path, map_location="cpu")
loss = []
for i in ckpt["state_dict"].keys():
if i[0:4] == "loss":
loss.append(i)
for i in loss:
del ckpt["state_dict"][i]
model.first_stage_model = model.first_stage_model.float()
model.first_stage_model.load_state_dict(ckpt["state_dict"])
model.first_stage_model = model.first_stage_model.float()
del ckpt
del loss
config.logger.info(f"Using VAE from {self.config.vae_path}")
if self.config.penultimate == "1":
model.cond_stage_model.return_layer = -2
model.cond_stage_model.do_final_ln = True
config.logger.info(f"CLIP: Using penultimate layer")
if self.config.clip_contexts > 1:
model.cond_stage_model.clip_extend = True
model.cond_stage_model.max_clip_extend = 75 * self.config.clip_contexts
model.cond_stage_model.inference_mode = True
self.k_model = K.external.CompVisDenoiser(model)
self.k_model = StableInterface(self.k_model)
self.device = config.device
self.model_config = model_config
self.plms = PLMSSampler(model)
self.ddim = DDIMSampler(model)
self.ema_manager = self.model.ema_scope
if self.config.enable_ema == "0":
self.ema_manager = contextlib.nullcontext
config.logger.info("Disabling EMA")
else:
config.logger.info(f"Using EMA")
self.sampler_map = {
'plms': self.plms.sample,
'ddim': self.ddim.sample,
'k_euler': K.sampling.sample_euler,
'k_euler_ancestral': K.sampling.sample_euler_ancestral,
'k_heun': K.sampling.sample_heun,
'k_dpm_2': K.sampling.sample_dpm_2,
'k_dpm_2_ancestral': K.sampling.sample_dpm_2_ancestral,
'k_lms': K.sampling.sample_lms,
}
if config.prior_path:
self.prior = VectorAdjustPrior.load_model(config.prior_path).to(self.device)
self.copied_ema = False
@property
def get_default_config(self):
dict_config = {
'steps': 30,
'sampler': "k_euler_ancestral",
'n_samples': 1,
'image': None,
'fixed_code': False,
'ddim_eta': 0.0,
'height': 512,
'width': 512,
'latent_channels': 4,
'downsampling_factor': 8,
'scale': 12.0,
'dynamic_threshold': None,
'seed': None,
'stage_two_seed': None,
'module': None,
'masks': None,
'output': None,
}
return DotMap(dict_config)
def from_folder(self, folder):
folder = Path(folder)
model_config = OmegaConf.load(folder / "config.yaml")
if (folder / "pruned.ckpt").is_file():
model_path = folder / "pruned.ckpt"
else:
model_path = folder / "model.ckpt"
model = self.load_model_from_config(model_config, model_path)
return model, model_config
def from_path(self, file):
default_config = Path(self.config.default_config)
if not default_config.is_file():
raise Exception("Default config to load not found! Either give a folder on MODEL_PATH or specify a config to use with this checkpoint on DEFAULT_CONFIG")
model_config = OmegaConf.load(default_config)
model = self.load_model_from_config(model_config, file)
return model, model_config
def load_model_from_config(self, config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd.get('state_dict', pl_sd)
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
model.eval()
return model
@torch.no_grad()
@torch.autocast("cuda", enabled=True, dtype=torch.float16)
def sample(self, request, callback=None):
if request.module is not None:
if request.module == "vanilla":
pass
else:
module = self.premodules[request.module]
CrossAttention.set_hypernetwork(module)
if request.seed is not None:
seed_everything(request.seed)
if request.image is not None:
request.steps = 50
#request.sampler = "ddim_img2img" #enforce ddim for now
if request.sampler == "plms":
request.sampler = "k_lms"
if request.sampler == "ddim":
request.sampler = "k_lms"
self.ddim.make_schedule(ddim_num_steps=request.steps, ddim_eta=request.ddim_eta, verbose=False)
start_code = encode_image(request.image, self.model.first_stage_model).to(self.device)
start_code = self.model.get_first_stage_encoding(start_code)
start_code = torch.repeat_interleave(start_code, request.n_samples, dim=0)
main_noise = []
start_noise = []
for seed in range(request.seed, request.seed+request.n_samples):
main_noise.append(sample_start_noise(seed, request.latent_channels, request.height, request.width, request.downsampling_factor, self.device))
start_noise.append(sample_start_noise(seed, request.latent_channels, request.height, request.width, request.downsampling_factor, self.device))
main_noise = torch.cat(main_noise, dim=0)
start_noise = torch.cat(start_noise, dim=0)
start_code = start_code + (start_noise * request.noise)
t_enc = int(request.strength * request.steps)
if request.sampler.startswith("k_"):
sampler = "k-diffusion"
elif request.sampler == 'ddim_img2img':
sampler = 'img2img'
else:
sampler = "normal"
if request.image is None:
main_noise = []
for seed_offset in range(request.n_samples):
if request.masks is not None:
noise_x = sample_start_noise_special(request.seed, request, self.device)
else:
noise_x = sample_start_noise_special(request.seed+seed_offset, request, self.device)
if request.masks is not None:
for maskobj in request.masks:
mask_seed = maskobj["seed"]
mask = maskobj["mask"]
mask = np.asarray(mask)
mask = torch.from_numpy(mask).clone().to(self.device).permute(2, 0, 1)
mask = mask.float() / 255.0
# convert RGB or grayscale image into 4-channel
mask = mask[0].unsqueeze(0)
mask = torch.repeat_interleave(mask, request.latent_channels, dim=0)
mask = (mask < 0.5).float()
# interpolate start noise
noise_x = (noise_x * (1-mask)) + (sample_start_noise_special(mask_seed+seed_offset, request, self.device) * mask)
main_noise.append(noise_x)
main_noise = torch.cat(main_noise, dim=0)
start_code = main_noise
prompt = [request.prompt]
prompt_condition = prompt_mixing(self.model, prompt[0], 1)
if hasattr(self, "prior") and request.mitigate:
prompt_condition = self.prior(prompt_condition)
uc = None
if request.scale != 1.0:
if request.uc is not None:
uc = [request.uc]
uc = prompt_mixing(self.model, uc[0], 1)
else:
if self.config.quality_hack == "1":
uc = ["Tags: lowres"]
uc = prompt_mixing(self.model, uc[0], 1)
else:
uc = self.model.get_learned_conditioning([""])
prompt_condition, uc = fix_cond_shapes(self.model, prompt_condition, uc)
shape = [
request.latent_channels,
request.height // request.downsampling_factor,
request.width // request.downsampling_factor
]
c_dele = CallbackDelegate(callback, request.steps * request.n_samples)
# handle images one at a time because batches eat absurd VRAM
sampless = []
for main_noise, start_code in zip(main_noise.chunk(request.n_samples), start_code.chunk(request.n_samples)):
if sampler == "normal":
with self.ema_manager():
samples, _ = self.sampler_map[request.sampler](
S=request.steps,
conditioning=prompt_condition,
batch_size=1,
shape=shape,
verbose=False,
unconditional_guidance_scale=request.scale,
unconditional_conditioning=uc,
eta=request.ddim_eta,
dynamic_threshold=request.dynamic_threshold,
x_T=start_code,
callback=c_dele.update
)
elif sampler == "k-diffusion":
with self.ema_manager():
sigmas = self.k_model.get_sigmas(request.steps)
if request.image is not None:
noise = main_noise * sigmas[request.steps - t_enc - 1]
start_code = start_code + noise
sigmas = sigmas[request.steps - t_enc - 1:]
else:
start_code = start_code * sigmas[0]
extra_args = {'cond': prompt_condition, 'uncond': uc, 'cond_scale': request.scale}
samples = self.sampler_map[request.sampler](
self.k_model,
start_code,
sigmas,
request.seed,
callback=c_dele.update,
extra_args=extra_args
)
sampless.append(samples)
torch_gc()
images = []
for samples in sampless:
with torch.autocast("cuda", enabled=False):
x_samples_ddim = self.model.decode_first_stage(samples.float())
#x_samples_ddim = decode_image(samples, self.model.first_stage_model)
#x_samples_ddim = self.model.first_stage_model.decode(samples.float())
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = x_sample.astype(np.uint8)
x_sample = np.ascontiguousarray(x_sample)
images.append(x_sample)
torch_gc()
if request.seed is not None:
torch.seed()
np.random.seed()
#set hypernetwork to none after generation
CrossAttention.set_hypernetwork(None)
return images
@torch.no_grad()
def sample_two_stages(self, request, callback = None):
request = DotMap(request)
if request.seed is not None:
seed_everything(request.seed)
if request.plms:
sampler = self.plms
else:
sampler = self.ddim
start_code = None
if request.fixed_code:
start_code = torch.randn([
request.n_samples,
request.latent_channels,
request.height // request.downsampling_factor,
request.width // request.downsampling_factor,
], device=self.device)
prompt = [request.prompt] * request.n_samples
prompt_condition = prompt_mixing(self.model, prompt[0], request.n_samples)
uc = None
if request.scale != 1.0:
uc = self.model.get_learned_conditioning(request.n_samples * [""])
prompt_condition, uc = fix_cond_shapes(self.model, prompt_condition, uc)
shape = [
request.latent_channels,
request.height // request.downsampling_factor,
request.width // request.downsampling_factor
]
c_dele = CallbackDelegate(callback, request.steps * request.n_samples)
with torch.autocast("cuda", enabled=self.config.amp):
with ema_manager():
samples, _ = sampler.sample(
S=request.steps,
conditioning=prompt_condition,
batch_size=request.n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=request.scale,
unconditional_conditioning=uc,
eta=request.ddim_eta,
dynamic_threshold=request.dynamic_threshold,
x_T=start_code,
callback=c_dele.update
)
x_samples_ddim = self.model.decode_first_stage(samples)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).squeeze(0)
x_samples_ddim = pil_upscale(x_samples_ddim, scale=2)
if request.stage_two_seed is not None:
torch.manual_seed(request.stage_two_seed)
np.random.seed(request.stage_two_seed)
with torch.autocast("cuda", enabled=self.config.amp):
with ema_manager():
init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(x_samples_ddim))
self.ddim.make_schedule(ddim_num_steps=request.steps, ddim_eta=request.ddim_eta, verbose=False)
t_enc = int(request.strength * request.steps)
print("init latent shape:")
print(init_latent.shape)
init_latent = init_latent + (torch.randn_like(init_latent) * request.noise)
prompt_condition = prompt_mixing(self.model, prompt[0], request.n_samples)
uc = None
if request.scale != 1.0:
uc = self.model.get_learned_conditioning(request.n_samples * [""])
prompt_condition, uc = fix_cond_shapes(self.model, prompt_condition, uc)
# encode (scaled latent)
start_code_terped=None
z_enc = self.ddim.stochastic_encode(init_latent, torch.tensor([t_enc]*request.n_samples).to(self.device), noise=start_code_terped)
# decode it
samples = self.ddim.decode(z_enc, prompt_condition, t_enc, unconditional_guidance_scale=request.scale,
unconditional_conditioning=uc,)
x_samples_ddim = self.model.decode_first_stage(samples)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
images = []
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = x_sample.astype(np.uint8)
x_sample = np.ascontiguousarray(x_sample)
images.append(x_sample)
if request.seed is not None:
torch.seed()
np.random.seed()
return images
@torch.no_grad()
def sample_from_image(self, request):
return
class DalleMiniModel(nn.Module):
def __init__(self, config):
nn.Module.__init__(self)
from min_dalle import MinDalle
self.config = config
self.model = MinDalle(
models_root=config.model_path,
dtype=torch.float16,
device='cuda',
is_mega=True,
is_reusable=True
)
@torch.no_grad()
def sample(self, request, callback=None):
if request.seed is not None:
seed = request.seed
else:
seed = -1
images = self.model.generate_images(
text=request.prompt,
seed=seed,
grid_size=request.grid_size,
is_seamless=False,
temperature=request.temp,
top_k=request.top_k,
supercondition_factor=request.scale,
is_verbose=False
)
images = images.to('cpu').numpy()
images = images.astype(np.uint8)
images = np.ascontiguousarray(images)
if request.seed is not None:
torch.seed()
np.random.seed()
return images
def apply_temp(logits, temperature):
logits = logits / temperature
return logits
@torch.no_grad()
def generate(forward, prompt_tokens, tokenizer, tokens_to_generate=50, ds=False, ops_list=[{"temp": 0.9}], hypernetwork=None, non_deterministic=False, fully_deterministic=False):
in_tokens = prompt_tokens
context = prompt_tokens
generated = torch.zeros(len(ops_list), 0, dtype=torch.long).to(in_tokens.device)
kv = None
if non_deterministic:
torch.seed()
#soft_required = ["top_k", "top_p"]
op_map = {
"temp": apply_temp,
}
for _ in range(tokens_to_generate):
if ds:
logits, kv = forward(in_tokens, past_key_values=kv, use_cache=True)
else:
logits, kv = forward(in_tokens, cache=True, kv=kv, hypernetwork=hypernetwork)
logits = logits[:, -1, :] #get the last token in the seq
logits = torch.log_softmax(logits, dim=-1)
batch = []
for i, ops in enumerate(ops_list):
item = logits[i, ...].unsqueeze(0)
ctx = context[i, ...].unsqueeze(0)
for op, value in ops.items():
if op == "rep_pen":
item = op_map[op](ctx, item, **value)
else:
item = op_map[op](item, value)
batch.append(item)
logits = torch.cat(batch, dim=0)
logits = torch.softmax(logits, dim=-1)
#fully_deterministic makes it deterministic across the batch
if fully_deterministic:
logits = logits.split(1, dim=0)
logit_list = []
for logit in logits:
torch.manual_seed(69)
logit_list.append(torch.multinomial(logit, 1))
logits = torch.cat(logit_list, dim=0)
else:
logits = torch.multinomial(logits, 1)
if logits[0, 0] == 48585:
if generated[0, -1] == 1400:
pass
elif generated[0, -1] == 3363:
return "safe", "none"
else:
return "notsafe", tokenizer.decode(generated.squeeze()).split("Output: ")[-1]
generated = torch.cat([generated, logits], dim=-1)
context = torch.cat([context, logits], dim=-1)
in_tokens = logits
return "unknown", tokenizer.decode(generated.squeeze())
class BasedformerModel(nn.Module):
def __init__(self, config):
nn.Module.__init__(self)
from basedformer import lm_utils
from transformers import GPT2TokenizerFast
self.config = config
self.model = lm_utils.load_from_path(config.model_path).half().cuda()
self.model = self.model.convert_to_ds()
self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
@torch.no_grad()
def sample(self, request, callback=None):
prompt = request.prompt
prompt = self.tokenizer.encode("Input: " + prompt, return_tensors='pt').cuda().long()
prompt = torch.cat([prompt, torch.tensor([[49527]], dtype=torch.long).cuda()], dim=1)
is_safe, corrected = generate(self.model.module, prompt, self.tokenizer, tokens_to_generate=150, ds=True)
return is_safe, corrected
class EmbedderModel(nn.Module):
def __init__(self, config=None):
nn.Module.__init__(self)
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer('./models/sentence-transformers_all-MiniLM-L6-v2').cuda()
self.tags = [tuple(x) for x in json.load(open("models/tags.json"))]
self.knn = self.load_knn("models/tags.index")
print("Loaded tag suggestion model using phrase embeddings")
def load_knn(self, filename):
import faiss
try:
return faiss.read_index(filename)
except RuntimeError:
print(f"Generating tag embedding index for {len(self.tags)} tags.")
i = faiss.IndexFlatL2(384)
i.add(self([name for name, count in self.tags]))
faiss.write_index(i, filename)
return i
def __call__(self, sentences):
with torch.no_grad():
sentence_embeddings = self.model.encode(sentences)
return sentence_embeddings
def get_top_k(self, text):
#check if text is a substring in tag_count.keys()
found = []
a = bisect.bisect_left(self.tags, (text,))
b = bisect.bisect_left(self.tags, (text + '\xff',), lo=a)
for tag, count in self.tags[a:b]:
if len(tag) >= len(text) and tag.startswith(text):
found.append([tag, count, 0])
results = []
embedding = self([text])
k = 15
D, I = self.knn.search(embedding, k)
D, I = D.squeeze(), I.squeeze()
for id, prob in zip(I, D):
tag, count = self.tags[id]
results.append([tag, count, prob])
found.sort(key=lambda x: x[1], reverse=True)
found = found[:5]
# found = heapq.nlargest(5, found, key=lambda x: x[1])
results_tags = [x[0] for x in found]
for result in results.copy():
if result[0] in results_tags:
results.remove(result)
results = sorted(results, key=lambda x: x[2], reverse=True)
#filter results for >0.5 confidence unless it has the search text in it and confidence is >0.4
results = [x for x in results if x[2] > 0.5 or (x[2] > 0.4 and text in x[0])]
if found:
results = found + results
#max 10 results
results = results[:10]
results = sorted(results, key=lambda x: x[1], reverse=True)
return results

@ -0,0 +1,221 @@
import traceback
from dotmap import DotMap
import math
from io import BytesIO
import base64
import random
v1pp_defaults = {
'steps': 50,
'sampler': "plms",
'image': None,
'fixed_code': False,
'ddim_eta': 0.0,
'height': 512,
'width': 512,
'latent_channels': 4,
'downsampling_factor': 8,
'scale': 7.0,
'dynamic_threshold': None,
'seed': None,
'stage_two_seed': None,
'module': None,
'masks': None,
}
v1pp_forced_defaults = {
'latent_channels': 4,
'downsampling_factor': 8,
}
dalle_mini_defaults = {
'temp': 1.0,
'top_k': 256,
'scale': 16,
'grid_size': 4,
}
dalle_mini_forced_defaults = {
}
defaults = {
'stable-diffusion': (v1pp_defaults, v1pp_forced_defaults),
'dalle-mini': (dalle_mini_defaults, dalle_mini_forced_defaults),
'basedformer': ({}, {}),
'embedder': ({}, {}),
}
samplers = [
"plms",
"ddim",
"k_euler",
"k_euler_ancestral",
"k_heun",
"k_dpm_2",
"k_dpm_2_ancestral",
"k_lms"
]
def closest_multiple(num, mult):
num_int = int(num)
floor = math.floor(num_int / mult) * mult
ceil = math.ceil(num_int / mult) * mult
return floor if (num_int - floor) < (ceil - num_int) else ceil
def sanitize_stable_diffusion(request, config):
if request.steps > 50:
return False, "steps must be smaller than 50"
if request.width * request.height == 0:
return False, "width and height must be non-zero"
if request.width <= 0:
return False, "width must be positive"
if request.height <= 0:
return False, "height must be positive"
if request.steps <= 0:
return False, "steps must be positive"
if request.ddim_eta < 0:
return False, "ddim_eta shouldn't be negative"
if request.scale < 1.0:
return False, "scale should be at least 1.0"
if request.dynamic_threshold is not None and request.dynamic_threshold < 0:
return False, "dynamic_threshold shouldn't be negative"
if request.width * request.height >= 1024*1025:
return False, "width and height must be less than 1024*1025"
if request.strength < 0.0 or request.strength >= 1.0:
return False, "strength should be more than 0.0 and less than 1.0"
if request.noise < 0.0 or request.noise > 1.0:
return False, "noise should be more than 0.0 and less than 1.0"
if request.advanced:
request.width = closest_multiple(request.width // 2, 64)
request.height = closest_multiple(request.height // 2, 64)
if request.sampler not in samplers:
return False, "sampler should be one of {}".format(samplers)
if request.seed is None:
state = random.getstate()
request.seed = random.randint(0, 2**32)
random.setstate(state)
if request.module is not None:
if request.module not in config.model.premodules and request.module != "vanilla":
return False, "module should be one of: " + ", ".join(config.model.premodules)
max_gens = 100
if 0:
num_gen_tiers = [(1024*512, 4), (640*640, 6), (704*512, 8), (512*512, 16), (384*640, 18)]
pixel_count = request.width * request.height
for tier in num_gen_tiers:
if pixel_count <= tier[0]:
max_gens = tier[1]
else:
break
if request.n_samples > max_gens:
return False, f"requested more ({request.n_samples}) images than possible at this resolution"
if request.image is not None:
#decode from base64
try:
request.image = base64.b64decode(request.image.encode('utf-8'))
except Exception as e:
traceback.print_exc()
return False, "image is not valid base64"
#check if image is valid
try:
from PIL import Image
image = Image.open(BytesIO(request.image))
image.verify()
except Exception as e:
traceback.print_exc()
return False, "image is not valid"
#image is valid, load it again(still check again, verify() can't be sure as it doesn't decode.)
try:
image = Image.open(BytesIO(request.image))
image = image.convert('RGB')
image = image.resize((request.width, request.height), resample=Image.Resampling.LANCZOS)
request.image = image
except Exception as e:
traceback.print_exc()
return False, "Error while opening and cleaning image"
if request.masks is not None:
masks = request.masks
for x in range(len(masks)):
image = masks[x]["mask"]
try:
image_bytes = base64.b64decode(image.encode('utf-8'))
except Exception as e:
traceback.print_exc()
return False, "image is not valid base64"
try:
from PIL import Image
image = Image.open(BytesIO(image_bytes))
image.verify()
except Exception as e:
traceback.print_exc()
return False, "image is not valid"
#image is valid, load it again(still check again, verify() can't be sure as it doesn't decode.)
try:
image = Image.open(BytesIO(image_bytes))
#image = image.convert('RGB')
image = image.resize((request.width//request.downsampling_factor, request.height//request.downsampling_factor), resample=Image.Resampling.LANCZOS)
except Exception as e:
traceback.print_exc()
return False, "Error while opening and cleaning image"
masks[x]["mask"] = image
return True, request
def sanitize_dalle_mini(request):
return True, request
def sanitize_basedformer(request):
return True, request
def sanitize_embedder(request):
return True, request
def sanitize_input(config, request):
"""
Sanitize the input data and set defaults
"""
request = DotMap(request)
default, forced_default = defaults[config.model_name]
for k, v in default.items():
if k not in request:
request[k] = v
for k, v in forced_default.items():
request[k] = v
if config.model_name == 'stable-diffusion':
return sanitize_stable_diffusion(request, config)
elif config.model_name == 'dalle-mini':
return sanitize_dalle_mini(request)
elif config.model_name == 'basedformer':
return sanitize_basedformer(request)
elif config.model_name == "embedder":
return sanitize_embedder(request)

@ -0,0 +1,764 @@
from concurrent.futures import thread
import os
import random
import re
import string
import sys
from types import LambdaType
from aiohttp import request
from async_timeout import asyncio
from fastapi import FastAPI, Request, Depends
from numpy import number
from pydantic import BaseModel
from fastapi.responses import HTMLResponse, PlainTextResponse, Response
from fastapi.exceptions import HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse
from hydra_node.config import init_config_model
from hydra_node.models import EmbedderModel, torch_gc
from typing import Optional, List, Any
from typing_extensions import TypedDict
import socket
from hydra_node.sanitize import sanitize_input
import uvicorn
from typing import Union, Dict
import time
import gc
import io
import signal
import base64
import traceback
import threading
from PIL import Image
from PIL.PngImagePlugin import PngInfo
import json
genlock = threading.Lock()
MAX_N_SAMPLES = int(os.getenv("MAX_N_SAMPLES", 100))
TOKEN = os.getenv("TOKEN", None)
print(f"Starting Hydra Node HTTP TOKEN={TOKEN}")
#Initialize model and config
model, config, model_hash = init_config_model()
try:
embedmodel = EmbedderModel()
except Exception as e:
print("couldn't load embed model, suggestions won't work:", e)
embedmodel = False
logger = config.logger
try:
config.mainpid = int(open("gunicorn.pid", "r").read())
except FileNotFoundError:
config.mainpid = os.getpid()
mainpid = config.mainpid
hostname = socket.gethostname()
sent_first_message = False
class TaskQueueFullException(Exception):
pass
class TaskData(TypedDict):
task_id: str
position: str
status: str
updated_time: float
callback: LambdaType
current_step: int
total_steps: int
payload: Any
response: Any
class TaskQueue:
max_queue_size = 10
recovery_queue_size = 6
is_busy = False
loop_thread = None
gc_loop_thread = None
running = False
gc_running = False
lock = threading.Lock()
queued_task_list: List[TaskData] = []
task_map: Dict[str, TaskData] = {}
def __init__(self, max_queue_size = 10, recovery_queue_size = 6) -> None:
self.max_queue_size = max_queue_size
self.max_queue_size = recovery_queue_size
self.gc_loop_thread = threading.Thread(name="TaskQueueGC", target=self._gc_loop)
self.gc_running = True
self.gc_loop_thread.start()
logger.info("Task queue created")
def add_task(self, callback: LambdaType, payload: Any = {}) -> str:
if self.is_busy:
raise TaskQueueFullException("Task queue is full")
if len(self.queued_task_list) >= self.max_queue_size: # mark busy
self.is_busy = True
raise TaskQueueFullException("Task queue is full")
task_id = ''.join(random.sample(string.ascii_letters + string.digits, 16))
task = TaskData(
task_id=task_id,
position=0,
status="queued",
updated_time=time.time(),
callback=callback,
current_step=0,
total_steps=0,
payload=payload
)
with self.lock:
self.queued_task_list.append(task)
task["position"] = len(self.queued_task_list) - 1
logger.info("Added task: %s, queue size: %d" % (task_id, len(self.queued_task_list)))
# create index
with self.lock:
self.task_map[task_id] = task
self._start()
return task_id
def get_task_data(self, task_id: str) -> Union[TaskData, bool]:
if task_id in self.task_map:
return self.task_map[task_id]
else:
return False
def delete_task_data(self, task_id: str) -> bool:
if task_id in self.task_map:
with self.lock:
del(self.task_map[task_id])
else:
return False
def stop(self):
self.running = False
self.gc_running = False
self.queued_task_list = []
def _start(self):
with self.lock:
if self.running == False:
self.loop_thread = threading.Thread(name="TaskQueue", target=self._loop)
self.running = True
self.loop_thread.start()
def _loop(self):
while self.running and len(self.queued_task_list) > 0:
current_task = self.queued_task_list[0]
logger.info("Start task%s." % current_task["task_id"])
try:
current_task["status"] = "running"
current_task["updated_time"] = time.time()
# run task
res = current_task["callback"](current_task)
# call gc
gc.collect()
torch_gc()
current_task["status"] = "finished"
current_task["updated_time"] = time.time()
current_task["response"] = res
current_task["current_step"] = current_task["total_steps"]
except Exception as e:
current_task["status"] = "error"
current_task["updated_time"] = time.time()
current_task["response"] = e
gc.collect()
e_s = str(e)
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
traceback.print_exc()
logger.error(str(e))
logger.error("GPU error in task.")
torch_gc()
current_task["response"] = Exception("GPU error in task.")
# os.kill(mainpid, signal.SIGTERM)
with self.lock:
self.queued_task_list.pop(0)
# reset queue position
for i in range(0, len(self.queued_task_list)):
self.queued_task_list[i]["position"] = i
if self.is_busy and len(self.queued_task_list) < self.recovery_queue_size: # mark not busy
self.is_busy = False
logger.info("Task%s finished, queue size: %d" % (current_task["task_id"], len(self.queued_task_list)))
with self.lock:
self.running = False
# Task to remove finished task
def _gc_loop(self):
while self.gc_running:
current_time = time.time()
with self.lock:
will_remove_keys = []
for (key, task_data) in self.task_map.items():
if task_data["status"] == "finished" or task_data["status"] == "error":
if task_data["updated_time"] + 120 < current_time: # delete task which finished 2 minutes ago
will_remove_keys.append(key)
for key in will_remove_keys:
del(self.task_map[key])
time.sleep(1)
queue = TaskQueue(
max_queue_size=int(os.getenv("QUEUE_MAX_SIZE", 10)),
recovery_queue_size=int(os.getenv("QUEUE_RECOVERY_SIZE", 6))
)
def verify_token(req: Request):
if TOKEN:
valid = "Authorization" in req.headers and req.headers["Authorization"] == "Bearer "+TOKEN
if not valid:
raise HTTPException(
status_code=401,
detail="Unauthorized"
)
return True
#Initialize fastapi
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
@app.on_event("startup")
def startup_event():
logger.info("FastAPI Started, serving")
@app.on_event("shutdown")
def shutdown_event():
logger.info("FastAPI Shutdown, exiting")
queue.stop()
class Masker(TypedDict):
seed: int
mask: str
class Tags(TypedDict):
tag: str
count: int
confidence: float
class GenerationRequest(BaseModel):
prompt: str
image: str = None
n_samples: int = 1
steps: int = 50
sampler: str = "plms"
fixed_code: bool = False
ddim_eta: float = 0.0
height: int = 512
width: int = 512
latent_channels: int = 4
downsampling_factor: int = 8
scale: float = 7.0
dynamic_threshold: float = None
seed: int = None
temp: float = 1.0
top_k: int = 256
grid_size: int = 4
advanced: bool = False
stage_two_seed: int = None
strength: float = 0.69
noise: float = 0.667
mitigate: bool = False
module: str = None
masks: List[Masker] = None
uc: str = None
class TaskIdOutput(BaseModel):
task_id: str
class TextRequest(BaseModel):
prompt: str
class TagOutput(BaseModel):
tags: List[Tags]
class TextOutput(BaseModel):
is_safe: str
corrected_text: str
class TaskIdRequest(BaseModel):
task_id: str
class TaskDataOutput(BaseModel):
status: str
position: int
current_step: int
total_steps: int
class GenerationOutput(BaseModel):
output: List[str]
class ErrorOutput(BaseModel):
error: str
def saveimage(image, request):
os.makedirs("images", exist_ok=True)
filename = request.prompt.replace('masterpiece, best quality, ', '')
filename = re.sub(r'[/\\<>:"|]', '', filename)
filename = filename[:128]
filename += f' s-{request.seed}'
filename = os.path.join("images", filename.strip())
for n in range(1000000):
suff = '.png'
if n:
suff = f'-{n}.png'
if not os.path.exists(filename + suff):
break
try:
with open(filename + suff, "wb") as f:
f.write(image)
except Exception as e:
print("failed to save image:", e)
def _generate_stream(request: GenerationRequest, task_info: TaskData = TaskData()):
try:
task_info["total_steps"] = request.steps + 1
def _on_step(step_num, total_steps):
task_info["total_steps"] = total_steps + 1
task_info["current_step"] = step_num
if request.advanced:
if request.n_samples > 1:
return ErrorOutput(error="advanced mode does not support n_samples > 1")
images = model.sample_two_stages(request, callback=_on_step)
else:
images = model.sample(request, callback=_on_step)
logger.info("Sample finished.")
seed = request.seed
images_encoded = []
for x in range(len(images)):
if seed is not None:
request.seed = seed
seed += 1
comment = json.dumps({"steps":request.steps,"sampler":request.sampler,"seed":request.seed,"strength":request.strength,"noise":request.noise,"scale":request.scale,"uc":request.uc})
metadata = PngInfo()
metadata.add_text("Title", "AI generated image")
metadata.add_text("Description", request.prompt)
metadata.add_text("Software", "NovelAI")
metadata.add_text("Source", "Stable Diffusion "+model_hash)
metadata.add_text("Comment", comment)
image = Image.fromarray(images[x])
#save pillow image with bytesIO
output = io.BytesIO()
image.save(output, format='PNG', pnginfo=metadata)
image = output.getvalue()
if config.savefiles:
saveimage(image, request)
#get base64 of image
image = base64.b64encode(image).decode("ascii")
images_encoded.append(image)
task_info["current_step"] += 1
del images
logger.info("Images encoded.")
return images_encoded
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
else:
raise e
@app.post('/generate-stream')
async def generate_stream(request: GenerationRequest, authorized: bool = Depends(verify_token)):
t = time.perf_counter()
try:
request.n_samples = min(request.n_samples, MAX_N_SAMPLES) # 应用生成上限
output = sanitize_input(config, request)
if output[0]:
request = output[1]
else:
return ErrorOutput(error=output[1])
task_id = None
try:
task_id = queue.add_task(lambda t: _generate_stream(request, t), request)
except TaskQueueFullException:
return ErrorOutput(error="Task queue is full, please wait for minutes.")
except Exception as err:
raise err
images_encoded = []
while True:
task_data = queue.get_task_data(task_id)
if not task_data:
raise Exception("Task not found")
if task_data["status"] == "finished":
images_encoded = task_data["response"]
break
elif task_data["status"] == "error":
return {"error": str(task_data["response"])}
await asyncio.sleep(0.1)
process_time = time.perf_counter() - t
logger.info(f"Request took {process_time:0.3f} seconds")
data = ""
ptr = 0
for x in images_encoded:
ptr += 1
data += ("event: newImage\nid: {}\ndata:{}\n\n").format(ptr, x)
return Response(content=data, media_type="text/event-stream")
#return GenerationOutput(output=images)
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
return {"error": str(e)}
@app.post('/start-generate-stream')
async def start_generate_stream(request: GenerationRequest, authorized: bool = Depends(verify_token)):
try:
request.n_samples = min(request.n_samples, MAX_N_SAMPLES) # 应用生成上限
output = sanitize_input(config, request)
if output[0]:
request = output[1]
else:
return ErrorOutput(error=output[1])
task_id = None
try:
task_id = queue.add_task(lambda t: _generate_stream(request, t), request)
except TaskQueueFullException:
return ErrorOutput(error="Task queue is full, please wait for minutes.")
except Exception as err:
raise err
return TaskIdOutput(task_id=task_id)
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
return {"error": str(e)}
@app.post('/get-generate-stream-output')
async def generate_stream_output(request: TaskIdRequest, authorized: bool = Depends(verify_token)):
try:
task_id = request.task_id
task_data = queue.get_task_data(task_id)
if not task_data:
return ErrorOutput(error="Task not found.")
if task_data["status"] == "finished":
images_encoded = task_data["response"]
data = ""
ptr = 0
for x in images_encoded:
ptr += 1
data += ("event: newImage\nid: {}\ndata:{}\n\n").format(ptr, x)
return Response(content=data, media_type="text/event-stream")
elif task_data["status"] == "error":
raise task_data["response"]
else:
return ErrorOutput(error="Task is not finished.")
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
return {"error": str(e)}
def _generate(request: GenerationRequest, task_info: TaskData):
try:
task_info["total_steps"] = request.steps + 1
def _on_step(step_num, total_steps):
task_info["total_steps"] = total_steps + 1
task_info["current_step"] = step_num
images = model.sample(request, callback=_on_step)
images_encoded = []
for x in range(len(images)):
image = Image.fromarray(images[x])
comment = json.dumps({"steps":request.steps,"sampler":request.sampler,"seed":request.seed,"strength":request.strength,"noise":request.noise,"scale":request.scale,"uc":request.uc})
metadata = PngInfo()
metadata.add_text("Title", "AI generated image")
metadata.add_text("Description", request.prompt)
metadata.add_text("Software", "NovelAI")
metadata.add_text("Source", "Stable Diffusion "+model_hash)
metadata.add_text("Comment", comment)
image = Image.fromarray(images[x])
#save pillow image with bytesIO
output = io.BytesIO()
image.save(output, format='PNG', pnginfo=metadata)
image = output.getvalue()
if config.savefiles:
saveimage(image, request)
#get base64 of image
image = base64.b64encode(image).decode("ascii")
images_encoded.append(image)
task_info["current_step"] += 1
del images
return images_encoded
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
else:
raise e
@app.post('/generate', response_model=Union[GenerationOutput, ErrorOutput])
async def generate(request: GenerationRequest, authorized: bool = Depends(verify_token)):
t = time.perf_counter()
try:
request.n_samples = min(request.n_samples, MAX_N_SAMPLES) # 应用生成上限
output = sanitize_input(config, request)
if output[0]:
request = output[1]
else:
return ErrorOutput(error=output[1])
task_id = None
try:
task_id = queue.add_task(lambda t: _generate(request, t), request)
except TaskQueueFullException:
return ErrorOutput(error="Task queue is full, please wait for minutes.")
except Exception as err:
raise err
images_encoded = []
while True:
task_data = queue.get_task_data(task_id)
if not task_data:
raise Exception("Task not found")
if task_data["status"] == "finished":
images_encoded = task_data["response"]
break
elif task_data["status"] == "error":
return {"error": str(task_data["response"])}
await asyncio.sleep(0.1)
process_time = time.perf_counter() - t
logger.info(f"Request took {process_time:0.3f} seconds")
return GenerationOutput(output=images_encoded)
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
return {"error": str(e)}
@app.post('/start-generate', response_model=Union[GenerationOutput, ErrorOutput])
async def generate(request: GenerationRequest, authorized: bool = Depends(verify_token)):
try:
request.n_samples = min(request.n_samples, MAX_N_SAMPLES) # 应用生成上限
output = sanitize_input(config, request)
if output[0]:
request = output[1]
else:
return ErrorOutput(error=output[1])
task_id = None
try:
task_id = queue.add_task(lambda t: _generate(request, t), request)
except TaskQueueFullException:
return ErrorOutput(error="Task queue is full, please wait for minutes.")
except Exception as err:
raise err
return TaskIdOutput(task_id=task_id)
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
return {"error": str(e)}
@app.post('/get-generate-output')
async def generate_stream_output(request: TaskIdRequest, authorized: bool = Depends(verify_token)):
try:
task_id = request.task_id
task_data = queue.get_task_data(task_id)
if not task_data:
return ErrorOutput(error="Task not found.")
if task_data["status"] == "finished":
images_encoded = task_data["response"]
return GenerationOutput(output=images_encoded)
elif task_data["status"] == "error":
raise task_data["response"]
else:
return ErrorOutput(error="Task is not finished.")
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
return {"error": str(e)}
@app.post('/generate-text', response_model=Union[TextOutput, ErrorOutput])
def generate_text(request: TextRequest, authorized: bool = Depends(verify_token)):
t = time.perf_counter()
try:
output = sanitize_input(config, request)
if output[0]:
request = output[1]
else:
return ErrorOutput(error=output[1])
is_safe, corrected_text = model.sample(request)
process_time = time.perf_counter() - t
logger.info(f"Request took {process_time:0.3f} seconds")
return TextOutput(is_safe=is_safe, corrected_text=corrected_text)
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
return ErrorOutput(error=str(e))
@app.get('/predict-tags', response_model=Union[TagOutput, ErrorOutput])
async def predict_tags(prompt="", authorized: bool = Depends(verify_token)):
t = time.perf_counter()
try:
#output = sanitize_input(config, request)
#if output[0]:
# request = output[1]
#else:
# return ErrorOutput(error=output[1])
tags = embedmodel.get_top_k(prompt)
process_time = time.perf_counter() - t
logger.info(f"Request took {process_time:0.3f} seconds")
return TagOutput(tags=[Tags(tag=tag, count=count, confidence=confidence) for tag, count, confidence in tags])
except Exception as e:
traceback.print_exc()
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
torch_gc()
# logger.error("GPU error, committing seppuku.")
# os.kill(mainpid, signal.SIGTERM)
return ErrorOutput(error=str(e))
@app.post('/task-info', response_model=Union[TaskDataOutput, ErrorOutput])
async def get_task_info(request: TaskIdRequest):
task_data = queue.get_task_data(request.task_id)
if task_data:
return TaskDataOutput(
status=task_data["status"],
position=task_data["position"],
current_step=task_data["current_step"],
total_steps=task_data["total_steps"]
)
else:
return ErrorOutput(error="Cannot find current task.")
@app.get('/')
def index():
return FileResponse('static/index.html')
app.mount("/", StaticFiles(directory="static/"), name="static")
def start():
uvicorn.run("main:app", host="0.0.0.0", port=4315, log_level="info")
if __name__ == "__main__":
start()

@ -0,0 +1,37 @@
@echo off
:: UNCOMMENT (remove ::) if you want the backend to automatically save files for you
set SAVE_FILES="1"
if not defined PYTHON (set PYTHON=python)
if not defined VENV_DIR (set VENV_DIR=venv)
set ERROR_REPORTING=FALSE
:: set PYTHON="%~dp0%VENV_DIR%\Scripts\Python.exe"
set DTYPE=float32
set CLIP_CONTEXTS=3
set AMP=1
set MODEL=stable-diffusion
set DEV=False
set MODEL_PATH=models/animefull-final-pruned
::these aren't actually used by the site?
::set MODULE_PATH=models/modules
::unclear if these are used either
::set PRIOR_PATH=models/vector_adjust_v2.pt
set ENABLE_EMA=1
set VAE_PATH=models/animevae.pt
set PENULTIMATE=1
set PYTHONDONTWRITEBYTECODE=1
set LOWVRAM=0
:: 队列大小
set QUEUE_MAX_SIZE=10
set QUEUE_RECOVERY_SIZE=6
:: 并行生成最多张数
set MAX_N_SAMPLES=100
%PYTHON% -m uvicorn --host 0.0.0.0 --port=6969 --workers 1 main:app
pause

@ -0,0 +1,35 @@
#!/bin/bash
# UNCOMMENT if you want the backend to automatically save files for you
# export SAVE_FILES="1"
export DTYPE="float32"
export CLIP_CONTEXTS=3
export AMP="1"
export MODEL="stable-diffusion"
export DEV="True"
export MODEL_PATH="models/animefull-final-pruned"
#these aren't actually used by the site
#export MODULE_PATH="models/modules"
#unclear if these are used either
#export PRIOR_PATH="models/vector_adjust_v2.pt"
export ENABLE_EMA="1"
export VAE_PATH="models/animevae.pt"
export PENULTIMATE="1"
export PYTHONDONTWRITEBYTECODE=1
export LOWVRAM=0
# 队列大小
export QUEUE_MAX_SIZE=10
export QUEUE_RECOVERY_SIZE=6
# 并行生成最多张数
export MAX_N_SAMPLES=100
if [[ -f venv/bin/python ]]; then
PYTHON=venv/bin/python
else
PYTHON=python
fi
$PYTHON -m uvicorn --host 0.0.0.0 --port=6969 main:app & bore local 6969 --to bore.pub
Loading…
Cancel
Save