上传项目
commit
e23b2b2f95
@ -0,0 +1,192 @@
|
||||
import os
|
||||
import torch
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
from dotmap import DotMap
|
||||
from hydra_node.models import StableDiffusionModel, DalleMiniModel, BasedformerModel, EmbedderModel
|
||||
from hydra_node import lowvram
|
||||
import traceback
|
||||
import zlib
|
||||
from pathlib import Path
|
||||
from ldm.modules.attention import CrossAttention, HyperLogic
|
||||
|
||||
from . import lautocast
|
||||
|
||||
model_map = {
|
||||
"stable-diffusion": StableDiffusionModel,
|
||||
"dalle-mini": DalleMiniModel,
|
||||
"basedformer": BasedformerModel,
|
||||
"embedder": EmbedderModel,
|
||||
}
|
||||
|
||||
def no_init(loading_code):
|
||||
def dummy(self):
|
||||
return
|
||||
|
||||
modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm]
|
||||
original = {}
|
||||
for mod in modules:
|
||||
original[mod] = mod.reset_parameters
|
||||
mod.reset_parameters = dummy
|
||||
|
||||
result = loading_code()
|
||||
for mod in modules:
|
||||
mod.reset_parameters = original[mod]
|
||||
|
||||
return result
|
||||
|
||||
def crc32(filename, chunksize=65536):
|
||||
"""Compute the CRC-32 checksum of the contents of the given filename"""
|
||||
with open(filename, "rb") as f:
|
||||
checksum = 0
|
||||
while (chunk := f.read(chunksize)) :
|
||||
checksum = zlib.crc32(chunk, checksum)
|
||||
return '%08X' % (checksum & 0xFFFFFFFF)
|
||||
|
||||
def load_modules(path):
|
||||
path = Path(path)
|
||||
modules = {}
|
||||
if not path.is_dir():
|
||||
return
|
||||
|
||||
for file in path.iterdir():
|
||||
module = load_module(file, "cpu")
|
||||
modules[file.stem] = module
|
||||
print(f"Loaded module {file.stem}")
|
||||
|
||||
return modules
|
||||
|
||||
def load_module(path, device):
|
||||
path = Path(path)
|
||||
if not path.is_file():
|
||||
print("Module path {} is not a file".format(path))
|
||||
|
||||
network = {
|
||||
768: (HyperLogic(768).to(device), HyperLogic(768).to(device)),
|
||||
1280: (HyperLogic(1280).to(device), HyperLogic(1280).to(device)),
|
||||
640: (HyperLogic(640).to(device), HyperLogic(640).to(device)),
|
||||
320: (HyperLogic(320).to(device), HyperLogic(320).to(device)),
|
||||
}
|
||||
|
||||
state_dict = torch.load(path)
|
||||
for key in state_dict.keys():
|
||||
network[key][0].load_state_dict(state_dict[key][0])
|
||||
network[key][1].load_state_dict(state_dict[key][1])
|
||||
|
||||
return network
|
||||
|
||||
def init_config_model():
|
||||
config = DotMap()
|
||||
config.savefiles = os.getenv("SAVE_FILES", False)
|
||||
config.dtype = os.getenv("DTYPE", "float16")
|
||||
config.device = os.getenv("DEVICE", "cuda")
|
||||
config.amp = os.getenv("AMP", False)
|
||||
if config.amp == "1":
|
||||
config.amp = True
|
||||
elif config.amp == "0":
|
||||
config.amp = False
|
||||
|
||||
is_dev = ""
|
||||
environment = "production"
|
||||
if os.environ['DEV'] == "True":
|
||||
is_dev = "_dev"
|
||||
environment = "staging"
|
||||
config.is_dev = is_dev
|
||||
|
||||
# Setup logger
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(level=logging.INFO)
|
||||
fh = logging.StreamHandler()
|
||||
fh_formatter = logging.Formatter(
|
||||
"%(asctime)s %(levelname)s %(filename)s(%(process)d) - %(message)s"
|
||||
)
|
||||
fh.setFormatter(fh_formatter)
|
||||
logger.addHandler(fh)
|
||||
config.logger = logger
|
||||
|
||||
# Gather node information
|
||||
config.cuda_dev = torch.cuda.current_device()
|
||||
cpu_id = platform.processor()
|
||||
if os.path.exists('/proc/cpuinfo'):
|
||||
cpu_id = [line for line in open("/proc/cpuinfo", 'r').readlines() if
|
||||
'model name' in line][0].rstrip().split(': ')[-1]
|
||||
|
||||
config.cpu_id = cpu_id
|
||||
config.gpu_id = torch.cuda.get_device_name(config.cuda_dev)
|
||||
config.node_id = platform.node()
|
||||
|
||||
# Report on our CUDA memory and model.
|
||||
gb_gpu = int(torch.cuda.get_device_properties(
|
||||
config.cuda_dev).total_memory / (1000 * 1000 * 1000))
|
||||
logger.info(f"CPU: {config.cpu_id}")
|
||||
logger.info(f"GPU: {config.gpu_id}")
|
||||
logger.info(f"GPU RAM: {gb_gpu}gb")
|
||||
|
||||
config.model_name = os.environ['MODEL']
|
||||
logger.info(f"MODEL: {config.model_name}")
|
||||
|
||||
# Resolve where we get our model and data from.
|
||||
config.model_path = os.getenv('MODEL_PATH', None)
|
||||
config.enable_ema = os.getenv('ENABLE_EMA', "1")
|
||||
config.basedformer = os.getenv('BASEDFORMER', "0")
|
||||
config.penultimate = os.getenv('PENULTIMATE', "0")
|
||||
config.vae_path = os.getenv('VAE_PATH', None)
|
||||
config.module_path = os.getenv('MODULE_PATH', None)
|
||||
config.prior_path = os.getenv('PRIOR_PATH', None)
|
||||
config.default_config = os.getenv('DEFAULT_CONFIG', None)
|
||||
config.quality_hack = os.getenv('QUALITY_HACK', "0")
|
||||
config.clip_contexts = os.getenv('CLIP_CONTEXTS', "1")
|
||||
try:
|
||||
config.clip_contexts = int(config.clip_contexts)
|
||||
if config.clip_contexts < 1 or config.clip_contexts > 10:
|
||||
config.clip_contexts = 1
|
||||
except:
|
||||
config.clip_contexts = 1
|
||||
|
||||
# Misc settings
|
||||
config.model_alias = os.getenv('MODEL_ALIAS')
|
||||
|
||||
# Instantiate our actual model.
|
||||
load_time = time.time()
|
||||
model_hash = None
|
||||
|
||||
try:
|
||||
if config.model_name != "dalle-mini":
|
||||
model = no_init(lambda: model_map[config.model_name](config))
|
||||
else:
|
||||
model = model_map[config.model_name](config)
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
logger.error(f"Failed to load model: {str(e)}")
|
||||
#exit gunicorn
|
||||
sys.exit(4)
|
||||
|
||||
if config.model_name == "stable-diffusion":
|
||||
folder = Path(config.model_path)
|
||||
if (folder / "pruned.ckpt").is_file():
|
||||
model_path = folder / "pruned.ckpt"
|
||||
else:
|
||||
model_path = folder / "model.ckpt"
|
||||
model_hash = crc32(model_path)
|
||||
|
||||
#Load Modules
|
||||
if config.module_path is not None:
|
||||
modules = load_modules(config.module_path)
|
||||
#attach it to the model
|
||||
model.premodules = modules
|
||||
|
||||
|
||||
if lautocast.lowvram == False and lautocast.medvram == False:
|
||||
lowvram.setup_for_low_vram(model.model, True)
|
||||
|
||||
config.model = model
|
||||
|
||||
time_load = time.time() - load_time
|
||||
logger.info(f"Models loaded in {time_load:.2f}s")
|
||||
|
||||
return model, config, model_hash
|
@ -0,0 +1,9 @@
|
||||
import torch
|
||||
import os
|
||||
|
||||
lowvram = True if os.environ.get('LOWVRAM') == "1" else False
|
||||
medvram = True if os.environ.get('MEDVRAM') == "1" else False
|
||||
dtype = torch.float32 if os.environ.get('DTYPE', 'float32') else torch.float16
|
||||
|
||||
|
||||
print("using dtype: " + os.environ.get('DTYPE', 'float32'))
|
@ -0,0 +1,225 @@
|
||||
# from github.com/AUTOMATIC1111/stable-diffusion-webui
|
||||
|
||||
import torch
|
||||
from torch.nn.functional import silu
|
||||
|
||||
import ldm.modules.attention
|
||||
import ldm.modules.diffusionmodules.model
|
||||
|
||||
|
||||
|
||||
module_in_gpu = None
|
||||
cpu = torch.device("cpu")
|
||||
device = gpu = torch.device("cuda")
|
||||
|
||||
|
||||
def send_everything_to_cpu():
|
||||
global module_in_gpu
|
||||
|
||||
if module_in_gpu is not None:
|
||||
module_in_gpu.to(cpu)
|
||||
|
||||
module_in_gpu = None
|
||||
|
||||
|
||||
def setup_for_low_vram(sd_model, use_medvram):
|
||||
parents = {}
|
||||
|
||||
def send_me_to_gpu(module, _):
|
||||
"""send this module to GPU; send whatever tracked module was previous in GPU to CPU;
|
||||
we add this as forward_pre_hook to a lot of modules and this way all but one of them will
|
||||
be in CPU
|
||||
"""
|
||||
global module_in_gpu
|
||||
|
||||
module = parents.get(module, module)
|
||||
|
||||
if module_in_gpu == module:
|
||||
return
|
||||
|
||||
if module_in_gpu is not None:
|
||||
module_in_gpu.to(cpu)
|
||||
|
||||
module.to(gpu)
|
||||
module_in_gpu = module
|
||||
|
||||
# see below for register_forward_pre_hook;
|
||||
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
|
||||
# useless here, and we just replace those methods
|
||||
def first_stage_model_encode_wrap(self, encoder, x):
|
||||
send_me_to_gpu(self, None)
|
||||
return encoder(x)
|
||||
|
||||
def first_stage_model_decode_wrap(self, decoder, z):
|
||||
send_me_to_gpu(self, None)
|
||||
return decoder(z)
|
||||
|
||||
# remove three big modules, cond, first_stage, and unet from the model and then
|
||||
# send the model to GPU. Then put modules back. the modules will be in CPU.
|
||||
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
|
||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
|
||||
sd_model.to(device)
|
||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
|
||||
|
||||
# register hooks for those the first two models
|
||||
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
||||
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
||||
sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
|
||||
sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
|
||||
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
||||
|
||||
if use_medvram:
|
||||
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
|
||||
else:
|
||||
diff_model = sd_model.model.diffusion_model
|
||||
|
||||
# the third remaining model is still too big for 4 GB, so we also do the same for its submodules
|
||||
# so that only one of them is in GPU at a time
|
||||
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
|
||||
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
|
||||
sd_model.model.to(device)
|
||||
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
|
||||
|
||||
# install hooks for bits of third model
|
||||
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
|
||||
for block in diff_model.input_blocks:
|
||||
block.register_forward_pre_hook(send_me_to_gpu)
|
||||
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
|
||||
for block in diff_model.output_blocks:
|
||||
block.register_forward_pre_hook(send_me_to_gpu)
|
||||
|
||||
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||
|
||||
try:
|
||||
import xformers
|
||||
except ImportError:
|
||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
||||
|
||||
import math
|
||||
import torch
|
||||
from torch import einsum
|
||||
|
||||
from ldm.util import default
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
|
||||
# taken from https://github.com/Doggettx/stable-diffusion
|
||||
def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k_in = self.to_k(context) * self.scale
|
||||
v_in = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||
|
||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||
del s1
|
||||
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
|
||||
del q, k, v
|
||||
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
|
||||
return self.to_out(r2)
|
||||
|
||||
def cross_attention_attnblock_forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q1 = self.q(h_)
|
||||
k1 = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q1.shape
|
||||
|
||||
q2 = q1.reshape(b, c, h*w)
|
||||
del q1
|
||||
|
||||
q = q2.permute(0, 2, 1) # b,hw,c
|
||||
del q2
|
||||
|
||||
k = k1.reshape(b, c, h*w) # b,c,hw
|
||||
del k1
|
||||
|
||||
h_ = torch.zeros_like(k, device=q.device)
|
||||
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||
mem_required = tensor_size * 2.5
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
|
||||
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w2 = w1 * (int(c)**(-0.5))
|
||||
del w1
|
||||
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
|
||||
del w2
|
||||
|
||||
# attend to values
|
||||
v1 = v.reshape(b, c, h*w)
|
||||
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
del w3
|
||||
|
||||
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
del v1, w4
|
||||
|
||||
h2 = h_.reshape(b, c, h, w)
|
||||
del h_
|
||||
|
||||
h3 = self.proj_out(h2)
|
||||
del h2
|
||||
|
||||
h3 += x
|
||||
|
||||
return h3
|
@ -0,0 +1,824 @@
|
||||
import os
|
||||
import bisect
|
||||
import json
|
||||
from re import S
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
from omegaconf import OmegaConf
|
||||
from dotmap import DotMap
|
||||
import numpy as np
|
||||
from torch import autocast
|
||||
from einops import rearrange, repeat
|
||||
from torchvision.utils import make_grid
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.modules.attention import CrossAttention, HyperLogic
|
||||
from PIL import Image
|
||||
import k_diffusion as K
|
||||
import contextlib
|
||||
import random
|
||||
import base64
|
||||
|
||||
from .lowvram import setup_for_low_vram
|
||||
from . import lautocast
|
||||
|
||||
class CallbackDelegate:
|
||||
total_steps = 0
|
||||
current_step = -1
|
||||
callback = None
|
||||
|
||||
def __init__(self, callback, total_steps) -> None:
|
||||
self.callback = callback
|
||||
self.total_steps = total_steps
|
||||
|
||||
def update(self, n = None):
|
||||
self.current_step += 1
|
||||
if self.callback:
|
||||
self.callback(self.current_step, self.total_steps)
|
||||
return n
|
||||
|
||||
|
||||
def seed_everything(seed: int):
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
def pil_upscale(image, scale=1):
|
||||
device = image.device
|
||||
dtype = image.dtype
|
||||
image = Image.fromarray((image.cpu().permute(1,2,0).numpy().astype(np.float32) * 255.).astype(np.uint8))
|
||||
if scale > 1:
|
||||
image = image.resize((int(image.width * scale), int(image.height * scale)), resample=Image.LANCZOS)
|
||||
image = np.array(image)
|
||||
image = image.astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
image = 2.*image - 1.
|
||||
image = repeat(image, '1 ... -> b ...', b=1)
|
||||
return image.to(device)
|
||||
|
||||
def fix_batch(tensor, bs):
|
||||
return torch.stack([tensor.squeeze(0)]*bs, dim=0)
|
||||
|
||||
def torch_gc():
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
# make uc and prompt shapes match via padding for long prompts
|
||||
# finetune
|
||||
null_cond = None
|
||||
def fix_cond_shapes(model, prompt_condition, uc):
|
||||
global null_cond
|
||||
if null_cond is None:
|
||||
null_cond = model.get_learned_conditioning([""])
|
||||
while prompt_condition.shape[1] > uc.shape[1]:
|
||||
uc = torch.cat((uc, null_cond.repeat((uc.shape[0], 1, 1))), axis=1)
|
||||
while prompt_condition.shape[1] < uc.shape[1]:
|
||||
prompt_condition = torch.cat((prompt_condition, null_cond.repeat((prompt_condition.shape[0], 1, 1))), axis=1)
|
||||
return prompt_condition, uc
|
||||
|
||||
# mix conditioning vectors for prompts
|
||||
# @aero
|
||||
def prompt_mixing(model, prompt_body, batch_size):
|
||||
if "|" in prompt_body:
|
||||
prompt_parts = prompt_body.split("|")
|
||||
prompt_total_power = 0
|
||||
prompt_sum = None
|
||||
for prompt_part in prompt_parts:
|
||||
prompt_power = 1
|
||||
if ":" in prompt_part:
|
||||
prompt_sub_parts = prompt_part.split(":")
|
||||
try:
|
||||
prompt_power = float(prompt_sub_parts[1])
|
||||
prompt_part = prompt_sub_parts[0]
|
||||
except:
|
||||
print("Error parsing prompt power! Assuming 1")
|
||||
prompt_vector = model.get_learned_conditioning([prompt_part])
|
||||
if prompt_sum is None:
|
||||
prompt_sum = prompt_vector * prompt_power
|
||||
else:
|
||||
prompt_sum, prompt_vector = fix_cond_shapes(model, prompt_sum, prompt_vector)
|
||||
prompt_sum = prompt_sum + (prompt_vector * prompt_power)
|
||||
prompt_total_power = prompt_total_power + prompt_power
|
||||
return fix_batch(prompt_sum / prompt_total_power, batch_size)
|
||||
else:
|
||||
return fix_batch(model.get_learned_conditioning([prompt_body]), batch_size)
|
||||
|
||||
def sample_start_noise(seed, C, H, W, f, device="cuda"):
|
||||
if seed:
|
||||
gen = torch.Generator(device=device)
|
||||
gen.manual_seed(seed)
|
||||
noise = torch.randn([C, (H) // f, (W) // f], generator=gen, device=device).unsqueeze(0)
|
||||
else:
|
||||
noise = torch.randn([C, (H) // f, (W) // f], device=device).unsqueeze(0)
|
||||
return noise
|
||||
|
||||
def sample_start_noise_special(seed, request, device="cuda"):
|
||||
if seed:
|
||||
gen = torch.Generator(device=device)
|
||||
gen.manual_seed(seed)
|
||||
noise = torch.randn([request.latent_channels, request.height // request.downsampling_factor, request.width // request.downsampling_factor], generator=gen, device=device).unsqueeze(0)
|
||||
else:
|
||||
noise = torch.randn([request.latent_channels, request.height // request.downsampling_factor, request.width // request.downsampling_factor], device=device).unsqueeze(0)
|
||||
return noise
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_image(image, model):
|
||||
if isinstance(image, Image.Image):
|
||||
image = np.array(image)
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
if isinstance(image, np.ndarray):
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
#dtype = image.dtype
|
||||
image = image.to(torch.float32)
|
||||
#gets image as numpy array and returns as tensor
|
||||
def preprocess_vqgan(x):
|
||||
x = x / 255.0
|
||||
x = 2.*x - 1.
|
||||
return x
|
||||
|
||||
image = image.permute(2, 0, 1).unsqueeze(0).float().cuda()
|
||||
image = preprocess_vqgan(image)
|
||||
image = model.encode(image).sample()
|
||||
#image = image.to(dtype)
|
||||
|
||||
return image
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_image(image, model):
|
||||
def custom_to_pil(x):
|
||||
x = x.detach().float().cpu()
|
||||
x = torch.clamp(x, -1., 1.)
|
||||
x = (x + 1.)/2.
|
||||
x = x.permute(0, 2, 3, 1)#.numpy()
|
||||
#x = (255*x).astype(np.uint8)
|
||||
#x = Image.fromarray(x)
|
||||
#if not x.mode == "RGB":
|
||||
# x = x.convert("RGB")
|
||||
return x
|
||||
|
||||
image = model.decode(image)
|
||||
image = custom_to_pil(image)
|
||||
return image
|
||||
|
||||
class VectorAdjustPrior(nn.Module):
|
||||
def __init__(self, hidden_size, inter_dim=64):
|
||||
super().__init__()
|
||||
self.vector_proj = nn.Linear(hidden_size*2, inter_dim, bias=True)
|
||||
self.out_proj = nn.Linear(hidden_size+inter_dim, hidden_size, bias=True)
|
||||
|
||||
def forward(self, z):
|
||||
b, s = z.shape[0:2]
|
||||
x1 = torch.mean(z, dim=1).repeat(s, 1)
|
||||
x2 = z.reshape(b*s, -1)
|
||||
x = torch.cat((x1, x2), dim=1)
|
||||
x = self.vector_proj(x)
|
||||
x = torch.cat((x2, x), dim=1)
|
||||
x = self.out_proj(x)
|
||||
x = x.reshape(b, s, -1)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def load_model(cls, model_path, hidden_size=768, inter_dim=64):
|
||||
model = cls(hidden_size=hidden_size, inter_dim=inter_dim)
|
||||
model.load_state_dict(torch.load(model_path)["state_dict"])
|
||||
return model
|
||||
|
||||
class StableInterface(nn.Module):
|
||||
def __init__(self, model, thresholder = None):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.sigma_to_t = model.sigma_to_t
|
||||
self.thresholder = thresholder
|
||||
self.get_sigmas = model.get_sigmas
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
x_two = torch.cat([x] * 2)
|
||||
sigma_two = torch.cat([sigma] * 2)
|
||||
cond_full = torch.cat([uncond, cond])
|
||||
uncond, cond = self.inner_model(x_two, sigma_two, cond=cond_full).chunk(2)
|
||||
x_0 = uncond + (cond - uncond) * cond_scale
|
||||
if self.thresholder is not None:
|
||||
x_0 = self.thresholder(x_0)
|
||||
|
||||
return x_0
|
||||
|
||||
class StableDiffusionModel(nn.Module):
|
||||
def __init__(self, config):
|
||||
nn.Module.__init__(self)
|
||||
self.config = config
|
||||
self.premodules = None
|
||||
if Path(self.config.model_path).is_dir():
|
||||
config.logger.info(f"Loading model from folder {self.config.model_path}")
|
||||
model, model_config = self.from_folder(config.model_path)
|
||||
|
||||
elif Path(self.config.model_path).is_file():
|
||||
config.logger.info(f"Loading model from file {self.config.model_path}")
|
||||
model, model_config = self.from_file(config.model_path)
|
||||
|
||||
else:
|
||||
raise Exception("Invalid model path!")
|
||||
|
||||
if config.dtype == "float16":
|
||||
typex = torch.float16
|
||||
else:
|
||||
typex = torch.float32
|
||||
|
||||
print("----------------------------------")
|
||||
if lautocast.medvram == True:
|
||||
print("passed lowvram setup for medvram (6G)")
|
||||
setup_for_low_vram(model, True)
|
||||
elif lautocast.lowvram == True:
|
||||
print("setup for lowvram (4G and lower)")
|
||||
setup_for_low_vram(model, False)
|
||||
else:
|
||||
print("loading model")
|
||||
model.to(config.device)
|
||||
print("----------------------------------")
|
||||
self.model = model.to(typex)
|
||||
# self.model = model.to(config.device).to(typex)
|
||||
if self.config.vae_path:
|
||||
ckpt=torch.load(self.config.vae_path, map_location="cpu")
|
||||
loss = []
|
||||
for i in ckpt["state_dict"].keys():
|
||||
if i[0:4] == "loss":
|
||||
loss.append(i)
|
||||
for i in loss:
|
||||
del ckpt["state_dict"][i]
|
||||
|
||||
model.first_stage_model = model.first_stage_model.float()
|
||||
model.first_stage_model.load_state_dict(ckpt["state_dict"])
|
||||
model.first_stage_model = model.first_stage_model.float()
|
||||
del ckpt
|
||||
del loss
|
||||
config.logger.info(f"Using VAE from {self.config.vae_path}")
|
||||
|
||||
if self.config.penultimate == "1":
|
||||
model.cond_stage_model.return_layer = -2
|
||||
model.cond_stage_model.do_final_ln = True
|
||||
config.logger.info(f"CLIP: Using penultimate layer")
|
||||
|
||||
if self.config.clip_contexts > 1:
|
||||
model.cond_stage_model.clip_extend = True
|
||||
model.cond_stage_model.max_clip_extend = 75 * self.config.clip_contexts
|
||||
|
||||
model.cond_stage_model.inference_mode = True
|
||||
self.k_model = K.external.CompVisDenoiser(model)
|
||||
self.k_model = StableInterface(self.k_model)
|
||||
self.device = config.device
|
||||
self.model_config = model_config
|
||||
self.plms = PLMSSampler(model)
|
||||
self.ddim = DDIMSampler(model)
|
||||
self.ema_manager = self.model.ema_scope
|
||||
if self.config.enable_ema == "0":
|
||||
self.ema_manager = contextlib.nullcontext
|
||||
config.logger.info("Disabling EMA")
|
||||
else:
|
||||
config.logger.info(f"Using EMA")
|
||||
self.sampler_map = {
|
||||
'plms': self.plms.sample,
|
||||
'ddim': self.ddim.sample,
|
||||
'k_euler': K.sampling.sample_euler,
|
||||
'k_euler_ancestral': K.sampling.sample_euler_ancestral,
|
||||
'k_heun': K.sampling.sample_heun,
|
||||
'k_dpm_2': K.sampling.sample_dpm_2,
|
||||
'k_dpm_2_ancestral': K.sampling.sample_dpm_2_ancestral,
|
||||
'k_lms': K.sampling.sample_lms,
|
||||
}
|
||||
if config.prior_path:
|
||||
self.prior = VectorAdjustPrior.load_model(config.prior_path).to(self.device)
|
||||
self.copied_ema = False
|
||||
|
||||
@property
|
||||
def get_default_config(self):
|
||||
dict_config = {
|
||||
'steps': 30,
|
||||
'sampler': "k_euler_ancestral",
|
||||
'n_samples': 1,
|
||||
'image': None,
|
||||
'fixed_code': False,
|
||||
'ddim_eta': 0.0,
|
||||
'height': 512,
|
||||
'width': 512,
|
||||
'latent_channels': 4,
|
||||
'downsampling_factor': 8,
|
||||
'scale': 12.0,
|
||||
'dynamic_threshold': None,
|
||||
'seed': None,
|
||||
'stage_two_seed': None,
|
||||
'module': None,
|
||||
'masks': None,
|
||||
'output': None,
|
||||
}
|
||||
return DotMap(dict_config)
|
||||
|
||||
def from_folder(self, folder):
|
||||
folder = Path(folder)
|
||||
model_config = OmegaConf.load(folder / "config.yaml")
|
||||
if (folder / "pruned.ckpt").is_file():
|
||||
model_path = folder / "pruned.ckpt"
|
||||
else:
|
||||
model_path = folder / "model.ckpt"
|
||||
model = self.load_model_from_config(model_config, model_path)
|
||||
return model, model_config
|
||||
|
||||
def from_path(self, file):
|
||||
default_config = Path(self.config.default_config)
|
||||
if not default_config.is_file():
|
||||
raise Exception("Default config to load not found! Either give a folder on MODEL_PATH or specify a config to use with this checkpoint on DEFAULT_CONFIG")
|
||||
model_config = OmegaConf.load(default_config)
|
||||
model = self.load_model_from_config(model_config, file)
|
||||
return model, model_config
|
||||
|
||||
def load_model_from_config(self, config, ckpt, verbose=False):
|
||||
print(f"Loading model from {ckpt}")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
if "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
|
||||
sd = pl_sd.get('state_dict', pl_sd)
|
||||
|
||||
model = instantiate_from_config(config.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0 and verbose:
|
||||
print("missing keys:")
|
||||
print(m)
|
||||
if len(u) > 0 and verbose:
|
||||
print("unexpected keys:")
|
||||
print(u)
|
||||
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.autocast("cuda", enabled=False, dtype=lautocast.dtype)
|
||||
def sample(self, request, callback = None):
|
||||
if request.module is not None:
|
||||
if request.module == "vanilla":
|
||||
pass
|
||||
|
||||
else:
|
||||
module = self.premodules[request.module]
|
||||
CrossAttention.set_hypernetwork(module)
|
||||
|
||||
if request.seed is not None:
|
||||
seed_everything(request.seed)
|
||||
|
||||
if request.image is not None:
|
||||
request.steps = 50
|
||||
#request.sampler = "ddim_img2img" #enforce ddim for now
|
||||
if request.sampler == "plms":
|
||||
request.sampler = "k_lms"
|
||||
if request.sampler == "ddim":
|
||||
request.sampler = "k_lms"
|
||||
|
||||
self.ddim.make_schedule(ddim_num_steps=request.steps, ddim_eta=request.ddim_eta, verbose=False)
|
||||
start_code = encode_image(request.image, self.model.first_stage_model).to(self.device)
|
||||
start_code = self.model.get_first_stage_encoding(start_code)
|
||||
start_code = torch.repeat_interleave(start_code, request.n_samples, dim=0)
|
||||
|
||||
main_noise = []
|
||||
start_noise = []
|
||||
for seed in range(request.seed, request.seed+request.n_samples):
|
||||
main_noise.append(sample_start_noise(seed, request.latent_channels, request.height, request.width, request.downsampling_factor, self.device))
|
||||
start_noise.append(sample_start_noise(seed, request.latent_channels, request.height, request.width, request.downsampling_factor, self.device))
|
||||
|
||||
main_noise = torch.cat(main_noise, dim=0)
|
||||
start_noise = torch.cat(start_noise, dim=0)
|
||||
|
||||
start_code = start_code + (start_noise * request.noise)
|
||||
t_enc = int(request.strength * request.steps)
|
||||
|
||||
if request.sampler.startswith("k_"):
|
||||
sampler = "k-diffusion"
|
||||
|
||||
elif request.sampler == 'ddim_img2img':
|
||||
sampler = 'img2img'
|
||||
|
||||
else:
|
||||
sampler = "normal"
|
||||
|
||||
if request.image is None:
|
||||
main_noise = []
|
||||
for seed_offset in range(request.n_samples):
|
||||
if request.masks is not None:
|
||||
noise_x = sample_start_noise_special(request.seed, request, self.device)
|
||||
else:
|
||||
noise_x = sample_start_noise_special(request.seed+seed_offset, request, self.device)
|
||||
|
||||
if request.masks is not None:
|
||||
for maskobj in request.masks:
|
||||
mask_seed = maskobj["seed"]
|
||||
mask = maskobj["mask"]
|
||||
mask = np.asarray(mask)
|
||||
mask = torch.from_numpy(mask).clone().to(self.device).permute(2, 0, 1)
|
||||
mask = mask.float() / 255.0
|
||||
# convert RGB or grayscale image into 4-channel
|
||||
mask = mask[0].unsqueeze(0)
|
||||
mask = torch.repeat_interleave(mask, request.latent_channels, dim=0)
|
||||
mask = (mask < 0.5).float()
|
||||
|
||||
# interpolate start noise
|
||||
noise_x = (noise_x * (1-mask)) + (sample_start_noise_special(mask_seed+seed_offset, request, self.device) * mask)
|
||||
|
||||
main_noise.append(noise_x)
|
||||
|
||||
main_noise = torch.cat(main_noise, dim=0)
|
||||
start_code = main_noise
|
||||
|
||||
prompt = [request.prompt]
|
||||
prompt_condition = prompt_mixing(self.model, prompt[0], 1)
|
||||
if hasattr(self, "prior") and request.mitigate:
|
||||
prompt_condition = self.prior(prompt_condition)
|
||||
|
||||
uc = None
|
||||
if request.scale != 1.0:
|
||||
if request.uc is not None:
|
||||
uc = [request.uc]
|
||||
uc = prompt_mixing(self.model, uc[0], 1)
|
||||
else:
|
||||
if self.config.quality_hack == "1":
|
||||
uc = ["Tags: lowres"]
|
||||
uc = prompt_mixing(self.model, uc[0], 1)
|
||||
else:
|
||||
uc = self.model.get_learned_conditioning([""])
|
||||
prompt_condition, uc = fix_cond_shapes(self.model, prompt_condition, uc)
|
||||
|
||||
shape = [
|
||||
request.latent_channels,
|
||||
request.height // request.downsampling_factor,
|
||||
request.width // request.downsampling_factor
|
||||
]
|
||||
|
||||
c_dele = CallbackDelegate(callback, request.steps * request.n_samples)
|
||||
|
||||
# handle images one at a time because batches eat absurd VRAM
|
||||
sampless = []
|
||||
for main_noise, start_code in zip(main_noise.chunk(request.n_samples), start_code.chunk(request.n_samples)):
|
||||
if sampler == "normal":
|
||||
with self.ema_manager():
|
||||
c_dele.update()
|
||||
samples, _ = self.sampler_map[request.sampler](
|
||||
S=request.steps,
|
||||
conditioning=prompt_condition,
|
||||
batch_size=1,
|
||||
shape=shape,
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=request.scale,
|
||||
unconditional_conditioning=uc,
|
||||
eta=request.ddim_eta,
|
||||
dynamic_threshold=request.dynamic_threshold,
|
||||
x_T=start_code,
|
||||
callback=c_dele.update
|
||||
)
|
||||
|
||||
elif sampler == "k-diffusion":
|
||||
with self.ema_manager():
|
||||
sigmas = self.k_model.get_sigmas(request.steps)
|
||||
if request.image is not None:
|
||||
noise = main_noise * sigmas[request.steps - t_enc - 1]
|
||||
start_code = start_code + noise
|
||||
sigmas = sigmas[request.steps - t_enc - 1:]
|
||||
|
||||
else:
|
||||
start_code = start_code * sigmas[0]
|
||||
|
||||
extra_args = {'cond': prompt_condition, 'uncond': uc, 'cond_scale': request.scale}
|
||||
c_dele.update()
|
||||
samples = self.sampler_map[request.sampler](
|
||||
self.k_model,
|
||||
start_code,
|
||||
sigmas,
|
||||
request.seed,
|
||||
callback=c_dele.update,
|
||||
extra_args=extra_args
|
||||
)
|
||||
|
||||
sampless.append(samples)
|
||||
torch_gc()
|
||||
|
||||
images = []
|
||||
for samples in sampless:
|
||||
with torch.autocast("cuda", enabled=self.config.amp):
|
||||
x_samples_ddim = self.model.decode_first_stage(samples.float())
|
||||
#x_samples_ddim = decode_image(samples, self.model.first_stage_model)
|
||||
#x_samples_ddim = self.model.first_stage_model.decode(samples.float())
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
for x_sample in x_samples_ddim:
|
||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
x_sample = np.ascontiguousarray(x_sample)
|
||||
images.append(x_sample)
|
||||
|
||||
torch_gc()
|
||||
|
||||
if request.seed is not None:
|
||||
torch.seed()
|
||||
np.random.seed()
|
||||
|
||||
#set hypernetwork to none after generation
|
||||
CrossAttention.set_hypernetwork(None)
|
||||
|
||||
return images
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_two_stages(self, request, callback = None):
|
||||
request = DotMap(request)
|
||||
if request.seed is not None:
|
||||
seed_everything(request.seed)
|
||||
|
||||
if request.plms:
|
||||
sampler = self.plms
|
||||
else:
|
||||
sampler = self.ddim
|
||||
|
||||
start_code = None
|
||||
if request.fixed_code:
|
||||
start_code = torch.randn([
|
||||
request.n_samples,
|
||||
request.latent_channels,
|
||||
request.height // request.downsampling_factor,
|
||||
request.width // request.downsampling_factor,
|
||||
], device=self.device)
|
||||
|
||||
prompt = [request.prompt] * request.n_samples
|
||||
prompt_condition = prompt_mixing(self.model, prompt[0], request.n_samples)
|
||||
|
||||
uc = None
|
||||
if request.scale != 1.0:
|
||||
uc = self.model.get_learned_conditioning(request.n_samples * [""])
|
||||
prompt_condition, uc = fix_cond_shapes(self.model, prompt_condition, uc)
|
||||
|
||||
shape = [
|
||||
request.latent_channels,
|
||||
request.height // request.downsampling_factor,
|
||||
request.width // request.downsampling_factor
|
||||
]
|
||||
|
||||
c_dele = CallbackDelegate(callback, request.steps * request.n_samples)
|
||||
|
||||
with torch.autocast("cuda", enabled=self.config.amp):
|
||||
with self.ema_manager():
|
||||
samples, _ = sampler.sample(
|
||||
S=request.steps,
|
||||
conditioning=prompt_condition,
|
||||
batch_size=request.n_samples,
|
||||
shape=shape,
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=request.scale,
|
||||
unconditional_conditioning=uc,
|
||||
eta=request.ddim_eta,
|
||||
dynamic_threshold=request.dynamic_threshold,
|
||||
x_T=start_code,
|
||||
callback=c_dele.update
|
||||
)
|
||||
|
||||
x_samples_ddim = self.model.decode_first_stage(samples)
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).squeeze(0)
|
||||
x_samples_ddim = pil_upscale(x_samples_ddim, scale=2)
|
||||
|
||||
if request.stage_two_seed is not None:
|
||||
torch.manual_seed(request.stage_two_seed)
|
||||
np.random.seed(request.stage_two_seed)
|
||||
|
||||
with torch.autocast("cuda", enabled=self.config.amp):
|
||||
with self.ema_manager():
|
||||
init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(x_samples_ddim))
|
||||
self.ddim.make_schedule(ddim_num_steps=request.steps, ddim_eta=request.ddim_eta, verbose=False)
|
||||
t_enc = int(request.strength * request.steps)
|
||||
|
||||
print("init latent shape:")
|
||||
print(init_latent.shape)
|
||||
|
||||
init_latent = init_latent + (torch.randn_like(init_latent) * request.noise)
|
||||
|
||||
prompt_condition = prompt_mixing(self.model, prompt[0], request.n_samples)
|
||||
|
||||
uc = None
|
||||
if request.scale != 1.0:
|
||||
uc = self.model.get_learned_conditioning(request.n_samples * [""])
|
||||
prompt_condition, uc = fix_cond_shapes(self.model, prompt_condition, uc)
|
||||
|
||||
# encode (scaled latent)
|
||||
start_code_terped=None
|
||||
z_enc = self.ddim.stochastic_encode(init_latent, torch.tensor([t_enc]*request.n_samples).to(self.device), noise=start_code_terped)
|
||||
# decode it
|
||||
samples = self.ddim.decode(z_enc, prompt_condition, t_enc, unconditional_guidance_scale=request.scale,
|
||||
unconditional_conditioning=uc,)
|
||||
|
||||
x_samples_ddim = self.model.decode_first_stage(samples)
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
images = []
|
||||
for x_sample in x_samples_ddim:
|
||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
x_sample = np.ascontiguousarray(x_sample)
|
||||
images.append(x_sample)
|
||||
|
||||
if request.seed is not None:
|
||||
torch.seed()
|
||||
np.random.seed()
|
||||
|
||||
return images
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_from_image(self, request):
|
||||
return
|
||||
|
||||
class DalleMiniModel(nn.Module):
|
||||
def __init__(self, config):
|
||||
nn.Module.__init__(self)
|
||||
from min_dalle import MinDalle
|
||||
|
||||
self.config = config
|
||||
self.model = MinDalle(
|
||||
models_root=config.model_path,
|
||||
dtype=torch.float16,
|
||||
device='cuda',
|
||||
is_mega=True,
|
||||
is_reusable=True
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, request, callback = None):
|
||||
if request.seed is not None:
|
||||
seed = request.seed
|
||||
else:
|
||||
seed = -1
|
||||
|
||||
images = self.model.generate_images(
|
||||
text=request.prompt,
|
||||
seed=seed,
|
||||
grid_size=request.grid_size,
|
||||
is_seamless=False,
|
||||
temperature=request.temp,
|
||||
top_k=request.top_k,
|
||||
supercondition_factor=request.scale,
|
||||
is_verbose=False
|
||||
)
|
||||
images = images.to('cpu').numpy()
|
||||
images = images.astype(np.uint8)
|
||||
images = np.ascontiguousarray(images)
|
||||
|
||||
if request.seed is not None:
|
||||
torch.seed()
|
||||
np.random.seed()
|
||||
|
||||
return images
|
||||
|
||||
def apply_temp(logits, temperature):
|
||||
logits = logits / temperature
|
||||
return logits
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(forward, prompt_tokens, tokenizer, tokens_to_generate=50, ds=False, ops_list=[{"temp": 0.9}], hypernetwork=None, non_deterministic=False, fully_deterministic=False):
|
||||
in_tokens = prompt_tokens
|
||||
context = prompt_tokens
|
||||
generated = torch.zeros(len(ops_list), 0, dtype=torch.long).to(in_tokens.device)
|
||||
kv = None
|
||||
if non_deterministic:
|
||||
torch.seed()
|
||||
#soft_required = ["top_k", "top_p"]
|
||||
op_map = {
|
||||
"temp": apply_temp,
|
||||
}
|
||||
|
||||
for _ in range(tokens_to_generate):
|
||||
if ds:
|
||||
logits, kv = forward(in_tokens, past_key_values=kv, use_cache=True)
|
||||
else:
|
||||
logits, kv = forward(in_tokens, cache=True, kv=kv, hypernetwork=hypernetwork)
|
||||
logits = logits[:, -1, :] #get the last token in the seq
|
||||
logits = torch.log_softmax(logits, dim=-1)
|
||||
|
||||
batch = []
|
||||
for i, ops in enumerate(ops_list):
|
||||
item = logits[i, ...].unsqueeze(0)
|
||||
ctx = context[i, ...].unsqueeze(0)
|
||||
for op, value in ops.items():
|
||||
if op == "rep_pen":
|
||||
item = op_map[op](ctx, item, **value)
|
||||
|
||||
else:
|
||||
item = op_map[op](item, value)
|
||||
|
||||
batch.append(item)
|
||||
|
||||
logits = torch.cat(batch, dim=0)
|
||||
logits = torch.softmax(logits, dim=-1)
|
||||
|
||||
#fully_deterministic makes it deterministic across the batch
|
||||
if fully_deterministic:
|
||||
logits = logits.split(1, dim=0)
|
||||
logit_list = []
|
||||
for logit in logits:
|
||||
torch.manual_seed(69)
|
||||
logit_list.append(torch.multinomial(logit, 1))
|
||||
|
||||
logits = torch.cat(logit_list, dim=0)
|
||||
|
||||
else:
|
||||
logits = torch.multinomial(logits, 1)
|
||||
|
||||
if logits[0, 0] == 48585:
|
||||
if generated[0, -1] == 1400:
|
||||
pass
|
||||
elif generated[0, -1] == 3363:
|
||||
return "safe", "none"
|
||||
else:
|
||||
return "notsafe", tokenizer.decode(generated.squeeze()).split("Output: ")[-1]
|
||||
|
||||
generated = torch.cat([generated, logits], dim=-1)
|
||||
context = torch.cat([context, logits], dim=-1)
|
||||
in_tokens = logits
|
||||
|
||||
return "unknown", tokenizer.decode(generated.squeeze())
|
||||
|
||||
|
||||
class BasedformerModel(nn.Module):
|
||||
def __init__(self, config):
|
||||
nn.Module.__init__(self)
|
||||
from basedformer import lm_utils
|
||||
from transformers import GPT2TokenizerFast
|
||||
self.config = config
|
||||
self.model = lm_utils.load_from_path(config.model_path).half().cuda()
|
||||
self.model = self.model.convert_to_ds()
|
||||
self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, request, callback = None):
|
||||
prompt = request.prompt
|
||||
prompt = self.tokenizer.encode("Input: " + prompt, return_tensors='pt').cuda().long()
|
||||
prompt = torch.cat([prompt, torch.tensor([[49527]], dtype=torch.long).cuda()], dim=1)
|
||||
is_safe, corrected = generate(self.model.module, prompt, self.tokenizer, tokens_to_generate=150, ds=True)
|
||||
return is_safe, corrected
|
||||
|
||||
class EmbedderModel(nn.Module):
|
||||
def __init__(self, config=None):
|
||||
nn.Module.__init__(self)
|
||||
from sentence_transformers import SentenceTransformer
|
||||
self.model = SentenceTransformer('./models/sentence-transformers_all-MiniLM-L6-v2').cuda()
|
||||
self.tags = [tuple(x) for x in json.load(open("models/tags.json"))]
|
||||
self.knn = self.load_knn("models/tags.index")
|
||||
print("Loaded tag suggestion model using phrase embeddings")
|
||||
|
||||
def load_knn(self, filename):
|
||||
import faiss
|
||||
try:
|
||||
return faiss.read_index(filename)
|
||||
except RuntimeError:
|
||||
print(f"Generating tag embedding index for {len(self.tags)} tags.")
|
||||
i = faiss.IndexFlatL2(384)
|
||||
i.add(self([name for name, count in self.tags]))
|
||||
faiss.write_index(i, filename)
|
||||
return i
|
||||
|
||||
def __call__(self, sentences):
|
||||
with torch.no_grad():
|
||||
sentence_embeddings = self.model.encode(sentences)
|
||||
return sentence_embeddings
|
||||
|
||||
def get_top_k(self, text):
|
||||
#check if text is a substring in tag_count.keys()
|
||||
found = []
|
||||
a = bisect.bisect_left(self.tags, (text,))
|
||||
b = bisect.bisect_left(self.tags, (text + '\xff',), lo=a)
|
||||
for tag, count in self.tags[a:b]:
|
||||
if len(tag) >= len(text) and tag.startswith(text):
|
||||
found.append([tag, count, 0])
|
||||
|
||||
results = []
|
||||
embedding = self([text])
|
||||
k = 15
|
||||
D, I = self.knn.search(embedding, k)
|
||||
D, I = D.squeeze(), I.squeeze()
|
||||
for id, prob in zip(I, D):
|
||||
tag, count = self.tags[id]
|
||||
results.append([tag, count, prob])
|
||||
|
||||
found.sort(key=lambda x: x[1], reverse=True)
|
||||
found = found[:5]
|
||||
# found = heapq.nlargest(5, found, key=lambda x: x[1])
|
||||
results_tags = [x[0] for x in found]
|
||||
for result in results.copy():
|
||||
if result[0] in results_tags:
|
||||
results.remove(result)
|
||||
|
||||
results = sorted(results, key=lambda x: x[2], reverse=True)
|
||||
#filter results for >0.5 confidence unless it has the search text in it and confidence is >0.4
|
||||
results = [x for x in results if x[2] > 0.5 or (x[2] > 0.4 and text in x[0])]
|
||||
if found:
|
||||
results = found + results
|
||||
|
||||
#max 10 results
|
||||
results = results[:10]
|
||||
results = sorted(results, key=lambda x: x[1], reverse=True)
|
||||
return results
|
@ -0,0 +1,221 @@
|
||||
import traceback
|
||||
from dotmap import DotMap
|
||||
import math
|
||||
from io import BytesIO
|
||||
import base64
|
||||
import random
|
||||
|
||||
v1pp_defaults = {
|
||||
'steps': 50,
|
||||
'sampler': "plms",
|
||||
'image': None,
|
||||
'fixed_code': False,
|
||||
'ddim_eta': 0.0,
|
||||
'height': 512,
|
||||
'width': 512,
|
||||
'latent_channels': 4,
|
||||
'downsampling_factor': 8,
|
||||
'scale': 7.0,
|
||||
'dynamic_threshold': None,
|
||||
'seed': None,
|
||||
'stage_two_seed': None,
|
||||
'module': None,
|
||||
'masks': None,
|
||||
}
|
||||
|
||||
v1pp_forced_defaults = {
|
||||
'latent_channels': 4,
|
||||
'downsampling_factor': 8,
|
||||
}
|
||||
|
||||
dalle_mini_defaults = {
|
||||
'temp': 1.0,
|
||||
'top_k': 256,
|
||||
'scale': 16,
|
||||
'grid_size': 4,
|
||||
}
|
||||
|
||||
dalle_mini_forced_defaults = {
|
||||
}
|
||||
|
||||
defaults = {
|
||||
'stable-diffusion': (v1pp_defaults, v1pp_forced_defaults),
|
||||
'dalle-mini': (dalle_mini_defaults, dalle_mini_forced_defaults),
|
||||
'basedformer': ({}, {}),
|
||||
'embedder': ({}, {}),
|
||||
}
|
||||
|
||||
samplers = [
|
||||
"plms",
|
||||
"ddim",
|
||||
"k_euler",
|
||||
"k_euler_ancestral",
|
||||
"k_heun",
|
||||
"k_dpm_2",
|
||||
"k_dpm_2_ancestral",
|
||||
"k_lms"
|
||||
]
|
||||
|
||||
def closest_multiple(num, mult):
|
||||
num_int = int(num)
|
||||
floor = math.floor(num_int / mult) * mult
|
||||
ceil = math.ceil(num_int / mult) * mult
|
||||
return floor if (num_int - floor) < (ceil - num_int) else ceil
|
||||
|
||||
def sanitize_stable_diffusion(request, config):
|
||||
if request.steps > 50:
|
||||
return False, "steps must be smaller than 50"
|
||||
|
||||
if request.width * request.height == 0:
|
||||
return False, "width and height must be non-zero"
|
||||
|
||||
if request.width <= 0:
|
||||
return False, "width must be positive"
|
||||
|
||||
if request.height <= 0:
|
||||
return False, "height must be positive"
|
||||
|
||||
if request.steps <= 0:
|
||||
return False, "steps must be positive"
|
||||
|
||||
if request.ddim_eta < 0:
|
||||
return False, "ddim_eta shouldn't be negative"
|
||||
|
||||
if request.scale < 1.0:
|
||||
return False, "scale should be at least 1.0"
|
||||
|
||||
if request.dynamic_threshold is not None and request.dynamic_threshold < 0:
|
||||
return False, "dynamic_threshold shouldn't be negative"
|
||||
|
||||
if request.width * request.height >= 1024*1025:
|
||||
return False, "width and height must be less than 1024*1025"
|
||||
|
||||
if request.strength < 0.0 or request.strength >= 1.0:
|
||||
return False, "strength should be more than 0.0 and less than 1.0"
|
||||
|
||||
if request.noise < 0.0 or request.noise > 1.0:
|
||||
return False, "noise should be more than 0.0 and less than 1.0"
|
||||
|
||||
if request.advanced:
|
||||
request.width = closest_multiple(request.width // 2, 64)
|
||||
request.height = closest_multiple(request.height // 2, 64)
|
||||
|
||||
if request.sampler not in samplers:
|
||||
return False, "sampler should be one of {}".format(samplers)
|
||||
|
||||
if request.seed is None:
|
||||
state = random.getstate()
|
||||
request.seed = random.randint(0, 2**32)
|
||||
random.setstate(state)
|
||||
|
||||
if request.module is not None:
|
||||
if request.module not in config.model.premodules and request.module != "vanilla":
|
||||
return False, "module should be one of: " + ", ".join(config.model.premodules)
|
||||
|
||||
max_gens = 100
|
||||
if 0:
|
||||
num_gen_tiers = [(1024*512, 4), (640*640, 6), (704*512, 8), (512*512, 16), (384*640, 18)]
|
||||
pixel_count = request.width * request.height
|
||||
for tier in num_gen_tiers:
|
||||
if pixel_count <= tier[0]:
|
||||
max_gens = tier[1]
|
||||
else:
|
||||
break
|
||||
if request.n_samples > max_gens:
|
||||
return False, f"requested more ({request.n_samples}) images than possible at this resolution"
|
||||
|
||||
if request.image is not None:
|
||||
#decode from base64
|
||||
try:
|
||||
request.image = base64.b64decode(request.image.encode('utf-8'))
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return False, "image is not valid base64"
|
||||
#check if image is valid
|
||||
try:
|
||||
from PIL import Image
|
||||
image = Image.open(BytesIO(request.image))
|
||||
image.verify()
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return False, "image is not valid"
|
||||
|
||||
#image is valid, load it again(still check again, verify() can't be sure as it doesn't decode.)
|
||||
try:
|
||||
image = Image.open(BytesIO(request.image))
|
||||
image = image.convert('RGB')
|
||||
image = image.resize((request.width, request.height), resample=Image.Resampling.LANCZOS)
|
||||
request.image = image
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return False, "Error while opening and cleaning image"
|
||||
|
||||
if request.masks is not None:
|
||||
masks = request.masks
|
||||
for x in range(len(masks)):
|
||||
image = masks[x]["mask"]
|
||||
try:
|
||||
image_bytes = base64.b64decode(image.encode('utf-8'))
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return False, "image is not valid base64"
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
image = Image.open(BytesIO(image_bytes))
|
||||
image.verify()
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return False, "image is not valid"
|
||||
|
||||
#image is valid, load it again(still check again, verify() can't be sure as it doesn't decode.)
|
||||
try:
|
||||
image = Image.open(BytesIO(image_bytes))
|
||||
#image = image.convert('RGB')
|
||||
image = image.resize((request.width//request.downsampling_factor, request.height//request.downsampling_factor), resample=Image.Resampling.LANCZOS)
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return False, "Error while opening and cleaning image"
|
||||
|
||||
masks[x]["mask"] = image
|
||||
|
||||
return True, request
|
||||
|
||||
def sanitize_dalle_mini(request):
|
||||
return True, request
|
||||
|
||||
def sanitize_basedformer(request):
|
||||
return True, request
|
||||
|
||||
def sanitize_embedder(request):
|
||||
return True, request
|
||||
|
||||
def sanitize_input(config, request):
|
||||
"""
|
||||
Sanitize the input data and set defaults
|
||||
"""
|
||||
request = DotMap(request)
|
||||
default, forced_default = defaults[config.model_name]
|
||||
for k, v in default.items():
|
||||
if k not in request:
|
||||
request[k] = v
|
||||
|
||||
for k, v in forced_default.items():
|
||||
request[k] = v
|
||||
|
||||
if config.model_name == 'stable-diffusion':
|
||||
return sanitize_stable_diffusion(request, config)
|
||||
|
||||
elif config.model_name == 'dalle-mini':
|
||||
return sanitize_dalle_mini(request)
|
||||
|
||||
elif config.model_name == 'basedformer':
|
||||
return sanitize_basedformer(request)
|
||||
|
||||
elif config.model_name == "embedder":
|
||||
return sanitize_embedder(request)
|
@ -0,0 +1,193 @@
|
||||
import os
|
||||
import torch
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
from dotmap import DotMap
|
||||
from hydra_node.models import StableDiffusionModel, DalleMiniModel, BasedformerModel, EmbedderModel
|
||||
from hydra_node import lowvram
|
||||
import traceback
|
||||
import zlib
|
||||
from pathlib import Path
|
||||
from ldm.modules.attention import CrossAttention, HyperLogic
|
||||
|
||||
model_map = {
|
||||
"stable-diffusion": StableDiffusionModel,
|
||||
"dalle-mini": DalleMiniModel,
|
||||
"basedformer": BasedformerModel,
|
||||
"embedder": EmbedderModel,
|
||||
}
|
||||
|
||||
def no_init(loading_code):
|
||||
def dummy(self):
|
||||
return
|
||||
|
||||
modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm]
|
||||
original = {}
|
||||
for mod in modules:
|
||||
original[mod] = mod.reset_parameters
|
||||
mod.reset_parameters = dummy
|
||||
|
||||
result = loading_code()
|
||||
for mod in modules:
|
||||
mod.reset_parameters = original[mod]
|
||||
|
||||
return result
|
||||
|
||||
def crc32(filename, chunksize=65536):
|
||||
"""Compute the CRC-32 checksum of the contents of the given filename"""
|
||||
with open(filename, "rb") as f:
|
||||
checksum = 0
|
||||
while True:
|
||||
chunk = f.read(chunksize)
|
||||
if not chunk:
|
||||
break
|
||||
checksum = zlib.crc32(chunk, checksum)
|
||||
#while (chunk := f.read(chunksize)) :
|
||||
# checksum = zlib.crc32(chunk, checksum)
|
||||
return '%08X' % (checksum & 0xFFFFFFFF)
|
||||
|
||||
def load_modules(path):
|
||||
path = Path(path)
|
||||
modules = {}
|
||||
if not path.is_dir():
|
||||
return
|
||||
|
||||
for file in path.iterdir():
|
||||
module = load_module(file, "cpu")
|
||||
modules[file.stem] = module
|
||||
print(f"Loaded module {file.stem}")
|
||||
|
||||
return modules
|
||||
|
||||
def load_module(path, device):
|
||||
path = Path(path)
|
||||
if not path.is_file():
|
||||
print("Module path {} is not a file".format(path))
|
||||
|
||||
network = {
|
||||
768: (HyperLogic(768).to(device), HyperLogic(768).to(device)),
|
||||
1280: (HyperLogic(1280).to(device), HyperLogic(1280).to(device)),
|
||||
640: (HyperLogic(640).to(device), HyperLogic(640).to(device)),
|
||||
320: (HyperLogic(320).to(device), HyperLogic(320).to(device)),
|
||||
}
|
||||
|
||||
state_dict = torch.load(path)
|
||||
for key in state_dict.keys():
|
||||
network[key][0].load_state_dict(state_dict[key][0])
|
||||
network[key][1].load_state_dict(state_dict[key][1])
|
||||
|
||||
return network
|
||||
|
||||
def init_config_model():
|
||||
config = DotMap()
|
||||
config.savefiles = os.getenv("SAVE_FILES", False)
|
||||
config.dtype = os.getenv("DTYPE", "float16")
|
||||
config.device = os.getenv("DEVICE", "cuda")
|
||||
config.amp = os.getenv("AMP", False)
|
||||
if config.amp == "1":
|
||||
config.amp = True
|
||||
elif config.amp == "0":
|
||||
config.amp = False
|
||||
|
||||
is_dev = ""
|
||||
environment = "production"
|
||||
if os.environ['DEV'] == "True":
|
||||
is_dev = "_dev"
|
||||
environment = "staging"
|
||||
config.is_dev = is_dev
|
||||
|
||||
# Setup logger
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(level=logging.INFO)
|
||||
fh = logging.StreamHandler()
|
||||
fh_formatter = logging.Formatter(
|
||||
"%(asctime)s %(levelname)s %(filename)s(%(process)d) - %(message)s"
|
||||
)
|
||||
fh.setFormatter(fh_formatter)
|
||||
logger.addHandler(fh)
|
||||
config.logger = logger
|
||||
|
||||
# Gather node information
|
||||
config.cuda_dev = torch.cuda.current_device()
|
||||
cpu_id = platform.processor()
|
||||
if os.path.exists('/proc/cpuinfo'):
|
||||
cpu_id = [line for line in open("/proc/cpuinfo", 'r').readlines() if
|
||||
'model name' in line][0].rstrip().split(': ')[-1]
|
||||
|
||||
config.cpu_id = cpu_id
|
||||
config.gpu_id = torch.cuda.get_device_name(config.cuda_dev)
|
||||
config.node_id = platform.node()
|
||||
|
||||
# Report on our CUDA memory and model.
|
||||
gb_gpu = int(torch.cuda.get_device_properties(
|
||||
config.cuda_dev).total_memory / (1000 * 1000 * 1000))
|
||||
logger.info(f"CPU: {config.cpu_id}")
|
||||
logger.info(f"GPU: {config.gpu_id}")
|
||||
logger.info(f"GPU RAM: {gb_gpu}gb")
|
||||
|
||||
config.model_name = os.environ['MODEL']
|
||||
logger.info(f"MODEL: {config.model_name}")
|
||||
|
||||
# Resolve where we get our model and data from.
|
||||
config.model_path = os.getenv('MODEL_PATH', None)
|
||||
config.enable_ema = os.getenv('ENABLE_EMA', "1")
|
||||
config.basedformer = os.getenv('BASEDFORMER', "0")
|
||||
config.penultimate = os.getenv('PENULTIMATE', "0")
|
||||
config.vae_path = os.getenv('VAE_PATH', None)
|
||||
config.module_path = os.getenv('MODULE_PATH', None)
|
||||
config.prior_path = os.getenv('PRIOR_PATH', None)
|
||||
config.default_config = os.getenv('DEFAULT_CONFIG', None)
|
||||
config.quality_hack = os.getenv('QUALITY_HACK', "0")
|
||||
config.clip_contexts = os.getenv('CLIP_CONTEXTS', "1")
|
||||
try:
|
||||
config.clip_contexts = int(config.clip_contexts)
|
||||
if config.clip_contexts < 1 or config.clip_contexts > 10:
|
||||
config.clip_contexts = 1
|
||||
except:
|
||||
config.clip_contexts = 1
|
||||
|
||||
# Misc settings
|
||||
config.model_alias = os.getenv('MODEL_ALIAS')
|
||||
|
||||
# Instantiate our actual model.
|
||||
load_time = time.time()
|
||||
model_hash = None
|
||||
|
||||
try:
|
||||
if config.model_name != "dalle-mini":
|
||||
model = no_init(lambda: model_map[config.model_name](config))
|
||||
else:
|
||||
model = model_map[config.model_name](config)
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
logger.error(f"Failed to load model: {str(e)}")
|
||||
#exit gunicorn
|
||||
sys.exit(4)
|
||||
|
||||
if config.model_name == "stable-diffusion":
|
||||
folder = Path(config.model_path)
|
||||
if (folder / "pruned.ckpt").is_file():
|
||||
model_path = folder / "pruned.ckpt"
|
||||
else:
|
||||
model_path = folder / "model.ckpt"
|
||||
model_hash = crc32(model_path)
|
||||
|
||||
#Load Modules
|
||||
if config.module_path is not None:
|
||||
modules = load_modules(config.module_path)
|
||||
#attach it to the model
|
||||
model.premodules = modules
|
||||
|
||||
lowvram.setup_for_low_vram(model.model, True)
|
||||
|
||||
config.model = model
|
||||
|
||||
time_load = time.time() - load_time
|
||||
logger.info(f"Models loaded in {time_load:.2f}s")
|
||||
|
||||
return model, config, model_hash
|
@ -0,0 +1,225 @@
|
||||
# from github.com/AUTOMATIC1111/stable-diffusion-webui
|
||||
|
||||
import torch
|
||||
from torch.nn.functional import silu
|
||||
|
||||
import ldm.modules.attention
|
||||
import ldm.modules.diffusionmodules.model
|
||||
|
||||
|
||||
|
||||
module_in_gpu = None
|
||||
cpu = torch.device("cpu")
|
||||
device = gpu = torch.device("cuda")
|
||||
|
||||
|
||||
def send_everything_to_cpu():
|
||||
global module_in_gpu
|
||||
|
||||
if module_in_gpu is not None:
|
||||
module_in_gpu.to(cpu)
|
||||
|
||||
module_in_gpu = None
|
||||
|
||||
|
||||
def setup_for_low_vram(sd_model, use_medvram):
|
||||
parents = {}
|
||||
|
||||
def send_me_to_gpu(module, _):
|
||||
"""send this module to GPU; send whatever tracked module was previous in GPU to CPU;
|
||||
we add this as forward_pre_hook to a lot of modules and this way all but one of them will
|
||||
be in CPU
|
||||
"""
|
||||
global module_in_gpu
|
||||
|
||||
module = parents.get(module, module)
|
||||
|
||||
if module_in_gpu == module:
|
||||
return
|
||||
|
||||
if module_in_gpu is not None:
|
||||
module_in_gpu.to(cpu)
|
||||
|
||||
module.to(gpu)
|
||||
module_in_gpu = module
|
||||
|
||||
# see below for register_forward_pre_hook;
|
||||
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
|
||||
# useless here, and we just replace those methods
|
||||
def first_stage_model_encode_wrap(self, encoder, x):
|
||||
send_me_to_gpu(self, None)
|
||||
return encoder(x)
|
||||
|
||||
def first_stage_model_decode_wrap(self, decoder, z):
|
||||
send_me_to_gpu(self, None)
|
||||
return decoder(z)
|
||||
|
||||
# remove three big modules, cond, first_stage, and unet from the model and then
|
||||
# send the model to GPU. Then put modules back. the modules will be in CPU.
|
||||
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
|
||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
|
||||
sd_model.to(device)
|
||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
|
||||
|
||||
# register hooks for those the first two models
|
||||
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
||||
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
||||
sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
|
||||
sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
|
||||
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
||||
|
||||
if use_medvram:
|
||||
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
|
||||
else:
|
||||
diff_model = sd_model.model.diffusion_model
|
||||
|
||||
# the third remaining model is still too big for 4 GB, so we also do the same for its submodules
|
||||
# so that only one of them is in GPU at a time
|
||||
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
|
||||
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
|
||||
sd_model.model.to(device)
|
||||
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
|
||||
|
||||
# install hooks for bits of third model
|
||||
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
|
||||
for block in diff_model.input_blocks:
|
||||
block.register_forward_pre_hook(send_me_to_gpu)
|
||||
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
|
||||
for block in diff_model.output_blocks:
|
||||
block.register_forward_pre_hook(send_me_to_gpu)
|
||||
|
||||
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||
|
||||
try:
|
||||
import xformers
|
||||
except ImportError:
|
||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
||||
|
||||
import math
|
||||
import torch
|
||||
from torch import einsum
|
||||
|
||||
from ldm.util import default
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
|
||||
# taken from https://github.com/Doggettx/stable-diffusion
|
||||
def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k_in = self.to_k(context) * self.scale
|
||||
v_in = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||
|
||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||
del s1
|
||||
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
|
||||
del q, k, v
|
||||
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
|
||||
return self.to_out(r2)
|
||||
|
||||
def cross_attention_attnblock_forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q1 = self.q(h_)
|
||||
k1 = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q1.shape
|
||||
|
||||
q2 = q1.reshape(b, c, h*w)
|
||||
del q1
|
||||
|
||||
q = q2.permute(0, 2, 1) # b,hw,c
|
||||
del q2
|
||||
|
||||
k = k1.reshape(b, c, h*w) # b,c,hw
|
||||
del k1
|
||||
|
||||
h_ = torch.zeros_like(k, device=q.device)
|
||||
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||
mem_required = tensor_size * 2.5
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
|
||||
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w2 = w1 * (int(c)**(-0.5))
|
||||
del w1
|
||||
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
|
||||
del w2
|
||||
|
||||
# attend to values
|
||||
v1 = v.reshape(b, c, h*w)
|
||||
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
del w3
|
||||
|
||||
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
del v1, w4
|
||||
|
||||
h2 = h_.reshape(b, c, h, w)
|
||||
del h_
|
||||
|
||||
h3 = self.proj_out(h2)
|
||||
del h2
|
||||
|
||||
h3 += x
|
||||
|
||||
return h3
|
@ -0,0 +1,806 @@
|
||||
import os
|
||||
import bisect
|
||||
import json
|
||||
from re import S
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
from omegaconf import OmegaConf
|
||||
from dotmap import DotMap
|
||||
import numpy as np
|
||||
from torch import autocast
|
||||
from einops import rearrange, repeat
|
||||
from torchvision.utils import make_grid
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.modules.attention import CrossAttention, HyperLogic
|
||||
from PIL import Image
|
||||
import k_diffusion as K
|
||||
import contextlib
|
||||
import random
|
||||
|
||||
class CallbackDelegate:
|
||||
total_steps = 0
|
||||
current_step = -1
|
||||
callback = None
|
||||
|
||||
def __init__(self, callback, total_steps) -> None:
|
||||
self.callback = callback
|
||||
self.total_steps = total_steps
|
||||
|
||||
def update(self, n = None):
|
||||
self.current_step += 1
|
||||
if self.callback:
|
||||
self.callback(self.current_step, self.total_steps)
|
||||
return n
|
||||
|
||||
|
||||
def seed_everything(seed: int):
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
def pil_upscale(image, scale=1):
|
||||
device = image.device
|
||||
dtype = image.dtype
|
||||
image = Image.fromarray((image.cpu().permute(1,2,0).numpy().astype(np.float32) * 255.).astype(np.uint8))
|
||||
if scale > 1:
|
||||
image = image.resize((int(image.width * scale), int(image.height * scale)), resample=Image.LANCZOS)
|
||||
image = np.array(image)
|
||||
image = image.astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
image = 2.*image - 1.
|
||||
image = repeat(image, '1 ... -> b ...', b=1)
|
||||
return image.to(device)
|
||||
|
||||
def fix_batch(tensor, bs):
|
||||
return torch.stack([tensor.squeeze(0)]*bs, dim=0)
|
||||
|
||||
def torch_gc():
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
# make uc and prompt shapes match via padding for long prompts
|
||||
# finetune
|
||||
null_cond = None
|
||||
def fix_cond_shapes(model, prompt_condition, uc):
|
||||
global null_cond
|
||||
if null_cond is None:
|
||||
null_cond = model.get_learned_conditioning([""])
|
||||
while prompt_condition.shape[1] > uc.shape[1]:
|
||||
uc = torch.cat((uc, null_cond.repeat((uc.shape[0], 1, 1))), axis=1)
|
||||
while prompt_condition.shape[1] < uc.shape[1]:
|
||||
prompt_condition = torch.cat((prompt_condition, null_cond.repeat((prompt_condition.shape[0], 1, 1))), axis=1)
|
||||
return prompt_condition, uc
|
||||
|
||||
# mix conditioning vectors for prompts
|
||||
# @aero
|
||||
def prompt_mixing(model, prompt_body, batch_size):
|
||||
if "|" in prompt_body:
|
||||
prompt_parts = prompt_body.split("|")
|
||||
prompt_total_power = 0
|
||||
prompt_sum = None
|
||||
for prompt_part in prompt_parts:
|
||||
prompt_power = 1
|
||||
if ":" in prompt_part:
|
||||
prompt_sub_parts = prompt_part.split(":")
|
||||
try:
|
||||
prompt_power = float(prompt_sub_parts[1])
|
||||
prompt_part = prompt_sub_parts[0]
|
||||
except:
|
||||
print("Error parsing prompt power! Assuming 1")
|
||||
prompt_vector = model.get_learned_conditioning([prompt_part])
|
||||
if prompt_sum is None:
|
||||
prompt_sum = prompt_vector * prompt_power
|
||||
else:
|
||||
prompt_sum, prompt_vector = fix_cond_shapes(model, prompt_sum, prompt_vector)
|
||||
prompt_sum = prompt_sum + (prompt_vector * prompt_power)
|
||||
prompt_total_power = prompt_total_power + prompt_power
|
||||
return fix_batch(prompt_sum / prompt_total_power, batch_size)
|
||||
else:
|
||||
return fix_batch(model.get_learned_conditioning([prompt_body]), batch_size)
|
||||
|
||||
def sample_start_noise(seed, C, H, W, f, device="cuda"):
|
||||
if seed:
|
||||
gen = torch.Generator(device=device)
|
||||
gen.manual_seed(seed)
|
||||
noise = torch.randn([C, (H) // f, (W) // f], generator=gen, device=device).unsqueeze(0)
|
||||
else:
|
||||
noise = torch.randn([C, (H) // f, (W) // f], device=device).unsqueeze(0)
|
||||
return noise
|
||||
|
||||
def sample_start_noise_special(seed, request, device="cuda"):
|
||||
if seed:
|
||||
gen = torch.Generator(device=device)
|
||||
gen.manual_seed(seed)
|
||||
noise = torch.randn([request.latent_channels, request.height // request.downsampling_factor, request.width // request.downsampling_factor], generator=gen, device=device).unsqueeze(0)
|
||||
else:
|
||||
noise = torch.randn([request.latent_channels, request.height // request.downsampling_factor, request.width // request.downsampling_factor], device=device).unsqueeze(0)
|
||||
return noise
|
||||
|
||||
@torch.no_grad()
|
||||
#@torch.autocast("cuda", enabled=True, dtype=torch.float16)
|
||||
def encode_image(image, model):
|
||||
if isinstance(image, Image.Image):
|
||||
image = np.array(image)
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
if isinstance(image, np.ndarray):
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
#dtype = image.dtype
|
||||
image = image.to(torch.float32)
|
||||
#gets image as numpy array and returns as tensor
|
||||
def preprocess_vqgan(x):
|
||||
x = x / 255.0
|
||||
x = 2.*x - 1.
|
||||
return x
|
||||
|
||||
image = image.permute(2, 0, 1).unsqueeze(0).float().cuda()
|
||||
image = preprocess_vqgan(image)
|
||||
image = model.encode(image).sample()
|
||||
#image = image.to(dtype)
|
||||
|
||||
return image
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_image(image, model):
|
||||
def custom_to_pil(x):
|
||||
x = x.detach().float().cpu()
|
||||
x = torch.clamp(x, -1., 1.)
|
||||
x = (x + 1.)/2.
|
||||
x = x.permute(0, 2, 3, 1)#.numpy()
|
||||
#x = (255*x).astype(np.uint8)
|
||||
#x = Image.fromarray(x)
|
||||
#if not x.mode == "RGB":
|
||||
# x = x.convert("RGB")
|
||||
return x
|
||||
|
||||
image = model.decode(image)
|
||||
image = custom_to_pil(image)
|
||||
return image
|
||||
|
||||
class VectorAdjustPrior(nn.Module):
|
||||
def __init__(self, hidden_size, inter_dim=64):
|
||||
super().__init__()
|
||||
self.vector_proj = nn.Linear(hidden_size*2, inter_dim, bias=True)
|
||||
self.out_proj = nn.Linear(hidden_size+inter_dim, hidden_size, bias=True)
|
||||
|
||||
def forward(self, z):
|
||||
b, s = z.shape[0:2]
|
||||
x1 = torch.mean(z, dim=1).repeat(s, 1)
|
||||
x2 = z.reshape(b*s, -1)
|
||||
x = torch.cat((x1, x2), dim=1)
|
||||
x = self.vector_proj(x)
|
||||
x = torch.cat((x2, x), dim=1)
|
||||
x = self.out_proj(x)
|
||||
x = x.reshape(b, s, -1)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def load_model(cls, model_path, hidden_size=768, inter_dim=64):
|
||||
model = cls(hidden_size=hidden_size, inter_dim=inter_dim)
|
||||
model.load_state_dict(torch.load(model_path)["state_dict"])
|
||||
return model
|
||||
|
||||
class StableInterface(nn.Module):
|
||||
def __init__(self, model, thresholder = None):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.sigma_to_t = model.sigma_to_t
|
||||
self.thresholder = thresholder
|
||||
self.get_sigmas = model.get_sigmas
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
x_two = torch.cat([x] * 2)
|
||||
sigma_two = torch.cat([sigma] * 2)
|
||||
cond_full = torch.cat([uncond, cond])
|
||||
uncond, cond = self.inner_model(x_two, sigma_two, cond=cond_full).chunk(2)
|
||||
x_0 = uncond + (cond - uncond) * cond_scale
|
||||
if self.thresholder is not None:
|
||||
x_0 = self.thresholder(x_0)
|
||||
|
||||
return x_0
|
||||
|
||||
class StableDiffusionModel(nn.Module):
|
||||
def __init__(self, config):
|
||||
nn.Module.__init__(self)
|
||||
self.config = config
|
||||
self.premodules = None
|
||||
if Path(self.config.model_path).is_dir():
|
||||
config.logger.info(f"Loading model from folder {self.config.model_path}")
|
||||
model, model_config = self.from_folder(config.model_path)
|
||||
|
||||
elif Path(self.config.model_path).is_file():
|
||||
config.logger.info(f"Loading model from file {self.config.model_path}")
|
||||
model, model_config = self.from_file(config.model_path)
|
||||
|
||||
else:
|
||||
raise Exception("Invalid model path!")
|
||||
|
||||
if config.dtype == "float16":
|
||||
typex = torch.float16
|
||||
else:
|
||||
typex = torch.float32
|
||||
self.model = model.to(config.device).to(typex)
|
||||
if self.config.vae_path:
|
||||
ckpt=torch.load(self.config.vae_path, map_location="cpu")
|
||||
loss = []
|
||||
for i in ckpt["state_dict"].keys():
|
||||
if i[0:4] == "loss":
|
||||
loss.append(i)
|
||||
for i in loss:
|
||||
del ckpt["state_dict"][i]
|
||||
|
||||
model.first_stage_model = model.first_stage_model.float()
|
||||
model.first_stage_model.load_state_dict(ckpt["state_dict"])
|
||||
model.first_stage_model = model.first_stage_model.float()
|
||||
del ckpt
|
||||
del loss
|
||||
config.logger.info(f"Using VAE from {self.config.vae_path}")
|
||||
|
||||
if self.config.penultimate == "1":
|
||||
model.cond_stage_model.return_layer = -2
|
||||
model.cond_stage_model.do_final_ln = True
|
||||
config.logger.info(f"CLIP: Using penultimate layer")
|
||||
|
||||
if self.config.clip_contexts > 1:
|
||||
model.cond_stage_model.clip_extend = True
|
||||
model.cond_stage_model.max_clip_extend = 75 * self.config.clip_contexts
|
||||
|
||||
model.cond_stage_model.inference_mode = True
|
||||
self.k_model = K.external.CompVisDenoiser(model)
|
||||
self.k_model = StableInterface(self.k_model)
|
||||
self.device = config.device
|
||||
self.model_config = model_config
|
||||
self.plms = PLMSSampler(model)
|
||||
self.ddim = DDIMSampler(model)
|
||||
self.ema_manager = self.model.ema_scope
|
||||
if self.config.enable_ema == "0":
|
||||
self.ema_manager = contextlib.nullcontext
|
||||
config.logger.info("Disabling EMA")
|
||||
else:
|
||||
config.logger.info(f"Using EMA")
|
||||
self.sampler_map = {
|
||||
'plms': self.plms.sample,
|
||||
'ddim': self.ddim.sample,
|
||||
'k_euler': K.sampling.sample_euler,
|
||||
'k_euler_ancestral': K.sampling.sample_euler_ancestral,
|
||||
'k_heun': K.sampling.sample_heun,
|
||||
'k_dpm_2': K.sampling.sample_dpm_2,
|
||||
'k_dpm_2_ancestral': K.sampling.sample_dpm_2_ancestral,
|
||||
'k_lms': K.sampling.sample_lms,
|
||||
}
|
||||
if config.prior_path:
|
||||
self.prior = VectorAdjustPrior.load_model(config.prior_path).to(self.device)
|
||||
self.copied_ema = False
|
||||
|
||||
@property
|
||||
def get_default_config(self):
|
||||
dict_config = {
|
||||
'steps': 30,
|
||||
'sampler': "k_euler_ancestral",
|
||||
'n_samples': 1,
|
||||
'image': None,
|
||||
'fixed_code': False,
|
||||
'ddim_eta': 0.0,
|
||||
'height': 512,
|
||||
'width': 512,
|
||||
'latent_channels': 4,
|
||||
'downsampling_factor': 8,
|
||||
'scale': 12.0,
|
||||
'dynamic_threshold': None,
|
||||
'seed': None,
|
||||
'stage_two_seed': None,
|
||||
'module': None,
|
||||
'masks': None,
|
||||
'output': None,
|
||||
}
|
||||
return DotMap(dict_config)
|
||||
|
||||
def from_folder(self, folder):
|
||||
folder = Path(folder)
|
||||
model_config = OmegaConf.load(folder / "config.yaml")
|
||||
if (folder / "pruned.ckpt").is_file():
|
||||
model_path = folder / "pruned.ckpt"
|
||||
else:
|
||||
model_path = folder / "model.ckpt"
|
||||
model = self.load_model_from_config(model_config, model_path)
|
||||
return model, model_config
|
||||
|
||||
def from_path(self, file):
|
||||
default_config = Path(self.config.default_config)
|
||||
if not default_config.is_file():
|
||||
raise Exception("Default config to load not found! Either give a folder on MODEL_PATH or specify a config to use with this checkpoint on DEFAULT_CONFIG")
|
||||
model_config = OmegaConf.load(default_config)
|
||||
model = self.load_model_from_config(model_config, file)
|
||||
return model, model_config
|
||||
|
||||
def load_model_from_config(self, config, ckpt, verbose=False):
|
||||
print(f"Loading model from {ckpt}")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
if "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
|
||||
sd = pl_sd.get('state_dict', pl_sd)
|
||||
|
||||
model = instantiate_from_config(config.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0 and verbose:
|
||||
print("missing keys:")
|
||||
print(m)
|
||||
if len(u) > 0 and verbose:
|
||||
print("unexpected keys:")
|
||||
print(u)
|
||||
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.autocast("cuda", enabled=True, dtype=torch.float16)
|
||||
def sample(self, request, callback=None):
|
||||
if request.module is not None:
|
||||
if request.module == "vanilla":
|
||||
pass
|
||||
|
||||
else:
|
||||
module = self.premodules[request.module]
|
||||
CrossAttention.set_hypernetwork(module)
|
||||
|
||||
if request.seed is not None:
|
||||
seed_everything(request.seed)
|
||||
|
||||
if request.image is not None:
|
||||
request.steps = 50
|
||||
#request.sampler = "ddim_img2img" #enforce ddim for now
|
||||
if request.sampler == "plms":
|
||||
request.sampler = "k_lms"
|
||||
if request.sampler == "ddim":
|
||||
request.sampler = "k_lms"
|
||||
|
||||
self.ddim.make_schedule(ddim_num_steps=request.steps, ddim_eta=request.ddim_eta, verbose=False)
|
||||
start_code = encode_image(request.image, self.model.first_stage_model).to(self.device)
|
||||
start_code = self.model.get_first_stage_encoding(start_code)
|
||||
start_code = torch.repeat_interleave(start_code, request.n_samples, dim=0)
|
||||
|
||||
main_noise = []
|
||||
start_noise = []
|
||||
for seed in range(request.seed, request.seed+request.n_samples):
|
||||
main_noise.append(sample_start_noise(seed, request.latent_channels, request.height, request.width, request.downsampling_factor, self.device))
|
||||
start_noise.append(sample_start_noise(seed, request.latent_channels, request.height, request.width, request.downsampling_factor, self.device))
|
||||
|
||||
main_noise = torch.cat(main_noise, dim=0)
|
||||
start_noise = torch.cat(start_noise, dim=0)
|
||||
|
||||
start_code = start_code + (start_noise * request.noise)
|
||||
t_enc = int(request.strength * request.steps)
|
||||
|
||||
if request.sampler.startswith("k_"):
|
||||
sampler = "k-diffusion"
|
||||
|
||||
elif request.sampler == 'ddim_img2img':
|
||||
sampler = 'img2img'
|
||||
|
||||
else:
|
||||
sampler = "normal"
|
||||
|
||||
if request.image is None:
|
||||
main_noise = []
|
||||
for seed_offset in range(request.n_samples):
|
||||
if request.masks is not None:
|
||||
noise_x = sample_start_noise_special(request.seed, request, self.device)
|
||||
else:
|
||||
noise_x = sample_start_noise_special(request.seed+seed_offset, request, self.device)
|
||||
|
||||
if request.masks is not None:
|
||||
for maskobj in request.masks:
|
||||
mask_seed = maskobj["seed"]
|
||||
mask = maskobj["mask"]
|
||||
mask = np.asarray(mask)
|
||||
mask = torch.from_numpy(mask).clone().to(self.device).permute(2, 0, 1)
|
||||
mask = mask.float() / 255.0
|
||||
# convert RGB or grayscale image into 4-channel
|
||||
mask = mask[0].unsqueeze(0)
|
||||
mask = torch.repeat_interleave(mask, request.latent_channels, dim=0)
|
||||
mask = (mask < 0.5).float()
|
||||
|
||||
# interpolate start noise
|
||||
noise_x = (noise_x * (1-mask)) + (sample_start_noise_special(mask_seed+seed_offset, request, self.device) * mask)
|
||||
|
||||
main_noise.append(noise_x)
|
||||
|
||||
main_noise = torch.cat(main_noise, dim=0)
|
||||
start_code = main_noise
|
||||
|
||||
prompt = [request.prompt]
|
||||
prompt_condition = prompt_mixing(self.model, prompt[0], 1)
|
||||
if hasattr(self, "prior") and request.mitigate:
|
||||
prompt_condition = self.prior(prompt_condition)
|
||||
|
||||
uc = None
|
||||
if request.scale != 1.0:
|
||||
if request.uc is not None:
|
||||
uc = [request.uc]
|
||||
uc = prompt_mixing(self.model, uc[0], 1)
|
||||
else:
|
||||
if self.config.quality_hack == "1":
|
||||
uc = ["Tags: lowres"]
|
||||
uc = prompt_mixing(self.model, uc[0], 1)
|
||||
else:
|
||||
uc = self.model.get_learned_conditioning([""])
|
||||
prompt_condition, uc = fix_cond_shapes(self.model, prompt_condition, uc)
|
||||
|
||||
shape = [
|
||||
request.latent_channels,
|
||||
request.height // request.downsampling_factor,
|
||||
request.width // request.downsampling_factor
|
||||
]
|
||||
|
||||
c_dele = CallbackDelegate(callback, request.steps * request.n_samples)
|
||||
|
||||
# handle images one at a time because batches eat absurd VRAM
|
||||
sampless = []
|
||||
for main_noise, start_code in zip(main_noise.chunk(request.n_samples), start_code.chunk(request.n_samples)):
|
||||
if sampler == "normal":
|
||||
with self.ema_manager():
|
||||
samples, _ = self.sampler_map[request.sampler](
|
||||
S=request.steps,
|
||||
conditioning=prompt_condition,
|
||||
batch_size=1,
|
||||
shape=shape,
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=request.scale,
|
||||
unconditional_conditioning=uc,
|
||||
eta=request.ddim_eta,
|
||||
dynamic_threshold=request.dynamic_threshold,
|
||||
x_T=start_code,
|
||||
callback=c_dele.update
|
||||
)
|
||||
|
||||
elif sampler == "k-diffusion":
|
||||
with self.ema_manager():
|
||||
sigmas = self.k_model.get_sigmas(request.steps)
|
||||
if request.image is not None:
|
||||
noise = main_noise * sigmas[request.steps - t_enc - 1]
|
||||
start_code = start_code + noise
|
||||
sigmas = sigmas[request.steps - t_enc - 1:]
|
||||
|
||||
else:
|
||||
start_code = start_code * sigmas[0]
|
||||
|
||||
extra_args = {'cond': prompt_condition, 'uncond': uc, 'cond_scale': request.scale}
|
||||
samples = self.sampler_map[request.sampler](
|
||||
self.k_model,
|
||||
start_code,
|
||||
sigmas,
|
||||
request.seed,
|
||||
callback=c_dele.update,
|
||||
extra_args=extra_args
|
||||
)
|
||||
|
||||
sampless.append(samples)
|
||||
torch_gc()
|
||||
|
||||
images = []
|
||||
for samples in sampless:
|
||||
with torch.autocast("cuda", enabled=False):
|
||||
x_samples_ddim = self.model.decode_first_stage(samples.float())
|
||||
#x_samples_ddim = decode_image(samples, self.model.first_stage_model)
|
||||
#x_samples_ddim = self.model.first_stage_model.decode(samples.float())
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
for x_sample in x_samples_ddim:
|
||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
x_sample = np.ascontiguousarray(x_sample)
|
||||
images.append(x_sample)
|
||||
|
||||
torch_gc()
|
||||
|
||||
if request.seed is not None:
|
||||
torch.seed()
|
||||
np.random.seed()
|
||||
|
||||
#set hypernetwork to none after generation
|
||||
CrossAttention.set_hypernetwork(None)
|
||||
|
||||
return images
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_two_stages(self, request, callback = None):
|
||||
request = DotMap(request)
|
||||
if request.seed is not None:
|
||||
seed_everything(request.seed)
|
||||
|
||||
if request.plms:
|
||||
sampler = self.plms
|
||||
else:
|
||||
sampler = self.ddim
|
||||
|
||||
start_code = None
|
||||
if request.fixed_code:
|
||||
start_code = torch.randn([
|
||||
request.n_samples,
|
||||
request.latent_channels,
|
||||
request.height // request.downsampling_factor,
|
||||
request.width // request.downsampling_factor,
|
||||
], device=self.device)
|
||||
|
||||
prompt = [request.prompt] * request.n_samples
|
||||
prompt_condition = prompt_mixing(self.model, prompt[0], request.n_samples)
|
||||
|
||||
uc = None
|
||||
if request.scale != 1.0:
|
||||
uc = self.model.get_learned_conditioning(request.n_samples * [""])
|
||||
prompt_condition, uc = fix_cond_shapes(self.model, prompt_condition, uc)
|
||||
|
||||
shape = [
|
||||
request.latent_channels,
|
||||
request.height // request.downsampling_factor,
|
||||
request.width // request.downsampling_factor
|
||||
]
|
||||
|
||||
c_dele = CallbackDelegate(callback, request.steps * request.n_samples)
|
||||
|
||||
with torch.autocast("cuda", enabled=self.config.amp):
|
||||
with ema_manager():
|
||||
samples, _ = sampler.sample(
|
||||
S=request.steps,
|
||||
conditioning=prompt_condition,
|
||||
batch_size=request.n_samples,
|
||||
shape=shape,
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=request.scale,
|
||||
unconditional_conditioning=uc,
|
||||
eta=request.ddim_eta,
|
||||
dynamic_threshold=request.dynamic_threshold,
|
||||
x_T=start_code,
|
||||
callback=c_dele.update
|
||||
)
|
||||
|
||||
x_samples_ddim = self.model.decode_first_stage(samples)
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).squeeze(0)
|
||||
x_samples_ddim = pil_upscale(x_samples_ddim, scale=2)
|
||||
|
||||
if request.stage_two_seed is not None:
|
||||
torch.manual_seed(request.stage_two_seed)
|
||||
np.random.seed(request.stage_two_seed)
|
||||
|
||||
with torch.autocast("cuda", enabled=self.config.amp):
|
||||
with ema_manager():
|
||||
init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(x_samples_ddim))
|
||||
self.ddim.make_schedule(ddim_num_steps=request.steps, ddim_eta=request.ddim_eta, verbose=False)
|
||||
t_enc = int(request.strength * request.steps)
|
||||
|
||||
print("init latent shape:")
|
||||
print(init_latent.shape)
|
||||
|
||||
init_latent = init_latent + (torch.randn_like(init_latent) * request.noise)
|
||||
|
||||
prompt_condition = prompt_mixing(self.model, prompt[0], request.n_samples)
|
||||
|
||||
uc = None
|
||||
if request.scale != 1.0:
|
||||
uc = self.model.get_learned_conditioning(request.n_samples * [""])
|
||||
prompt_condition, uc = fix_cond_shapes(self.model, prompt_condition, uc)
|
||||
|
||||
# encode (scaled latent)
|
||||
start_code_terped=None
|
||||
z_enc = self.ddim.stochastic_encode(init_latent, torch.tensor([t_enc]*request.n_samples).to(self.device), noise=start_code_terped)
|
||||
# decode it
|
||||
samples = self.ddim.decode(z_enc, prompt_condition, t_enc, unconditional_guidance_scale=request.scale,
|
||||
unconditional_conditioning=uc,)
|
||||
|
||||
x_samples_ddim = self.model.decode_first_stage(samples)
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
images = []
|
||||
for x_sample in x_samples_ddim:
|
||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
x_sample = np.ascontiguousarray(x_sample)
|
||||
images.append(x_sample)
|
||||
|
||||
if request.seed is not None:
|
||||
torch.seed()
|
||||
np.random.seed()
|
||||
|
||||
return images
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_from_image(self, request):
|
||||
return
|
||||
|
||||
class DalleMiniModel(nn.Module):
|
||||
def __init__(self, config):
|
||||
nn.Module.__init__(self)
|
||||
from min_dalle import MinDalle
|
||||
|
||||
self.config = config
|
||||
self.model = MinDalle(
|
||||
models_root=config.model_path,
|
||||
dtype=torch.float16,
|
||||
device='cuda',
|
||||
is_mega=True,
|
||||
is_reusable=True
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, request, callback=None):
|
||||
if request.seed is not None:
|
||||
seed = request.seed
|
||||
else:
|
||||
seed = -1
|
||||
|
||||
images = self.model.generate_images(
|
||||
text=request.prompt,
|
||||
seed=seed,
|
||||
grid_size=request.grid_size,
|
||||
is_seamless=False,
|
||||
temperature=request.temp,
|
||||
top_k=request.top_k,
|
||||
supercondition_factor=request.scale,
|
||||
is_verbose=False
|
||||
)
|
||||
images = images.to('cpu').numpy()
|
||||
images = images.astype(np.uint8)
|
||||
images = np.ascontiguousarray(images)
|
||||
|
||||
if request.seed is not None:
|
||||
torch.seed()
|
||||
np.random.seed()
|
||||
|
||||
return images
|
||||
|
||||
def apply_temp(logits, temperature):
|
||||
logits = logits / temperature
|
||||
return logits
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(forward, prompt_tokens, tokenizer, tokens_to_generate=50, ds=False, ops_list=[{"temp": 0.9}], hypernetwork=None, non_deterministic=False, fully_deterministic=False):
|
||||
in_tokens = prompt_tokens
|
||||
context = prompt_tokens
|
||||
generated = torch.zeros(len(ops_list), 0, dtype=torch.long).to(in_tokens.device)
|
||||
kv = None
|
||||
if non_deterministic:
|
||||
torch.seed()
|
||||
#soft_required = ["top_k", "top_p"]
|
||||
op_map = {
|
||||
"temp": apply_temp,
|
||||
}
|
||||
|
||||
for _ in range(tokens_to_generate):
|
||||
if ds:
|
||||
logits, kv = forward(in_tokens, past_key_values=kv, use_cache=True)
|
||||
else:
|
||||
logits, kv = forward(in_tokens, cache=True, kv=kv, hypernetwork=hypernetwork)
|
||||
logits = logits[:, -1, :] #get the last token in the seq
|
||||
logits = torch.log_softmax(logits, dim=-1)
|
||||
|
||||
batch = []
|
||||
for i, ops in enumerate(ops_list):
|
||||
item = logits[i, ...].unsqueeze(0)
|
||||
ctx = context[i, ...].unsqueeze(0)
|
||||
for op, value in ops.items():
|
||||
if op == "rep_pen":
|
||||
item = op_map[op](ctx, item, **value)
|
||||
|
||||
else:
|
||||
item = op_map[op](item, value)
|
||||
|
||||
batch.append(item)
|
||||
|
||||
logits = torch.cat(batch, dim=0)
|
||||
logits = torch.softmax(logits, dim=-1)
|
||||
|
||||
#fully_deterministic makes it deterministic across the batch
|
||||
if fully_deterministic:
|
||||
logits = logits.split(1, dim=0)
|
||||
logit_list = []
|
||||
for logit in logits:
|
||||
torch.manual_seed(69)
|
||||
logit_list.append(torch.multinomial(logit, 1))
|
||||
|
||||
logits = torch.cat(logit_list, dim=0)
|
||||
|
||||
else:
|
||||
logits = torch.multinomial(logits, 1)
|
||||
|
||||
if logits[0, 0] == 48585:
|
||||
if generated[0, -1] == 1400:
|
||||
pass
|
||||
elif generated[0, -1] == 3363:
|
||||
return "safe", "none"
|
||||
else:
|
||||
return "notsafe", tokenizer.decode(generated.squeeze()).split("Output: ")[-1]
|
||||
|
||||
generated = torch.cat([generated, logits], dim=-1)
|
||||
context = torch.cat([context, logits], dim=-1)
|
||||
in_tokens = logits
|
||||
|
||||
return "unknown", tokenizer.decode(generated.squeeze())
|
||||
|
||||
|
||||
class BasedformerModel(nn.Module):
|
||||
def __init__(self, config):
|
||||
nn.Module.__init__(self)
|
||||
from basedformer import lm_utils
|
||||
from transformers import GPT2TokenizerFast
|
||||
self.config = config
|
||||
self.model = lm_utils.load_from_path(config.model_path).half().cuda()
|
||||
self.model = self.model.convert_to_ds()
|
||||
self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, request, callback=None):
|
||||
prompt = request.prompt
|
||||
prompt = self.tokenizer.encode("Input: " + prompt, return_tensors='pt').cuda().long()
|
||||
prompt = torch.cat([prompt, torch.tensor([[49527]], dtype=torch.long).cuda()], dim=1)
|
||||
is_safe, corrected = generate(self.model.module, prompt, self.tokenizer, tokens_to_generate=150, ds=True)
|
||||
return is_safe, corrected
|
||||
|
||||
class EmbedderModel(nn.Module):
|
||||
def __init__(self, config=None):
|
||||
nn.Module.__init__(self)
|
||||
from sentence_transformers import SentenceTransformer
|
||||
self.model = SentenceTransformer('./models/sentence-transformers_all-MiniLM-L6-v2').cuda()
|
||||
self.tags = [tuple(x) for x in json.load(open("models/tags.json"))]
|
||||
self.knn = self.load_knn("models/tags.index")
|
||||
print("Loaded tag suggestion model using phrase embeddings")
|
||||
|
||||
def load_knn(self, filename):
|
||||
import faiss
|
||||
try:
|
||||
return faiss.read_index(filename)
|
||||
except RuntimeError:
|
||||
print(f"Generating tag embedding index for {len(self.tags)} tags.")
|
||||
i = faiss.IndexFlatL2(384)
|
||||
i.add(self([name for name, count in self.tags]))
|
||||
faiss.write_index(i, filename)
|
||||
return i
|
||||
|
||||
def __call__(self, sentences):
|
||||
with torch.no_grad():
|
||||
sentence_embeddings = self.model.encode(sentences)
|
||||
return sentence_embeddings
|
||||
|
||||
def get_top_k(self, text):
|
||||
#check if text is a substring in tag_count.keys()
|
||||
found = []
|
||||
a = bisect.bisect_left(self.tags, (text,))
|
||||
b = bisect.bisect_left(self.tags, (text + '\xff',), lo=a)
|
||||
for tag, count in self.tags[a:b]:
|
||||
if len(tag) >= len(text) and tag.startswith(text):
|
||||
found.append([tag, count, 0])
|
||||
|
||||
results = []
|
||||
embedding = self([text])
|
||||
k = 15
|
||||
D, I = self.knn.search(embedding, k)
|
||||
D, I = D.squeeze(), I.squeeze()
|
||||
for id, prob in zip(I, D):
|
||||
tag, count = self.tags[id]
|
||||
results.append([tag, count, prob])
|
||||
|
||||
found.sort(key=lambda x: x[1], reverse=True)
|
||||
found = found[:5]
|
||||
# found = heapq.nlargest(5, found, key=lambda x: x[1])
|
||||
results_tags = [x[0] for x in found]
|
||||
for result in results.copy():
|
||||
if result[0] in results_tags:
|
||||
results.remove(result)
|
||||
|
||||
results = sorted(results, key=lambda x: x[2], reverse=True)
|
||||
#filter results for >0.5 confidence unless it has the search text in it and confidence is >0.4
|
||||
results = [x for x in results if x[2] > 0.5 or (x[2] > 0.4 and text in x[0])]
|
||||
if found:
|
||||
results = found + results
|
||||
|
||||
#max 10 results
|
||||
results = results[:10]
|
||||
results = sorted(results, key=lambda x: x[1], reverse=True)
|
||||
return results
|
@ -0,0 +1,221 @@
|
||||
import traceback
|
||||
from dotmap import DotMap
|
||||
import math
|
||||
from io import BytesIO
|
||||
import base64
|
||||
import random
|
||||
|
||||
v1pp_defaults = {
|
||||
'steps': 50,
|
||||
'sampler': "plms",
|
||||
'image': None,
|
||||
'fixed_code': False,
|
||||
'ddim_eta': 0.0,
|
||||
'height': 512,
|
||||
'width': 512,
|
||||
'latent_channels': 4,
|
||||
'downsampling_factor': 8,
|
||||
'scale': 7.0,
|
||||
'dynamic_threshold': None,
|
||||
'seed': None,
|
||||
'stage_two_seed': None,
|
||||
'module': None,
|
||||
'masks': None,
|
||||
}
|
||||
|
||||
v1pp_forced_defaults = {
|
||||
'latent_channels': 4,
|
||||
'downsampling_factor': 8,
|
||||
}
|
||||
|
||||
dalle_mini_defaults = {
|
||||
'temp': 1.0,
|
||||
'top_k': 256,
|
||||
'scale': 16,
|
||||
'grid_size': 4,
|
||||
}
|
||||
|
||||
dalle_mini_forced_defaults = {
|
||||
}
|
||||
|
||||
defaults = {
|
||||
'stable-diffusion': (v1pp_defaults, v1pp_forced_defaults),
|
||||
'dalle-mini': (dalle_mini_defaults, dalle_mini_forced_defaults),
|
||||
'basedformer': ({}, {}),
|
||||
'embedder': ({}, {}),
|
||||
}
|
||||
|
||||
samplers = [
|
||||
"plms",
|
||||
"ddim",
|
||||
"k_euler",
|
||||
"k_euler_ancestral",
|
||||
"k_heun",
|
||||
"k_dpm_2",
|
||||
"k_dpm_2_ancestral",
|
||||
"k_lms"
|
||||
]
|
||||
|
||||
def closest_multiple(num, mult):
|
||||
num_int = int(num)
|
||||
floor = math.floor(num_int / mult) * mult
|
||||
ceil = math.ceil(num_int / mult) * mult
|
||||
return floor if (num_int - floor) < (ceil - num_int) else ceil
|
||||
|
||||
def sanitize_stable_diffusion(request, config):
|
||||
if request.steps > 50:
|
||||
return False, "steps must be smaller than 50"
|
||||
|
||||
if request.width * request.height == 0:
|
||||
return False, "width and height must be non-zero"
|
||||
|
||||
if request.width <= 0:
|
||||
return False, "width must be positive"
|
||||
|
||||
if request.height <= 0:
|
||||
return False, "height must be positive"
|
||||
|
||||
if request.steps <= 0:
|
||||
return False, "steps must be positive"
|
||||
|
||||
if request.ddim_eta < 0:
|
||||
return False, "ddim_eta shouldn't be negative"
|
||||
|
||||
if request.scale < 1.0:
|
||||
return False, "scale should be at least 1.0"
|
||||
|
||||
if request.dynamic_threshold is not None and request.dynamic_threshold < 0:
|
||||
return False, "dynamic_threshold shouldn't be negative"
|
||||
|
||||
if request.width * request.height >= 1024*1025:
|
||||
return False, "width and height must be less than 1024*1025"
|
||||
|
||||
if request.strength < 0.0 or request.strength >= 1.0:
|
||||
return False, "strength should be more than 0.0 and less than 1.0"
|
||||
|
||||
if request.noise < 0.0 or request.noise > 1.0:
|
||||
return False, "noise should be more than 0.0 and less than 1.0"
|
||||
|
||||
if request.advanced:
|
||||
request.width = closest_multiple(request.width // 2, 64)
|
||||
request.height = closest_multiple(request.height // 2, 64)
|
||||
|
||||
if request.sampler not in samplers:
|
||||
return False, "sampler should be one of {}".format(samplers)
|
||||
|
||||
if request.seed is None:
|
||||
state = random.getstate()
|
||||
request.seed = random.randint(0, 2**32)
|
||||
random.setstate(state)
|
||||
|
||||
if request.module is not None:
|
||||
if request.module not in config.model.premodules and request.module != "vanilla":
|
||||
return False, "module should be one of: " + ", ".join(config.model.premodules)
|
||||
|
||||
max_gens = 100
|
||||
if 0:
|
||||
num_gen_tiers = [(1024*512, 4), (640*640, 6), (704*512, 8), (512*512, 16), (384*640, 18)]
|
||||
pixel_count = request.width * request.height
|
||||
for tier in num_gen_tiers:
|
||||
if pixel_count <= tier[0]:
|
||||
max_gens = tier[1]
|
||||
else:
|
||||
break
|
||||
if request.n_samples > max_gens:
|
||||
return False, f"requested more ({request.n_samples}) images than possible at this resolution"
|
||||
|
||||
if request.image is not None:
|
||||
#decode from base64
|
||||
try:
|
||||
request.image = base64.b64decode(request.image.encode('utf-8'))
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return False, "image is not valid base64"
|
||||
#check if image is valid
|
||||
try:
|
||||
from PIL import Image
|
||||
image = Image.open(BytesIO(request.image))
|
||||
image.verify()
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return False, "image is not valid"
|
||||
|
||||
#image is valid, load it again(still check again, verify() can't be sure as it doesn't decode.)
|
||||
try:
|
||||
image = Image.open(BytesIO(request.image))
|
||||
image = image.convert('RGB')
|
||||
image = image.resize((request.width, request.height), resample=Image.Resampling.LANCZOS)
|
||||
request.image = image
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return False, "Error while opening and cleaning image"
|
||||
|
||||
if request.masks is not None:
|
||||
masks = request.masks
|
||||
for x in range(len(masks)):
|
||||
image = masks[x]["mask"]
|
||||
try:
|
||||
image_bytes = base64.b64decode(image.encode('utf-8'))
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return False, "image is not valid base64"
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
image = Image.open(BytesIO(image_bytes))
|
||||
image.verify()
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return False, "image is not valid"
|
||||
|
||||
#image is valid, load it again(still check again, verify() can't be sure as it doesn't decode.)
|
||||
try:
|
||||
image = Image.open(BytesIO(image_bytes))
|
||||
#image = image.convert('RGB')
|
||||
image = image.resize((request.width//request.downsampling_factor, request.height//request.downsampling_factor), resample=Image.Resampling.LANCZOS)
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return False, "Error while opening and cleaning image"
|
||||
|
||||
masks[x]["mask"] = image
|
||||
|
||||
return True, request
|
||||
|
||||
def sanitize_dalle_mini(request):
|
||||
return True, request
|
||||
|
||||
def sanitize_basedformer(request):
|
||||
return True, request
|
||||
|
||||
def sanitize_embedder(request):
|
||||
return True, request
|
||||
|
||||
def sanitize_input(config, request):
|
||||
"""
|
||||
Sanitize the input data and set defaults
|
||||
"""
|
||||
request = DotMap(request)
|
||||
default, forced_default = defaults[config.model_name]
|
||||
for k, v in default.items():
|
||||
if k not in request:
|
||||
request[k] = v
|
||||
|
||||
for k, v in forced_default.items():
|
||||
request[k] = v
|
||||
|
||||
if config.model_name == 'stable-diffusion':
|
||||
return sanitize_stable_diffusion(request, config)
|
||||
|
||||
elif config.model_name == 'dalle-mini':
|
||||
return sanitize_dalle_mini(request)
|
||||
|
||||
elif config.model_name == 'basedformer':
|
||||
return sanitize_basedformer(request)
|
||||
|
||||
elif config.model_name == "embedder":
|
||||
return sanitize_embedder(request)
|
@ -0,0 +1,37 @@
|
||||
@echo off
|
||||
|
||||
:: UNCOMMENT (remove ::) if you want the backend to automatically save files for you
|
||||
set SAVE_FILES="1"
|
||||
|
||||
if not defined PYTHON (set PYTHON=python)
|
||||
if not defined VENV_DIR (set VENV_DIR=venv)
|
||||
|
||||
set ERROR_REPORTING=FALSE
|
||||
|
||||
:: set PYTHON="%~dp0%VENV_DIR%\Scripts\Python.exe"
|
||||
|
||||
set DTYPE=float32
|
||||
set CLIP_CONTEXTS=3
|
||||
set AMP=1
|
||||
set MODEL=stable-diffusion
|
||||
set DEV=False
|
||||
set MODEL_PATH=models/animefull-final-pruned
|
||||
::these aren't actually used by the site?
|
||||
::set MODULE_PATH=models/modules
|
||||
::unclear if these are used either
|
||||
::set PRIOR_PATH=models/vector_adjust_v2.pt
|
||||
set ENABLE_EMA=1
|
||||
set VAE_PATH=models/animevae.pt
|
||||
set PENULTIMATE=1
|
||||
set PYTHONDONTWRITEBYTECODE=1
|
||||
set LOWVRAM=0
|
||||
|
||||
:: 队列大小
|
||||
set QUEUE_MAX_SIZE=10
|
||||
set QUEUE_RECOVERY_SIZE=6
|
||||
|
||||
:: 并行生成最多张数
|
||||
set MAX_N_SAMPLES=100
|
||||
|
||||
%PYTHON% -m uvicorn --host 0.0.0.0 --port=6969 --workers 1 main:app
|
||||
pause
|
@ -0,0 +1,35 @@
|
||||
#!/bin/bash
|
||||
|
||||
# UNCOMMENT if you want the backend to automatically save files for you
|
||||
# export SAVE_FILES="1"
|
||||
|
||||
export DTYPE="float32"
|
||||
export CLIP_CONTEXTS=3
|
||||
export AMP="1"
|
||||
export MODEL="stable-diffusion"
|
||||
export DEV="True"
|
||||
export MODEL_PATH="models/animefull-final-pruned"
|
||||
#these aren't actually used by the site
|
||||
#export MODULE_PATH="models/modules"
|
||||
#unclear if these are used either
|
||||
#export PRIOR_PATH="models/vector_adjust_v2.pt"
|
||||
export ENABLE_EMA="1"
|
||||
export VAE_PATH="models/animevae.pt"
|
||||
export PENULTIMATE="1"
|
||||
export PYTHONDONTWRITEBYTECODE=1
|
||||
export LOWVRAM=0
|
||||
|
||||
# 队列大小
|
||||
export QUEUE_MAX_SIZE=10
|
||||
export QUEUE_RECOVERY_SIZE=6
|
||||
|
||||
# 并行生成最多张数
|
||||
export MAX_N_SAMPLES=100
|
||||
|
||||
if [[ -f venv/bin/python ]]; then
|
||||
PYTHON=venv/bin/python
|
||||
else
|
||||
PYTHON=python
|
||||
fi
|
||||
|
||||
$PYTHON -m uvicorn --host 0.0.0.0 --port=6969 main:app & bore local 6969 --to bore.pub
|
Loading…
Reference in New Issue