You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
825 lines
31 KiB
Python
825 lines
31 KiB
Python
2 years ago
|
import os
|
||
|
import bisect
|
||
|
import json
|
||
|
from re import S
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from pathlib import Path
|
||
|
from omegaconf import OmegaConf
|
||
|
from dotmap import DotMap
|
||
|
import numpy as np
|
||
|
from torch import autocast
|
||
|
from einops import rearrange, repeat
|
||
|
from torchvision.utils import make_grid
|
||
|
from ldm.util import instantiate_from_config
|
||
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||
|
from ldm.models.diffusion.plms import PLMSSampler
|
||
|
from ldm.modules.attention import CrossAttention, HyperLogic
|
||
|
from PIL import Image
|
||
|
import k_diffusion as K
|
||
|
import contextlib
|
||
|
import random
|
||
|
import base64
|
||
|
|
||
|
from .lowvram import setup_for_low_vram
|
||
|
from . import lautocast
|
||
|
|
||
|
class CallbackDelegate:
|
||
|
total_steps = 0
|
||
|
current_step = -1
|
||
|
callback = None
|
||
|
|
||
|
def __init__(self, callback, total_steps) -> None:
|
||
|
self.callback = callback
|
||
|
self.total_steps = total_steps
|
||
|
|
||
|
def update(self, n = None):
|
||
|
self.current_step += 1
|
||
|
if self.callback:
|
||
|
self.callback(self.current_step, self.total_steps)
|
||
|
return n
|
||
|
|
||
|
|
||
|
def seed_everything(seed: int):
|
||
|
torch.manual_seed(seed)
|
||
|
np.random.seed(seed)
|
||
|
random.seed(seed)
|
||
|
|
||
|
def pil_upscale(image, scale=1):
|
||
|
device = image.device
|
||
|
dtype = image.dtype
|
||
|
image = Image.fromarray((image.cpu().permute(1,2,0).numpy().astype(np.float32) * 255.).astype(np.uint8))
|
||
|
if scale > 1:
|
||
|
image = image.resize((int(image.width * scale), int(image.height * scale)), resample=Image.LANCZOS)
|
||
|
image = np.array(image)
|
||
|
image = image.astype(np.float32) / 255.0
|
||
|
image = image[None].transpose(0, 3, 1, 2)
|
||
|
image = torch.from_numpy(image)
|
||
|
image = 2.*image - 1.
|
||
|
image = repeat(image, '1 ... -> b ...', b=1)
|
||
|
return image.to(device)
|
||
|
|
||
|
def fix_batch(tensor, bs):
|
||
|
return torch.stack([tensor.squeeze(0)]*bs, dim=0)
|
||
|
|
||
|
def torch_gc():
|
||
|
if torch.cuda.is_available():
|
||
|
torch.cuda.empty_cache()
|
||
|
torch.cuda.ipc_collect()
|
||
|
|
||
|
# make uc and prompt shapes match via padding for long prompts
|
||
|
# finetune
|
||
|
null_cond = None
|
||
|
def fix_cond_shapes(model, prompt_condition, uc):
|
||
|
global null_cond
|
||
|
if null_cond is None:
|
||
|
null_cond = model.get_learned_conditioning([""])
|
||
|
while prompt_condition.shape[1] > uc.shape[1]:
|
||
|
uc = torch.cat((uc, null_cond.repeat((uc.shape[0], 1, 1))), axis=1)
|
||
|
while prompt_condition.shape[1] < uc.shape[1]:
|
||
|
prompt_condition = torch.cat((prompt_condition, null_cond.repeat((prompt_condition.shape[0], 1, 1))), axis=1)
|
||
|
return prompt_condition, uc
|
||
|
|
||
|
# mix conditioning vectors for prompts
|
||
|
# @aero
|
||
|
def prompt_mixing(model, prompt_body, batch_size):
|
||
|
if "|" in prompt_body:
|
||
|
prompt_parts = prompt_body.split("|")
|
||
|
prompt_total_power = 0
|
||
|
prompt_sum = None
|
||
|
for prompt_part in prompt_parts:
|
||
|
prompt_power = 1
|
||
|
if ":" in prompt_part:
|
||
|
prompt_sub_parts = prompt_part.split(":")
|
||
|
try:
|
||
|
prompt_power = float(prompt_sub_parts[1])
|
||
|
prompt_part = prompt_sub_parts[0]
|
||
|
except:
|
||
|
print("Error parsing prompt power! Assuming 1")
|
||
|
prompt_vector = model.get_learned_conditioning([prompt_part])
|
||
|
if prompt_sum is None:
|
||
|
prompt_sum = prompt_vector * prompt_power
|
||
|
else:
|
||
|
prompt_sum, prompt_vector = fix_cond_shapes(model, prompt_sum, prompt_vector)
|
||
|
prompt_sum = prompt_sum + (prompt_vector * prompt_power)
|
||
|
prompt_total_power = prompt_total_power + prompt_power
|
||
|
return fix_batch(prompt_sum / prompt_total_power, batch_size)
|
||
|
else:
|
||
|
return fix_batch(model.get_learned_conditioning([prompt_body]), batch_size)
|
||
|
|
||
|
def sample_start_noise(seed, C, H, W, f, device="cuda"):
|
||
|
if seed:
|
||
|
gen = torch.Generator(device=device)
|
||
|
gen.manual_seed(seed)
|
||
|
noise = torch.randn([C, (H) // f, (W) // f], generator=gen, device=device).unsqueeze(0)
|
||
|
else:
|
||
|
noise = torch.randn([C, (H) // f, (W) // f], device=device).unsqueeze(0)
|
||
|
return noise
|
||
|
|
||
|
def sample_start_noise_special(seed, request, device="cuda"):
|
||
|
if seed:
|
||
|
gen = torch.Generator(device=device)
|
||
|
gen.manual_seed(seed)
|
||
|
noise = torch.randn([request.latent_channels, request.height // request.downsampling_factor, request.width // request.downsampling_factor], generator=gen, device=device).unsqueeze(0)
|
||
|
else:
|
||
|
noise = torch.randn([request.latent_channels, request.height // request.downsampling_factor, request.width // request.downsampling_factor], device=device).unsqueeze(0)
|
||
|
return noise
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def encode_image(image, model):
|
||
|
if isinstance(image, Image.Image):
|
||
|
image = np.array(image)
|
||
|
image = torch.from_numpy(image)
|
||
|
|
||
|
if isinstance(image, np.ndarray):
|
||
|
image = torch.from_numpy(image)
|
||
|
|
||
|
#dtype = image.dtype
|
||
|
image = image.to(torch.float32)
|
||
|
#gets image as numpy array and returns as tensor
|
||
|
def preprocess_vqgan(x):
|
||
|
x = x / 255.0
|
||
|
x = 2.*x - 1.
|
||
|
return x
|
||
|
|
||
|
image = image.permute(2, 0, 1).unsqueeze(0).float().cuda()
|
||
|
image = preprocess_vqgan(image)
|
||
|
image = model.encode(image).sample()
|
||
|
#image = image.to(dtype)
|
||
|
|
||
|
return image
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def decode_image(image, model):
|
||
|
def custom_to_pil(x):
|
||
|
x = x.detach().float().cpu()
|
||
|
x = torch.clamp(x, -1., 1.)
|
||
|
x = (x + 1.)/2.
|
||
|
x = x.permute(0, 2, 3, 1)#.numpy()
|
||
|
#x = (255*x).astype(np.uint8)
|
||
|
#x = Image.fromarray(x)
|
||
|
#if not x.mode == "RGB":
|
||
|
# x = x.convert("RGB")
|
||
|
return x
|
||
|
|
||
|
image = model.decode(image)
|
||
|
image = custom_to_pil(image)
|
||
|
return image
|
||
|
|
||
|
class VectorAdjustPrior(nn.Module):
|
||
|
def __init__(self, hidden_size, inter_dim=64):
|
||
|
super().__init__()
|
||
|
self.vector_proj = nn.Linear(hidden_size*2, inter_dim, bias=True)
|
||
|
self.out_proj = nn.Linear(hidden_size+inter_dim, hidden_size, bias=True)
|
||
|
|
||
|
def forward(self, z):
|
||
|
b, s = z.shape[0:2]
|
||
|
x1 = torch.mean(z, dim=1).repeat(s, 1)
|
||
|
x2 = z.reshape(b*s, -1)
|
||
|
x = torch.cat((x1, x2), dim=1)
|
||
|
x = self.vector_proj(x)
|
||
|
x = torch.cat((x2, x), dim=1)
|
||
|
x = self.out_proj(x)
|
||
|
x = x.reshape(b, s, -1)
|
||
|
return x
|
||
|
|
||
|
@classmethod
|
||
|
def load_model(cls, model_path, hidden_size=768, inter_dim=64):
|
||
|
model = cls(hidden_size=hidden_size, inter_dim=inter_dim)
|
||
|
model.load_state_dict(torch.load(model_path)["state_dict"])
|
||
|
return model
|
||
|
|
||
|
class StableInterface(nn.Module):
|
||
|
def __init__(self, model, thresholder = None):
|
||
|
super().__init__()
|
||
|
self.inner_model = model
|
||
|
self.sigma_to_t = model.sigma_to_t
|
||
|
self.thresholder = thresholder
|
||
|
self.get_sigmas = model.get_sigmas
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||
|
x_two = torch.cat([x] * 2)
|
||
|
sigma_two = torch.cat([sigma] * 2)
|
||
|
cond_full = torch.cat([uncond, cond])
|
||
|
uncond, cond = self.inner_model(x_two, sigma_two, cond=cond_full).chunk(2)
|
||
|
x_0 = uncond + (cond - uncond) * cond_scale
|
||
|
if self.thresholder is not None:
|
||
|
x_0 = self.thresholder(x_0)
|
||
|
|
||
|
return x_0
|
||
|
|
||
|
class StableDiffusionModel(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
nn.Module.__init__(self)
|
||
|
self.config = config
|
||
|
self.premodules = None
|
||
|
if Path(self.config.model_path).is_dir():
|
||
|
config.logger.info(f"Loading model from folder {self.config.model_path}")
|
||
|
model, model_config = self.from_folder(config.model_path)
|
||
|
|
||
|
elif Path(self.config.model_path).is_file():
|
||
|
config.logger.info(f"Loading model from file {self.config.model_path}")
|
||
|
model, model_config = self.from_file(config.model_path)
|
||
|
|
||
|
else:
|
||
|
raise Exception("Invalid model path!")
|
||
|
|
||
|
if config.dtype == "float16":
|
||
|
typex = torch.float16
|
||
|
else:
|
||
|
typex = torch.float32
|
||
|
|
||
|
print("----------------------------------")
|
||
|
if lautocast.medvram == True:
|
||
|
print("passed lowvram setup for medvram (6G)")
|
||
|
setup_for_low_vram(model, True)
|
||
|
elif lautocast.lowvram == True:
|
||
|
print("setup for lowvram (4G and lower)")
|
||
|
setup_for_low_vram(model, False)
|
||
|
else:
|
||
|
print("loading model")
|
||
|
model.to(config.device)
|
||
|
print("----------------------------------")
|
||
|
self.model = model.to(typex)
|
||
|
# self.model = model.to(config.device).to(typex)
|
||
|
if self.config.vae_path:
|
||
|
ckpt=torch.load(self.config.vae_path, map_location="cpu")
|
||
|
loss = []
|
||
|
for i in ckpt["state_dict"].keys():
|
||
|
if i[0:4] == "loss":
|
||
|
loss.append(i)
|
||
|
for i in loss:
|
||
|
del ckpt["state_dict"][i]
|
||
|
|
||
|
model.first_stage_model = model.first_stage_model.float()
|
||
|
model.first_stage_model.load_state_dict(ckpt["state_dict"])
|
||
|
model.first_stage_model = model.first_stage_model.float()
|
||
|
del ckpt
|
||
|
del loss
|
||
|
config.logger.info(f"Using VAE from {self.config.vae_path}")
|
||
|
|
||
|
if self.config.penultimate == "1":
|
||
|
model.cond_stage_model.return_layer = -2
|
||
|
model.cond_stage_model.do_final_ln = True
|
||
|
config.logger.info(f"CLIP: Using penultimate layer")
|
||
|
|
||
|
if self.config.clip_contexts > 1:
|
||
|
model.cond_stage_model.clip_extend = True
|
||
|
model.cond_stage_model.max_clip_extend = 75 * self.config.clip_contexts
|
||
|
|
||
|
model.cond_stage_model.inference_mode = True
|
||
|
self.k_model = K.external.CompVisDenoiser(model)
|
||
|
self.k_model = StableInterface(self.k_model)
|
||
|
self.device = config.device
|
||
|
self.model_config = model_config
|
||
|
self.plms = PLMSSampler(model)
|
||
|
self.ddim = DDIMSampler(model)
|
||
|
self.ema_manager = self.model.ema_scope
|
||
|
if self.config.enable_ema == "0":
|
||
|
self.ema_manager = contextlib.nullcontext
|
||
|
config.logger.info("Disabling EMA")
|
||
|
else:
|
||
|
config.logger.info(f"Using EMA")
|
||
|
self.sampler_map = {
|
||
|
'plms': self.plms.sample,
|
||
|
'ddim': self.ddim.sample,
|
||
|
'k_euler': K.sampling.sample_euler,
|
||
|
'k_euler_ancestral': K.sampling.sample_euler_ancestral,
|
||
|
'k_heun': K.sampling.sample_heun,
|
||
|
'k_dpm_2': K.sampling.sample_dpm_2,
|
||
|
'k_dpm_2_ancestral': K.sampling.sample_dpm_2_ancestral,
|
||
|
'k_lms': K.sampling.sample_lms,
|
||
|
}
|
||
|
if config.prior_path:
|
||
|
self.prior = VectorAdjustPrior.load_model(config.prior_path).to(self.device)
|
||
|
self.copied_ema = False
|
||
|
|
||
|
@property
|
||
|
def get_default_config(self):
|
||
|
dict_config = {
|
||
|
'steps': 30,
|
||
|
'sampler': "k_euler_ancestral",
|
||
|
'n_samples': 1,
|
||
|
'image': None,
|
||
|
'fixed_code': False,
|
||
|
'ddim_eta': 0.0,
|
||
|
'height': 512,
|
||
|
'width': 512,
|
||
|
'latent_channels': 4,
|
||
|
'downsampling_factor': 8,
|
||
|
'scale': 12.0,
|
||
|
'dynamic_threshold': None,
|
||
|
'seed': None,
|
||
|
'stage_two_seed': None,
|
||
|
'module': None,
|
||
|
'masks': None,
|
||
|
'output': None,
|
||
|
}
|
||
|
return DotMap(dict_config)
|
||
|
|
||
|
def from_folder(self, folder):
|
||
|
folder = Path(folder)
|
||
|
model_config = OmegaConf.load(folder / "config.yaml")
|
||
|
if (folder / "pruned.ckpt").is_file():
|
||
|
model_path = folder / "pruned.ckpt"
|
||
|
else:
|
||
|
model_path = folder / "model.ckpt"
|
||
|
model = self.load_model_from_config(model_config, model_path)
|
||
|
return model, model_config
|
||
|
|
||
|
def from_path(self, file):
|
||
|
default_config = Path(self.config.default_config)
|
||
|
if not default_config.is_file():
|
||
|
raise Exception("Default config to load not found! Either give a folder on MODEL_PATH or specify a config to use with this checkpoint on DEFAULT_CONFIG")
|
||
|
model_config = OmegaConf.load(default_config)
|
||
|
model = self.load_model_from_config(model_config, file)
|
||
|
return model, model_config
|
||
|
|
||
|
def load_model_from_config(self, config, ckpt, verbose=False):
|
||
|
print(f"Loading model from {ckpt}")
|
||
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||
|
if "global_step" in pl_sd:
|
||
|
print(f"Global Step: {pl_sd['global_step']}")
|
||
|
|
||
|
sd = pl_sd.get('state_dict', pl_sd)
|
||
|
|
||
|
model = instantiate_from_config(config.model)
|
||
|
m, u = model.load_state_dict(sd, strict=False)
|
||
|
if len(m) > 0 and verbose:
|
||
|
print("missing keys:")
|
||
|
print(m)
|
||
|
if len(u) > 0 and verbose:
|
||
|
print("unexpected keys:")
|
||
|
print(u)
|
||
|
|
||
|
model.eval()
|
||
|
return model
|
||
|
|
||
|
@torch.no_grad()
|
||
|
@torch.autocast("cuda", enabled=False, dtype=lautocast.dtype)
|
||
|
def sample(self, request, callback = None):
|
||
|
if request.module is not None:
|
||
|
if request.module == "vanilla":
|
||
|
pass
|
||
|
|
||
|
else:
|
||
|
module = self.premodules[request.module]
|
||
|
CrossAttention.set_hypernetwork(module)
|
||
|
|
||
|
if request.seed is not None:
|
||
|
seed_everything(request.seed)
|
||
|
|
||
|
if request.image is not None:
|
||
|
request.steps = 50
|
||
|
#request.sampler = "ddim_img2img" #enforce ddim for now
|
||
|
if request.sampler == "plms":
|
||
|
request.sampler = "k_lms"
|
||
|
if request.sampler == "ddim":
|
||
|
request.sampler = "k_lms"
|
||
|
|
||
|
self.ddim.make_schedule(ddim_num_steps=request.steps, ddim_eta=request.ddim_eta, verbose=False)
|
||
|
start_code = encode_image(request.image, self.model.first_stage_model).to(self.device)
|
||
|
start_code = self.model.get_first_stage_encoding(start_code)
|
||
|
start_code = torch.repeat_interleave(start_code, request.n_samples, dim=0)
|
||
|
|
||
|
main_noise = []
|
||
|
start_noise = []
|
||
|
for seed in range(request.seed, request.seed+request.n_samples):
|
||
|
main_noise.append(sample_start_noise(seed, request.latent_channels, request.height, request.width, request.downsampling_factor, self.device))
|
||
|
start_noise.append(sample_start_noise(seed, request.latent_channels, request.height, request.width, request.downsampling_factor, self.device))
|
||
|
|
||
|
main_noise = torch.cat(main_noise, dim=0)
|
||
|
start_noise = torch.cat(start_noise, dim=0)
|
||
|
|
||
|
start_code = start_code + (start_noise * request.noise)
|
||
|
t_enc = int(request.strength * request.steps)
|
||
|
|
||
|
if request.sampler.startswith("k_"):
|
||
|
sampler = "k-diffusion"
|
||
|
|
||
|
elif request.sampler == 'ddim_img2img':
|
||
|
sampler = 'img2img'
|
||
|
|
||
|
else:
|
||
|
sampler = "normal"
|
||
|
|
||
|
if request.image is None:
|
||
|
main_noise = []
|
||
|
for seed_offset in range(request.n_samples):
|
||
|
if request.masks is not None:
|
||
|
noise_x = sample_start_noise_special(request.seed, request, self.device)
|
||
|
else:
|
||
|
noise_x = sample_start_noise_special(request.seed+seed_offset, request, self.device)
|
||
|
|
||
|
if request.masks is not None:
|
||
|
for maskobj in request.masks:
|
||
|
mask_seed = maskobj["seed"]
|
||
|
mask = maskobj["mask"]
|
||
|
mask = np.asarray(mask)
|
||
|
mask = torch.from_numpy(mask).clone().to(self.device).permute(2, 0, 1)
|
||
|
mask = mask.float() / 255.0
|
||
|
# convert RGB or grayscale image into 4-channel
|
||
|
mask = mask[0].unsqueeze(0)
|
||
|
mask = torch.repeat_interleave(mask, request.latent_channels, dim=0)
|
||
|
mask = (mask < 0.5).float()
|
||
|
|
||
|
# interpolate start noise
|
||
|
noise_x = (noise_x * (1-mask)) + (sample_start_noise_special(mask_seed+seed_offset, request, self.device) * mask)
|
||
|
|
||
|
main_noise.append(noise_x)
|
||
|
|
||
|
main_noise = torch.cat(main_noise, dim=0)
|
||
|
start_code = main_noise
|
||
|
|
||
|
prompt = [request.prompt]
|
||
|
prompt_condition = prompt_mixing(self.model, prompt[0], 1)
|
||
|
if hasattr(self, "prior") and request.mitigate:
|
||
|
prompt_condition = self.prior(prompt_condition)
|
||
|
|
||
|
uc = None
|
||
|
if request.scale != 1.0:
|
||
|
if request.uc is not None:
|
||
|
uc = [request.uc]
|
||
|
uc = prompt_mixing(self.model, uc[0], 1)
|
||
|
else:
|
||
|
if self.config.quality_hack == "1":
|
||
|
uc = ["Tags: lowres"]
|
||
|
uc = prompt_mixing(self.model, uc[0], 1)
|
||
|
else:
|
||
|
uc = self.model.get_learned_conditioning([""])
|
||
|
prompt_condition, uc = fix_cond_shapes(self.model, prompt_condition, uc)
|
||
|
|
||
|
shape = [
|
||
|
request.latent_channels,
|
||
|
request.height // request.downsampling_factor,
|
||
|
request.width // request.downsampling_factor
|
||
|
]
|
||
|
|
||
|
c_dele = CallbackDelegate(callback, request.steps * request.n_samples)
|
||
|
|
||
|
# handle images one at a time because batches eat absurd VRAM
|
||
|
sampless = []
|
||
|
for main_noise, start_code in zip(main_noise.chunk(request.n_samples), start_code.chunk(request.n_samples)):
|
||
|
if sampler == "normal":
|
||
|
with self.ema_manager():
|
||
|
c_dele.update()
|
||
|
samples, _ = self.sampler_map[request.sampler](
|
||
|
S=request.steps,
|
||
|
conditioning=prompt_condition,
|
||
|
batch_size=1,
|
||
|
shape=shape,
|
||
|
verbose=False,
|
||
|
unconditional_guidance_scale=request.scale,
|
||
|
unconditional_conditioning=uc,
|
||
|
eta=request.ddim_eta,
|
||
|
dynamic_threshold=request.dynamic_threshold,
|
||
|
x_T=start_code,
|
||
|
callback=c_dele.update
|
||
|
)
|
||
|
|
||
|
elif sampler == "k-diffusion":
|
||
|
with self.ema_manager():
|
||
|
sigmas = self.k_model.get_sigmas(request.steps)
|
||
|
if request.image is not None:
|
||
|
noise = main_noise * sigmas[request.steps - t_enc - 1]
|
||
|
start_code = start_code + noise
|
||
|
sigmas = sigmas[request.steps - t_enc - 1:]
|
||
|
|
||
|
else:
|
||
|
start_code = start_code * sigmas[0]
|
||
|
|
||
|
extra_args = {'cond': prompt_condition, 'uncond': uc, 'cond_scale': request.scale}
|
||
|
c_dele.update()
|
||
|
samples = self.sampler_map[request.sampler](
|
||
|
self.k_model,
|
||
|
start_code,
|
||
|
sigmas,
|
||
|
request.seed,
|
||
|
callback=c_dele.update,
|
||
|
extra_args=extra_args
|
||
|
)
|
||
|
|
||
|
sampless.append(samples)
|
||
|
torch_gc()
|
||
|
|
||
|
images = []
|
||
|
for samples in sampless:
|
||
|
with torch.autocast("cuda", enabled=self.config.amp):
|
||
|
x_samples_ddim = self.model.decode_first_stage(samples.float())
|
||
|
#x_samples_ddim = decode_image(samples, self.model.first_stage_model)
|
||
|
#x_samples_ddim = self.model.first_stage_model.decode(samples.float())
|
||
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||
|
|
||
|
for x_sample in x_samples_ddim:
|
||
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||
|
x_sample = x_sample.astype(np.uint8)
|
||
|
x_sample = np.ascontiguousarray(x_sample)
|
||
|
images.append(x_sample)
|
||
|
|
||
|
torch_gc()
|
||
|
|
||
|
if request.seed is not None:
|
||
|
torch.seed()
|
||
|
np.random.seed()
|
||
|
|
||
|
#set hypernetwork to none after generation
|
||
|
CrossAttention.set_hypernetwork(None)
|
||
|
|
||
|
return images
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def sample_two_stages(self, request, callback = None):
|
||
|
request = DotMap(request)
|
||
|
if request.seed is not None:
|
||
|
seed_everything(request.seed)
|
||
|
|
||
|
if request.plms:
|
||
|
sampler = self.plms
|
||
|
else:
|
||
|
sampler = self.ddim
|
||
|
|
||
|
start_code = None
|
||
|
if request.fixed_code:
|
||
|
start_code = torch.randn([
|
||
|
request.n_samples,
|
||
|
request.latent_channels,
|
||
|
request.height // request.downsampling_factor,
|
||
|
request.width // request.downsampling_factor,
|
||
|
], device=self.device)
|
||
|
|
||
|
prompt = [request.prompt] * request.n_samples
|
||
|
prompt_condition = prompt_mixing(self.model, prompt[0], request.n_samples)
|
||
|
|
||
|
uc = None
|
||
|
if request.scale != 1.0:
|
||
|
uc = self.model.get_learned_conditioning(request.n_samples * [""])
|
||
|
prompt_condition, uc = fix_cond_shapes(self.model, prompt_condition, uc)
|
||
|
|
||
|
shape = [
|
||
|
request.latent_channels,
|
||
|
request.height // request.downsampling_factor,
|
||
|
request.width // request.downsampling_factor
|
||
|
]
|
||
|
|
||
|
c_dele = CallbackDelegate(callback, request.steps * request.n_samples)
|
||
|
|
||
|
with torch.autocast("cuda", enabled=self.config.amp):
|
||
|
with self.ema_manager():
|
||
|
samples, _ = sampler.sample(
|
||
|
S=request.steps,
|
||
|
conditioning=prompt_condition,
|
||
|
batch_size=request.n_samples,
|
||
|
shape=shape,
|
||
|
verbose=False,
|
||
|
unconditional_guidance_scale=request.scale,
|
||
|
unconditional_conditioning=uc,
|
||
|
eta=request.ddim_eta,
|
||
|
dynamic_threshold=request.dynamic_threshold,
|
||
|
x_T=start_code,
|
||
|
callback=c_dele.update
|
||
|
)
|
||
|
|
||
|
x_samples_ddim = self.model.decode_first_stage(samples)
|
||
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).squeeze(0)
|
||
|
x_samples_ddim = pil_upscale(x_samples_ddim, scale=2)
|
||
|
|
||
|
if request.stage_two_seed is not None:
|
||
|
torch.manual_seed(request.stage_two_seed)
|
||
|
np.random.seed(request.stage_two_seed)
|
||
|
|
||
|
with torch.autocast("cuda", enabled=self.config.amp):
|
||
|
with self.ema_manager():
|
||
|
init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(x_samples_ddim))
|
||
|
self.ddim.make_schedule(ddim_num_steps=request.steps, ddim_eta=request.ddim_eta, verbose=False)
|
||
|
t_enc = int(request.strength * request.steps)
|
||
|
|
||
|
print("init latent shape:")
|
||
|
print(init_latent.shape)
|
||
|
|
||
|
init_latent = init_latent + (torch.randn_like(init_latent) * request.noise)
|
||
|
|
||
|
prompt_condition = prompt_mixing(self.model, prompt[0], request.n_samples)
|
||
|
|
||
|
uc = None
|
||
|
if request.scale != 1.0:
|
||
|
uc = self.model.get_learned_conditioning(request.n_samples * [""])
|
||
|
prompt_condition, uc = fix_cond_shapes(self.model, prompt_condition, uc)
|
||
|
|
||
|
# encode (scaled latent)
|
||
|
start_code_terped=None
|
||
|
z_enc = self.ddim.stochastic_encode(init_latent, torch.tensor([t_enc]*request.n_samples).to(self.device), noise=start_code_terped)
|
||
|
# decode it
|
||
|
samples = self.ddim.decode(z_enc, prompt_condition, t_enc, unconditional_guidance_scale=request.scale,
|
||
|
unconditional_conditioning=uc,)
|
||
|
|
||
|
x_samples_ddim = self.model.decode_first_stage(samples)
|
||
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||
|
|
||
|
images = []
|
||
|
for x_sample in x_samples_ddim:
|
||
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||
|
x_sample = x_sample.astype(np.uint8)
|
||
|
x_sample = np.ascontiguousarray(x_sample)
|
||
|
images.append(x_sample)
|
||
|
|
||
|
if request.seed is not None:
|
||
|
torch.seed()
|
||
|
np.random.seed()
|
||
|
|
||
|
return images
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def sample_from_image(self, request):
|
||
|
return
|
||
|
|
||
|
class DalleMiniModel(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
nn.Module.__init__(self)
|
||
|
from min_dalle import MinDalle
|
||
|
|
||
|
self.config = config
|
||
|
self.model = MinDalle(
|
||
|
models_root=config.model_path,
|
||
|
dtype=torch.float16,
|
||
|
device='cuda',
|
||
|
is_mega=True,
|
||
|
is_reusable=True
|
||
|
)
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def sample(self, request, callback = None):
|
||
|
if request.seed is not None:
|
||
|
seed = request.seed
|
||
|
else:
|
||
|
seed = -1
|
||
|
|
||
|
images = self.model.generate_images(
|
||
|
text=request.prompt,
|
||
|
seed=seed,
|
||
|
grid_size=request.grid_size,
|
||
|
is_seamless=False,
|
||
|
temperature=request.temp,
|
||
|
top_k=request.top_k,
|
||
|
supercondition_factor=request.scale,
|
||
|
is_verbose=False
|
||
|
)
|
||
|
images = images.to('cpu').numpy()
|
||
|
images = images.astype(np.uint8)
|
||
|
images = np.ascontiguousarray(images)
|
||
|
|
||
|
if request.seed is not None:
|
||
|
torch.seed()
|
||
|
np.random.seed()
|
||
|
|
||
|
return images
|
||
|
|
||
|
def apply_temp(logits, temperature):
|
||
|
logits = logits / temperature
|
||
|
return logits
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def generate(forward, prompt_tokens, tokenizer, tokens_to_generate=50, ds=False, ops_list=[{"temp": 0.9}], hypernetwork=None, non_deterministic=False, fully_deterministic=False):
|
||
|
in_tokens = prompt_tokens
|
||
|
context = prompt_tokens
|
||
|
generated = torch.zeros(len(ops_list), 0, dtype=torch.long).to(in_tokens.device)
|
||
|
kv = None
|
||
|
if non_deterministic:
|
||
|
torch.seed()
|
||
|
#soft_required = ["top_k", "top_p"]
|
||
|
op_map = {
|
||
|
"temp": apply_temp,
|
||
|
}
|
||
|
|
||
|
for _ in range(tokens_to_generate):
|
||
|
if ds:
|
||
|
logits, kv = forward(in_tokens, past_key_values=kv, use_cache=True)
|
||
|
else:
|
||
|
logits, kv = forward(in_tokens, cache=True, kv=kv, hypernetwork=hypernetwork)
|
||
|
logits = logits[:, -1, :] #get the last token in the seq
|
||
|
logits = torch.log_softmax(logits, dim=-1)
|
||
|
|
||
|
batch = []
|
||
|
for i, ops in enumerate(ops_list):
|
||
|
item = logits[i, ...].unsqueeze(0)
|
||
|
ctx = context[i, ...].unsqueeze(0)
|
||
|
for op, value in ops.items():
|
||
|
if op == "rep_pen":
|
||
|
item = op_map[op](ctx, item, **value)
|
||
|
|
||
|
else:
|
||
|
item = op_map[op](item, value)
|
||
|
|
||
|
batch.append(item)
|
||
|
|
||
|
logits = torch.cat(batch, dim=0)
|
||
|
logits = torch.softmax(logits, dim=-1)
|
||
|
|
||
|
#fully_deterministic makes it deterministic across the batch
|
||
|
if fully_deterministic:
|
||
|
logits = logits.split(1, dim=0)
|
||
|
logit_list = []
|
||
|
for logit in logits:
|
||
|
torch.manual_seed(69)
|
||
|
logit_list.append(torch.multinomial(logit, 1))
|
||
|
|
||
|
logits = torch.cat(logit_list, dim=0)
|
||
|
|
||
|
else:
|
||
|
logits = torch.multinomial(logits, 1)
|
||
|
|
||
|
if logits[0, 0] == 48585:
|
||
|
if generated[0, -1] == 1400:
|
||
|
pass
|
||
|
elif generated[0, -1] == 3363:
|
||
|
return "safe", "none"
|
||
|
else:
|
||
|
return "notsafe", tokenizer.decode(generated.squeeze()).split("Output: ")[-1]
|
||
|
|
||
|
generated = torch.cat([generated, logits], dim=-1)
|
||
|
context = torch.cat([context, logits], dim=-1)
|
||
|
in_tokens = logits
|
||
|
|
||
|
return "unknown", tokenizer.decode(generated.squeeze())
|
||
|
|
||
|
|
||
|
class BasedformerModel(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
nn.Module.__init__(self)
|
||
|
from basedformer import lm_utils
|
||
|
from transformers import GPT2TokenizerFast
|
||
|
self.config = config
|
||
|
self.model = lm_utils.load_from_path(config.model_path).half().cuda()
|
||
|
self.model = self.model.convert_to_ds()
|
||
|
self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def sample(self, request, callback = None):
|
||
|
prompt = request.prompt
|
||
|
prompt = self.tokenizer.encode("Input: " + prompt, return_tensors='pt').cuda().long()
|
||
|
prompt = torch.cat([prompt, torch.tensor([[49527]], dtype=torch.long).cuda()], dim=1)
|
||
|
is_safe, corrected = generate(self.model.module, prompt, self.tokenizer, tokens_to_generate=150, ds=True)
|
||
|
return is_safe, corrected
|
||
|
|
||
|
class EmbedderModel(nn.Module):
|
||
|
def __init__(self, config=None):
|
||
|
nn.Module.__init__(self)
|
||
|
from sentence_transformers import SentenceTransformer
|
||
|
self.model = SentenceTransformer('./models/sentence-transformers_all-MiniLM-L6-v2').cuda()
|
||
|
self.tags = [tuple(x) for x in json.load(open("models/tags.json"))]
|
||
|
self.knn = self.load_knn("models/tags.index")
|
||
|
print("Loaded tag suggestion model using phrase embeddings")
|
||
|
|
||
|
def load_knn(self, filename):
|
||
|
import faiss
|
||
|
try:
|
||
|
return faiss.read_index(filename)
|
||
|
except RuntimeError:
|
||
|
print(f"Generating tag embedding index for {len(self.tags)} tags.")
|
||
|
i = faiss.IndexFlatL2(384)
|
||
|
i.add(self([name for name, count in self.tags]))
|
||
|
faiss.write_index(i, filename)
|
||
|
return i
|
||
|
|
||
|
def __call__(self, sentences):
|
||
|
with torch.no_grad():
|
||
|
sentence_embeddings = self.model.encode(sentences)
|
||
|
return sentence_embeddings
|
||
|
|
||
|
def get_top_k(self, text):
|
||
|
#check if text is a substring in tag_count.keys()
|
||
|
found = []
|
||
|
a = bisect.bisect_left(self.tags, (text,))
|
||
|
b = bisect.bisect_left(self.tags, (text + '\xff',), lo=a)
|
||
|
for tag, count in self.tags[a:b]:
|
||
|
if len(tag) >= len(text) and tag.startswith(text):
|
||
|
found.append([tag, count, 0])
|
||
|
|
||
|
results = []
|
||
|
embedding = self([text])
|
||
|
k = 15
|
||
|
D, I = self.knn.search(embedding, k)
|
||
|
D, I = D.squeeze(), I.squeeze()
|
||
|
for id, prob in zip(I, D):
|
||
|
tag, count = self.tags[id]
|
||
|
results.append([tag, count, prob])
|
||
|
|
||
|
found.sort(key=lambda x: x[1], reverse=True)
|
||
|
found = found[:5]
|
||
|
# found = heapq.nlargest(5, found, key=lambda x: x[1])
|
||
|
results_tags = [x[0] for x in found]
|
||
|
for result in results.copy():
|
||
|
if result[0] in results_tags:
|
||
|
results.remove(result)
|
||
|
|
||
|
results = sorted(results, key=lambda x: x[2], reverse=True)
|
||
|
#filter results for >0.5 confidence unless it has the search text in it and confidence is >0.4
|
||
|
results = [x for x in results if x[2] > 0.5 or (x[2] > 0.4 and text in x[0])]
|
||
|
if found:
|
||
|
results = found + results
|
||
|
|
||
|
#max 10 results
|
||
|
results = results[:10]
|
||
|
results = sorted(results, key=lambda x: x[1], reverse=True)
|
||
|
return results
|