import os import bisect import json from re import S import torch import torch.nn as nn from pathlib import Path from omegaconf import OmegaConf from dotmap import DotMap import numpy as np from torch import autocast from einops import rearrange, repeat from torchvision.utils import make_grid from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler from ldm.modules.attention import CrossAttention, HyperLogic from PIL import Image import k_diffusion as K import contextlib import random import base64 from .lowvram import setup_for_low_vram from . import lautocast class CallbackDelegate: total_steps = 0 current_step = -1 callback = None def __init__(self, callback, total_steps) -> None: self.callback = callback self.total_steps = total_steps def update(self, n = None): self.current_step += 1 if self.callback: self.callback(self.current_step, self.total_steps) return n def seed_everything(seed: int): torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) def pil_upscale(image, scale=1): device = image.device dtype = image.dtype image = Image.fromarray((image.cpu().permute(1,2,0).numpy().astype(np.float32) * 255.).astype(np.uint8)) if scale > 1: image = image.resize((int(image.width * scale), int(image.height * scale)), resample=Image.LANCZOS) image = np.array(image) image = image.astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) image = 2.*image - 1. image = repeat(image, '1 ... -> b ...', b=1) return image.to(device) def fix_batch(tensor, bs): return torch.stack([tensor.squeeze(0)]*bs, dim=0) def torch_gc(): if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() # make uc and prompt shapes match via padding for long prompts # finetune null_cond = None def fix_cond_shapes(model, prompt_condition, uc): global null_cond if null_cond is None: null_cond = model.get_learned_conditioning([""]) while prompt_condition.shape[1] > uc.shape[1]: uc = torch.cat((uc, null_cond.repeat((uc.shape[0], 1, 1))), axis=1) while prompt_condition.shape[1] < uc.shape[1]: prompt_condition = torch.cat((prompt_condition, null_cond.repeat((prompt_condition.shape[0], 1, 1))), axis=1) return prompt_condition, uc # mix conditioning vectors for prompts # @aero def prompt_mixing(model, prompt_body, batch_size): if "|" in prompt_body: prompt_parts = prompt_body.split("|") prompt_total_power = 0 prompt_sum = None for prompt_part in prompt_parts: prompt_power = 1 if ":" in prompt_part: prompt_sub_parts = prompt_part.split(":") try: prompt_power = float(prompt_sub_parts[1]) prompt_part = prompt_sub_parts[0] except: print("Error parsing prompt power! Assuming 1") prompt_vector = model.get_learned_conditioning([prompt_part]) if prompt_sum is None: prompt_sum = prompt_vector * prompt_power else: prompt_sum, prompt_vector = fix_cond_shapes(model, prompt_sum, prompt_vector) prompt_sum = prompt_sum + (prompt_vector * prompt_power) prompt_total_power = prompt_total_power + prompt_power return fix_batch(prompt_sum / prompt_total_power, batch_size) else: return fix_batch(model.get_learned_conditioning([prompt_body]), batch_size) def sample_start_noise(seed, C, H, W, f, device="cuda"): if seed: gen = torch.Generator(device=device) gen.manual_seed(seed) noise = torch.randn([C, (H) // f, (W) // f], generator=gen, device=device).unsqueeze(0) else: noise = torch.randn([C, (H) // f, (W) // f], device=device).unsqueeze(0) return noise def sample_start_noise_special(seed, request, device="cuda"): if seed: gen = torch.Generator(device=device) gen.manual_seed(seed) noise = torch.randn([request.latent_channels, request.height // request.downsampling_factor, request.width // request.downsampling_factor], generator=gen, device=device).unsqueeze(0) else: noise = torch.randn([request.latent_channels, request.height // request.downsampling_factor, request.width // request.downsampling_factor], device=device).unsqueeze(0) return noise @torch.no_grad() def encode_image(image, model): if isinstance(image, Image.Image): image = np.array(image) image = torch.from_numpy(image) if isinstance(image, np.ndarray): image = torch.from_numpy(image) #dtype = image.dtype image = image.to(torch.float32) #gets image as numpy array and returns as tensor def preprocess_vqgan(x): x = x / 255.0 x = 2.*x - 1. return x image = image.permute(2, 0, 1).unsqueeze(0).float().cuda() image = preprocess_vqgan(image) image = model.encode(image).sample() #image = image.to(dtype) return image @torch.no_grad() def decode_image(image, model): def custom_to_pil(x): x = x.detach().float().cpu() x = torch.clamp(x, -1., 1.) x = (x + 1.)/2. x = x.permute(0, 2, 3, 1)#.numpy() #x = (255*x).astype(np.uint8) #x = Image.fromarray(x) #if not x.mode == "RGB": # x = x.convert("RGB") return x image = model.decode(image) image = custom_to_pil(image) return image class VectorAdjustPrior(nn.Module): def __init__(self, hidden_size, inter_dim=64): super().__init__() self.vector_proj = nn.Linear(hidden_size*2, inter_dim, bias=True) self.out_proj = nn.Linear(hidden_size+inter_dim, hidden_size, bias=True) def forward(self, z): b, s = z.shape[0:2] x1 = torch.mean(z, dim=1).repeat(s, 1) x2 = z.reshape(b*s, -1) x = torch.cat((x1, x2), dim=1) x = self.vector_proj(x) x = torch.cat((x2, x), dim=1) x = self.out_proj(x) x = x.reshape(b, s, -1) return x @classmethod def load_model(cls, model_path, hidden_size=768, inter_dim=64): model = cls(hidden_size=hidden_size, inter_dim=inter_dim) model.load_state_dict(torch.load(model_path)["state_dict"]) return model class StableInterface(nn.Module): def __init__(self, model, thresholder = None): super().__init__() self.inner_model = model self.sigma_to_t = model.sigma_to_t self.thresholder = thresholder self.get_sigmas = model.get_sigmas @torch.no_grad() def forward(self, x, sigma, uncond, cond, cond_scale): x_two = torch.cat([x] * 2) sigma_two = torch.cat([sigma] * 2) cond_full = torch.cat([uncond, cond]) uncond, cond = self.inner_model(x_two, sigma_two, cond=cond_full).chunk(2) x_0 = uncond + (cond - uncond) * cond_scale if self.thresholder is not None: x_0 = self.thresholder(x_0) return x_0 class StableDiffusionModel(nn.Module): def __init__(self, config): nn.Module.__init__(self) self.config = config self.premodules = None if Path(self.config.model_path).is_dir(): config.logger.info(f"Loading model from folder {self.config.model_path}") model, model_config = self.from_folder(config.model_path) elif Path(self.config.model_path).is_file(): config.logger.info(f"Loading model from file {self.config.model_path}") model, model_config = self.from_file(config.model_path) else: raise Exception("Invalid model path!") if config.dtype == "float16": typex = torch.float16 else: typex = torch.float32 print("----------------------------------") if lautocast.medvram == True: print("passed lowvram setup for medvram (6G)") setup_for_low_vram(model, True) elif lautocast.lowvram == True: print("setup for lowvram (4G and lower)") setup_for_low_vram(model, False) else: print("loading model") model.to(config.device) print("----------------------------------") self.model = model.to(typex) # self.model = model.to(config.device).to(typex) if self.config.vae_path: ckpt=torch.load(self.config.vae_path, map_location="cpu") loss = [] for i in ckpt["state_dict"].keys(): if i[0:4] == "loss": loss.append(i) for i in loss: del ckpt["state_dict"][i] model.first_stage_model = model.first_stage_model.float() model.first_stage_model.load_state_dict(ckpt["state_dict"]) model.first_stage_model = model.first_stage_model.float() del ckpt del loss config.logger.info(f"Using VAE from {self.config.vae_path}") if self.config.penultimate == "1": model.cond_stage_model.return_layer = -2 model.cond_stage_model.do_final_ln = True config.logger.info(f"CLIP: Using penultimate layer") if self.config.clip_contexts > 1: model.cond_stage_model.clip_extend = True model.cond_stage_model.max_clip_extend = 75 * self.config.clip_contexts model.cond_stage_model.inference_mode = True self.k_model = K.external.CompVisDenoiser(model) self.k_model = StableInterface(self.k_model) self.device = config.device self.model_config = model_config self.plms = PLMSSampler(model) self.ddim = DDIMSampler(model) self.ema_manager = self.model.ema_scope if self.config.enable_ema == "0": self.ema_manager = contextlib.nullcontext config.logger.info("Disabling EMA") else: config.logger.info(f"Using EMA") self.sampler_map = { 'plms': self.plms.sample, 'ddim': self.ddim.sample, 'k_euler': K.sampling.sample_euler, 'k_euler_ancestral': K.sampling.sample_euler_ancestral, 'k_heun': K.sampling.sample_heun, 'k_dpm_2': K.sampling.sample_dpm_2, 'k_dpm_2_ancestral': K.sampling.sample_dpm_2_ancestral, 'k_lms': K.sampling.sample_lms, } if config.prior_path: self.prior = VectorAdjustPrior.load_model(config.prior_path).to(self.device) self.copied_ema = False @property def get_default_config(self): dict_config = { 'steps': 30, 'sampler': "k_euler_ancestral", 'n_samples': 1, 'image': None, 'fixed_code': False, 'ddim_eta': 0.0, 'height': 512, 'width': 512, 'latent_channels': 4, 'downsampling_factor': 8, 'scale': 12.0, 'dynamic_threshold': None, 'seed': None, 'stage_two_seed': None, 'module': None, 'masks': None, 'output': None, } return DotMap(dict_config) def from_folder(self, folder): folder = Path(folder) model_config = OmegaConf.load(folder / "config.yaml") if (folder / "pruned.ckpt").is_file(): model_path = folder / "pruned.ckpt" else: model_path = folder / "model.ckpt" model = self.load_model_from_config(model_config, model_path) return model, model_config def from_path(self, file): default_config = Path(self.config.default_config) if not default_config.is_file(): raise Exception("Default config to load not found! Either give a folder on MODEL_PATH or specify a config to use with this checkpoint on DEFAULT_CONFIG") model_config = OmegaConf.load(default_config) model = self.load_model_from_config(model_config, file) return model, model_config def load_model_from_config(self, config, ckpt, verbose=False): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd.get('state_dict', pl_sd) model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("missing keys:") print(m) if len(u) > 0 and verbose: print("unexpected keys:") print(u) model.eval() return model @torch.no_grad() @torch.autocast("cuda", enabled=False, dtype=lautocast.dtype) def sample(self, request, callback = None): if request.module is not None: if request.module == "vanilla": pass else: module = self.premodules[request.module] CrossAttention.set_hypernetwork(module) if request.seed is not None: seed_everything(request.seed) if request.image is not None: request.steps = 50 #request.sampler = "ddim_img2img" #enforce ddim for now if request.sampler == "plms": request.sampler = "k_lms" if request.sampler == "ddim": request.sampler = "k_lms" self.ddim.make_schedule(ddim_num_steps=request.steps, ddim_eta=request.ddim_eta, verbose=False) start_code = encode_image(request.image, self.model.first_stage_model).to(self.device) start_code = self.model.get_first_stage_encoding(start_code) start_code = torch.repeat_interleave(start_code, request.n_samples, dim=0) main_noise = [] start_noise = [] for seed in range(request.seed, request.seed+request.n_samples): main_noise.append(sample_start_noise(seed, request.latent_channels, request.height, request.width, request.downsampling_factor, self.device)) start_noise.append(sample_start_noise(seed, request.latent_channels, request.height, request.width, request.downsampling_factor, self.device)) main_noise = torch.cat(main_noise, dim=0) start_noise = torch.cat(start_noise, dim=0) start_code = start_code + (start_noise * request.noise) t_enc = int(request.strength * request.steps) if request.sampler.startswith("k_"): sampler = "k-diffusion" elif request.sampler == 'ddim_img2img': sampler = 'img2img' else: sampler = "normal" if request.image is None: main_noise = [] for seed_offset in range(request.n_samples): if request.masks is not None: noise_x = sample_start_noise_special(request.seed, request, self.device) else: noise_x = sample_start_noise_special(request.seed+seed_offset, request, self.device) if request.masks is not None: for maskobj in request.masks: mask_seed = maskobj["seed"] mask = maskobj["mask"] mask = np.asarray(mask) mask = torch.from_numpy(mask).clone().to(self.device).permute(2, 0, 1) mask = mask.float() / 255.0 # convert RGB or grayscale image into 4-channel mask = mask[0].unsqueeze(0) mask = torch.repeat_interleave(mask, request.latent_channels, dim=0) mask = (mask < 0.5).float() # interpolate start noise noise_x = (noise_x * (1-mask)) + (sample_start_noise_special(mask_seed+seed_offset, request, self.device) * mask) main_noise.append(noise_x) main_noise = torch.cat(main_noise, dim=0) start_code = main_noise prompt = [request.prompt] prompt_condition = prompt_mixing(self.model, prompt[0], 1) if hasattr(self, "prior") and request.mitigate: prompt_condition = self.prior(prompt_condition) uc = None if request.scale != 1.0: if request.uc is not None: uc = [request.uc] uc = prompt_mixing(self.model, uc[0], 1) else: if self.config.quality_hack == "1": uc = ["Tags: lowres"] uc = prompt_mixing(self.model, uc[0], 1) else: uc = self.model.get_learned_conditioning([""]) prompt_condition, uc = fix_cond_shapes(self.model, prompt_condition, uc) shape = [ request.latent_channels, request.height // request.downsampling_factor, request.width // request.downsampling_factor ] c_dele = CallbackDelegate(callback, request.steps * request.n_samples) # handle images one at a time because batches eat absurd VRAM sampless = [] for main_noise, start_code in zip(main_noise.chunk(request.n_samples), start_code.chunk(request.n_samples)): if sampler == "normal": with self.ema_manager(): c_dele.update() samples, _ = self.sampler_map[request.sampler]( S=request.steps, conditioning=prompt_condition, batch_size=1, shape=shape, verbose=False, unconditional_guidance_scale=request.scale, unconditional_conditioning=uc, eta=request.ddim_eta, dynamic_threshold=request.dynamic_threshold, x_T=start_code, callback=c_dele.update ) elif sampler == "k-diffusion": with self.ema_manager(): sigmas = self.k_model.get_sigmas(request.steps) if request.image is not None: noise = main_noise * sigmas[request.steps - t_enc - 1] start_code = start_code + noise sigmas = sigmas[request.steps - t_enc - 1:] else: start_code = start_code * sigmas[0] extra_args = {'cond': prompt_condition, 'uncond': uc, 'cond_scale': request.scale} c_dele.update() samples = self.sampler_map[request.sampler]( self.k_model, start_code, sigmas, request.seed, callback=c_dele.update, extra_args=extra_args ) sampless.append(samples) torch_gc() images = [] for samples in sampless: with torch.autocast("cuda", enabled=self.config.amp): x_samples_ddim = self.model.decode_first_stage(samples.float()) #x_samples_ddim = decode_image(samples, self.model.first_stage_model) #x_samples_ddim = self.model.first_stage_model.decode(samples.float()) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) for x_sample in x_samples_ddim: x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') x_sample = x_sample.astype(np.uint8) x_sample = np.ascontiguousarray(x_sample) images.append(x_sample) torch_gc() if request.seed is not None: torch.seed() np.random.seed() #set hypernetwork to none after generation CrossAttention.set_hypernetwork(None) return images @torch.no_grad() def sample_two_stages(self, request, callback = None): request = DotMap(request) if request.seed is not None: seed_everything(request.seed) if request.plms: sampler = self.plms else: sampler = self.ddim start_code = None if request.fixed_code: start_code = torch.randn([ request.n_samples, request.latent_channels, request.height // request.downsampling_factor, request.width // request.downsampling_factor, ], device=self.device) prompt = [request.prompt] * request.n_samples prompt_condition = prompt_mixing(self.model, prompt[0], request.n_samples) uc = None if request.scale != 1.0: uc = self.model.get_learned_conditioning(request.n_samples * [""]) prompt_condition, uc = fix_cond_shapes(self.model, prompt_condition, uc) shape = [ request.latent_channels, request.height // request.downsampling_factor, request.width // request.downsampling_factor ] c_dele = CallbackDelegate(callback, request.steps * request.n_samples) with torch.autocast("cuda", enabled=self.config.amp): with self.ema_manager(): samples, _ = sampler.sample( S=request.steps, conditioning=prompt_condition, batch_size=request.n_samples, shape=shape, verbose=False, unconditional_guidance_scale=request.scale, unconditional_conditioning=uc, eta=request.ddim_eta, dynamic_threshold=request.dynamic_threshold, x_T=start_code, callback=c_dele.update ) x_samples_ddim = self.model.decode_first_stage(samples) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).squeeze(0) x_samples_ddim = pil_upscale(x_samples_ddim, scale=2) if request.stage_two_seed is not None: torch.manual_seed(request.stage_two_seed) np.random.seed(request.stage_two_seed) with torch.autocast("cuda", enabled=self.config.amp): with self.ema_manager(): init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(x_samples_ddim)) self.ddim.make_schedule(ddim_num_steps=request.steps, ddim_eta=request.ddim_eta, verbose=False) t_enc = int(request.strength * request.steps) print("init latent shape:") print(init_latent.shape) init_latent = init_latent + (torch.randn_like(init_latent) * request.noise) prompt_condition = prompt_mixing(self.model, prompt[0], request.n_samples) uc = None if request.scale != 1.0: uc = self.model.get_learned_conditioning(request.n_samples * [""]) prompt_condition, uc = fix_cond_shapes(self.model, prompt_condition, uc) # encode (scaled latent) start_code_terped=None z_enc = self.ddim.stochastic_encode(init_latent, torch.tensor([t_enc]*request.n_samples).to(self.device), noise=start_code_terped) # decode it samples = self.ddim.decode(z_enc, prompt_condition, t_enc, unconditional_guidance_scale=request.scale, unconditional_conditioning=uc,) x_samples_ddim = self.model.decode_first_stage(samples) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) images = [] for x_sample in x_samples_ddim: x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') x_sample = x_sample.astype(np.uint8) x_sample = np.ascontiguousarray(x_sample) images.append(x_sample) if request.seed is not None: torch.seed() np.random.seed() return images @torch.no_grad() def sample_from_image(self, request): return class DalleMiniModel(nn.Module): def __init__(self, config): nn.Module.__init__(self) from min_dalle import MinDalle self.config = config self.model = MinDalle( models_root=config.model_path, dtype=torch.float16, device='cuda', is_mega=True, is_reusable=True ) @torch.no_grad() def sample(self, request, callback = None): if request.seed is not None: seed = request.seed else: seed = -1 images = self.model.generate_images( text=request.prompt, seed=seed, grid_size=request.grid_size, is_seamless=False, temperature=request.temp, top_k=request.top_k, supercondition_factor=request.scale, is_verbose=False ) images = images.to('cpu').numpy() images = images.astype(np.uint8) images = np.ascontiguousarray(images) if request.seed is not None: torch.seed() np.random.seed() return images def apply_temp(logits, temperature): logits = logits / temperature return logits @torch.no_grad() def generate(forward, prompt_tokens, tokenizer, tokens_to_generate=50, ds=False, ops_list=[{"temp": 0.9}], hypernetwork=None, non_deterministic=False, fully_deterministic=False): in_tokens = prompt_tokens context = prompt_tokens generated = torch.zeros(len(ops_list), 0, dtype=torch.long).to(in_tokens.device) kv = None if non_deterministic: torch.seed() #soft_required = ["top_k", "top_p"] op_map = { "temp": apply_temp, } for _ in range(tokens_to_generate): if ds: logits, kv = forward(in_tokens, past_key_values=kv, use_cache=True) else: logits, kv = forward(in_tokens, cache=True, kv=kv, hypernetwork=hypernetwork) logits = logits[:, -1, :] #get the last token in the seq logits = torch.log_softmax(logits, dim=-1) batch = [] for i, ops in enumerate(ops_list): item = logits[i, ...].unsqueeze(0) ctx = context[i, ...].unsqueeze(0) for op, value in ops.items(): if op == "rep_pen": item = op_map[op](ctx, item, **value) else: item = op_map[op](item, value) batch.append(item) logits = torch.cat(batch, dim=0) logits = torch.softmax(logits, dim=-1) #fully_deterministic makes it deterministic across the batch if fully_deterministic: logits = logits.split(1, dim=0) logit_list = [] for logit in logits: torch.manual_seed(69) logit_list.append(torch.multinomial(logit, 1)) logits = torch.cat(logit_list, dim=0) else: logits = torch.multinomial(logits, 1) if logits[0, 0] == 48585: if generated[0, -1] == 1400: pass elif generated[0, -1] == 3363: return "safe", "none" else: return "notsafe", tokenizer.decode(generated.squeeze()).split("Output: ")[-1] generated = torch.cat([generated, logits], dim=-1) context = torch.cat([context, logits], dim=-1) in_tokens = logits return "unknown", tokenizer.decode(generated.squeeze()) class BasedformerModel(nn.Module): def __init__(self, config): nn.Module.__init__(self) from basedformer import lm_utils from transformers import GPT2TokenizerFast self.config = config self.model = lm_utils.load_from_path(config.model_path).half().cuda() self.model = self.model.convert_to_ds() self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") @torch.no_grad() def sample(self, request, callback = None): prompt = request.prompt prompt = self.tokenizer.encode("Input: " + prompt, return_tensors='pt').cuda().long() prompt = torch.cat([prompt, torch.tensor([[49527]], dtype=torch.long).cuda()], dim=1) is_safe, corrected = generate(self.model.module, prompt, self.tokenizer, tokens_to_generate=150, ds=True) return is_safe, corrected class EmbedderModel(nn.Module): def __init__(self, config=None): nn.Module.__init__(self) from sentence_transformers import SentenceTransformer self.model = SentenceTransformer('./models/sentence-transformers_all-MiniLM-L6-v2').cuda() self.tags = [tuple(x) for x in json.load(open("models/tags.json"))] self.knn = self.load_knn("models/tags.index") print("Loaded tag suggestion model using phrase embeddings") def load_knn(self, filename): import faiss try: return faiss.read_index(filename) except RuntimeError: print(f"Generating tag embedding index for {len(self.tags)} tags.") i = faiss.IndexFlatL2(384) i.add(self([name for name, count in self.tags])) faiss.write_index(i, filename) return i def __call__(self, sentences): with torch.no_grad(): sentence_embeddings = self.model.encode(sentences) return sentence_embeddings def get_top_k(self, text): #check if text is a substring in tag_count.keys() found = [] a = bisect.bisect_left(self.tags, (text,)) b = bisect.bisect_left(self.tags, (text + '\xff',), lo=a) for tag, count in self.tags[a:b]: if len(tag) >= len(text) and tag.startswith(text): found.append([tag, count, 0]) results = [] embedding = self([text]) k = 15 D, I = self.knn.search(embedding, k) D, I = D.squeeze(), I.squeeze() for id, prob in zip(I, D): tag, count = self.tags[id] results.append([tag, count, prob]) found.sort(key=lambda x: x[1], reverse=True) found = found[:5] # found = heapq.nlargest(5, found, key=lambda x: x[1]) results_tags = [x[0] for x in found] for result in results.copy(): if result[0] in results_tags: results.remove(result) results = sorted(results, key=lambda x: x[2], reverse=True) #filter results for >0.5 confidence unless it has the search text in it and confidence is >0.4 results = [x for x in results if x[2] > 0.5 or (x[2] > 0.4 and text in x[0])] if found: results = found + results #max 10 results results = results[:10] results = sorted(results, key=lambda x: x[1], reverse=True) return results