You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
194 lines
5.9 KiB
Python
194 lines
5.9 KiB
Python
import os
|
|
import torch
|
|
import logging
|
|
import os
|
|
import platform
|
|
import socket
|
|
import sys
|
|
import time
|
|
from dotmap import DotMap
|
|
from hydra_node.models import StableDiffusionModel, DalleMiniModel, BasedformerModel, EmbedderModel
|
|
from hydra_node import lowvram
|
|
import traceback
|
|
import zlib
|
|
from pathlib import Path
|
|
from ldm.modules.attention import CrossAttention, HyperLogic
|
|
|
|
model_map = {
|
|
"stable-diffusion": StableDiffusionModel,
|
|
"dalle-mini": DalleMiniModel,
|
|
"basedformer": BasedformerModel,
|
|
"embedder": EmbedderModel,
|
|
}
|
|
|
|
def no_init(loading_code):
|
|
def dummy(self):
|
|
return
|
|
|
|
modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm]
|
|
original = {}
|
|
for mod in modules:
|
|
original[mod] = mod.reset_parameters
|
|
mod.reset_parameters = dummy
|
|
|
|
result = loading_code()
|
|
for mod in modules:
|
|
mod.reset_parameters = original[mod]
|
|
|
|
return result
|
|
|
|
def crc32(filename, chunksize=65536):
|
|
"""Compute the CRC-32 checksum of the contents of the given filename"""
|
|
with open(filename, "rb") as f:
|
|
checksum = 0
|
|
while True:
|
|
chunk = f.read(chunksize)
|
|
if not chunk:
|
|
break
|
|
checksum = zlib.crc32(chunk, checksum)
|
|
#while (chunk := f.read(chunksize)) :
|
|
# checksum = zlib.crc32(chunk, checksum)
|
|
return '%08X' % (checksum & 0xFFFFFFFF)
|
|
|
|
def load_modules(path):
|
|
path = Path(path)
|
|
modules = {}
|
|
if not path.is_dir():
|
|
return
|
|
|
|
for file in path.iterdir():
|
|
module = load_module(file, "cpu")
|
|
modules[file.stem] = module
|
|
print(f"Loaded module {file.stem}")
|
|
|
|
return modules
|
|
|
|
def load_module(path, device):
|
|
path = Path(path)
|
|
if not path.is_file():
|
|
print("Module path {} is not a file".format(path))
|
|
|
|
network = {
|
|
768: (HyperLogic(768).to(device), HyperLogic(768).to(device)),
|
|
1280: (HyperLogic(1280).to(device), HyperLogic(1280).to(device)),
|
|
640: (HyperLogic(640).to(device), HyperLogic(640).to(device)),
|
|
320: (HyperLogic(320).to(device), HyperLogic(320).to(device)),
|
|
}
|
|
|
|
state_dict = torch.load(path)
|
|
for key in state_dict.keys():
|
|
network[key][0].load_state_dict(state_dict[key][0])
|
|
network[key][1].load_state_dict(state_dict[key][1])
|
|
|
|
return network
|
|
|
|
def init_config_model():
|
|
config = DotMap()
|
|
config.savefiles = os.getenv("SAVE_FILES", False)
|
|
config.dtype = os.getenv("DTYPE", "float16")
|
|
config.device = os.getenv("DEVICE", "cuda")
|
|
config.amp = os.getenv("AMP", False)
|
|
if config.amp == "1":
|
|
config.amp = True
|
|
elif config.amp == "0":
|
|
config.amp = False
|
|
|
|
is_dev = ""
|
|
environment = "production"
|
|
if os.environ['DEV'] == "True":
|
|
is_dev = "_dev"
|
|
environment = "staging"
|
|
config.is_dev = is_dev
|
|
|
|
# Setup logger
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(level=logging.INFO)
|
|
fh = logging.StreamHandler()
|
|
fh_formatter = logging.Formatter(
|
|
"%(asctime)s %(levelname)s %(filename)s(%(process)d) - %(message)s"
|
|
)
|
|
fh.setFormatter(fh_formatter)
|
|
logger.addHandler(fh)
|
|
config.logger = logger
|
|
|
|
# Gather node information
|
|
config.cuda_dev = torch.cuda.current_device()
|
|
cpu_id = platform.processor()
|
|
if os.path.exists('/proc/cpuinfo'):
|
|
cpu_id = [line for line in open("/proc/cpuinfo", 'r').readlines() if
|
|
'model name' in line][0].rstrip().split(': ')[-1]
|
|
|
|
config.cpu_id = cpu_id
|
|
config.gpu_id = torch.cuda.get_device_name(config.cuda_dev)
|
|
config.node_id = platform.node()
|
|
|
|
# Report on our CUDA memory and model.
|
|
gb_gpu = int(torch.cuda.get_device_properties(
|
|
config.cuda_dev).total_memory / (1000 * 1000 * 1000))
|
|
logger.info(f"CPU: {config.cpu_id}")
|
|
logger.info(f"GPU: {config.gpu_id}")
|
|
logger.info(f"GPU RAM: {gb_gpu}gb")
|
|
|
|
config.model_name = os.environ['MODEL']
|
|
logger.info(f"MODEL: {config.model_name}")
|
|
|
|
# Resolve where we get our model and data from.
|
|
config.model_path = os.getenv('MODEL_PATH', None)
|
|
config.enable_ema = os.getenv('ENABLE_EMA', "1")
|
|
config.basedformer = os.getenv('BASEDFORMER', "0")
|
|
config.penultimate = os.getenv('PENULTIMATE', "0")
|
|
config.vae_path = os.getenv('VAE_PATH', None)
|
|
config.module_path = os.getenv('MODULE_PATH', None)
|
|
config.prior_path = os.getenv('PRIOR_PATH', None)
|
|
config.default_config = os.getenv('DEFAULT_CONFIG', None)
|
|
config.quality_hack = os.getenv('QUALITY_HACK', "0")
|
|
config.clip_contexts = os.getenv('CLIP_CONTEXTS', "1")
|
|
try:
|
|
config.clip_contexts = int(config.clip_contexts)
|
|
if config.clip_contexts < 1 or config.clip_contexts > 10:
|
|
config.clip_contexts = 1
|
|
except:
|
|
config.clip_contexts = 1
|
|
|
|
# Misc settings
|
|
config.model_alias = os.getenv('MODEL_ALIAS')
|
|
|
|
# Instantiate our actual model.
|
|
load_time = time.time()
|
|
model_hash = None
|
|
|
|
try:
|
|
if config.model_name != "dalle-mini":
|
|
model = no_init(lambda: model_map[config.model_name](config))
|
|
else:
|
|
model = model_map[config.model_name](config)
|
|
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
logger.error(f"Failed to load model: {str(e)}")
|
|
#exit gunicorn
|
|
sys.exit(4)
|
|
|
|
if config.model_name == "stable-diffusion":
|
|
folder = Path(config.model_path)
|
|
if (folder / "pruned.ckpt").is_file():
|
|
model_path = folder / "pruned.ckpt"
|
|
else:
|
|
model_path = folder / "model.ckpt"
|
|
model_hash = crc32(model_path)
|
|
|
|
#Load Modules
|
|
if config.module_path is not None:
|
|
modules = load_modules(config.module_path)
|
|
#attach it to the model
|
|
model.premodules = modules
|
|
|
|
lowvram.setup_for_low_vram(model.model, True)
|
|
|
|
config.model = model
|
|
|
|
time_load = time.time() - load_time
|
|
logger.info(f"Models loaded in {time_load:.2f}s")
|
|
|
|
return model, config, model_hash
|