You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
222 lines
6.5 KiB
Python
222 lines
6.5 KiB
Python
import traceback
|
|
from dotmap import DotMap
|
|
import math
|
|
from io import BytesIO
|
|
import base64
|
|
import random
|
|
|
|
v1pp_defaults = {
|
|
'steps': 50,
|
|
'sampler': "plms",
|
|
'image': None,
|
|
'fixed_code': False,
|
|
'ddim_eta': 0.0,
|
|
'height': 512,
|
|
'width': 512,
|
|
'latent_channels': 4,
|
|
'downsampling_factor': 8,
|
|
'scale': 7.0,
|
|
'dynamic_threshold': None,
|
|
'seed': None,
|
|
'stage_two_seed': None,
|
|
'module': None,
|
|
'masks': None,
|
|
}
|
|
|
|
v1pp_forced_defaults = {
|
|
'latent_channels': 4,
|
|
'downsampling_factor': 8,
|
|
}
|
|
|
|
dalle_mini_defaults = {
|
|
'temp': 1.0,
|
|
'top_k': 256,
|
|
'scale': 16,
|
|
'grid_size': 4,
|
|
}
|
|
|
|
dalle_mini_forced_defaults = {
|
|
}
|
|
|
|
defaults = {
|
|
'stable-diffusion': (v1pp_defaults, v1pp_forced_defaults),
|
|
'dalle-mini': (dalle_mini_defaults, dalle_mini_forced_defaults),
|
|
'basedformer': ({}, {}),
|
|
'embedder': ({}, {}),
|
|
}
|
|
|
|
samplers = [
|
|
"plms",
|
|
"ddim",
|
|
"k_euler",
|
|
"k_euler_ancestral",
|
|
"k_heun",
|
|
"k_dpm_2",
|
|
"k_dpm_2_ancestral",
|
|
"k_lms"
|
|
]
|
|
|
|
def closest_multiple(num, mult):
|
|
num_int = int(num)
|
|
floor = math.floor(num_int / mult) * mult
|
|
ceil = math.ceil(num_int / mult) * mult
|
|
return floor if (num_int - floor) < (ceil - num_int) else ceil
|
|
|
|
def sanitize_stable_diffusion(request, config):
|
|
if request.steps > 50:
|
|
return False, "steps must be smaller than 50"
|
|
|
|
if request.width * request.height == 0:
|
|
return False, "width and height must be non-zero"
|
|
|
|
if request.width <= 0:
|
|
return False, "width must be positive"
|
|
|
|
if request.height <= 0:
|
|
return False, "height must be positive"
|
|
|
|
if request.steps <= 0:
|
|
return False, "steps must be positive"
|
|
|
|
if request.ddim_eta < 0:
|
|
return False, "ddim_eta shouldn't be negative"
|
|
|
|
if request.scale < 1.0:
|
|
return False, "scale should be at least 1.0"
|
|
|
|
if request.dynamic_threshold is not None and request.dynamic_threshold < 0:
|
|
return False, "dynamic_threshold shouldn't be negative"
|
|
|
|
if request.width * request.height >= 1024*1025:
|
|
return False, "width and height must be less than 1024*1025"
|
|
|
|
if request.strength < 0.0 or request.strength >= 1.0:
|
|
return False, "strength should be more than 0.0 and less than 1.0"
|
|
|
|
if request.noise < 0.0 or request.noise > 1.0:
|
|
return False, "noise should be more than 0.0 and less than 1.0"
|
|
|
|
if request.advanced:
|
|
request.width = closest_multiple(request.width // 2, 64)
|
|
request.height = closest_multiple(request.height // 2, 64)
|
|
|
|
if request.sampler not in samplers:
|
|
return False, "sampler should be one of {}".format(samplers)
|
|
|
|
if request.seed is None:
|
|
state = random.getstate()
|
|
request.seed = random.randint(0, 2**32)
|
|
random.setstate(state)
|
|
|
|
if request.module is not None:
|
|
if request.module not in config.model.premodules and request.module != "vanilla":
|
|
return False, "module should be one of: " + ", ".join(config.model.premodules)
|
|
|
|
max_gens = 100
|
|
if 0:
|
|
num_gen_tiers = [(1024*512, 4), (640*640, 6), (704*512, 8), (512*512, 16), (384*640, 18)]
|
|
pixel_count = request.width * request.height
|
|
for tier in num_gen_tiers:
|
|
if pixel_count <= tier[0]:
|
|
max_gens = tier[1]
|
|
else:
|
|
break
|
|
if request.n_samples > max_gens:
|
|
return False, f"requested more ({request.n_samples}) images than possible at this resolution"
|
|
|
|
if request.image is not None:
|
|
#decode from base64
|
|
try:
|
|
request.image = base64.b64decode(request.image.encode('utf-8'))
|
|
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
return False, "image is not valid base64"
|
|
#check if image is valid
|
|
try:
|
|
from PIL import Image
|
|
image = Image.open(BytesIO(request.image))
|
|
image.verify()
|
|
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
return False, "image is not valid"
|
|
|
|
#image is valid, load it again(still check again, verify() can't be sure as it doesn't decode.)
|
|
try:
|
|
image = Image.open(BytesIO(request.image))
|
|
image = image.convert('RGB')
|
|
image = image.resize((request.width, request.height), resample=Image.Resampling.LANCZOS)
|
|
request.image = image
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
return False, "Error while opening and cleaning image"
|
|
|
|
if request.masks is not None:
|
|
masks = request.masks
|
|
for x in range(len(masks)):
|
|
image = masks[x]["mask"]
|
|
try:
|
|
image_bytes = base64.b64decode(image.encode('utf-8'))
|
|
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
return False, "image is not valid base64"
|
|
|
|
try:
|
|
from PIL import Image
|
|
image = Image.open(BytesIO(image_bytes))
|
|
image.verify()
|
|
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
return False, "image is not valid"
|
|
|
|
#image is valid, load it again(still check again, verify() can't be sure as it doesn't decode.)
|
|
try:
|
|
image = Image.open(BytesIO(image_bytes))
|
|
#image = image.convert('RGB')
|
|
image = image.resize((request.width//request.downsampling_factor, request.height//request.downsampling_factor), resample=Image.Resampling.LANCZOS)
|
|
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
return False, "Error while opening and cleaning image"
|
|
|
|
masks[x]["mask"] = image
|
|
|
|
return True, request
|
|
|
|
def sanitize_dalle_mini(request):
|
|
return True, request
|
|
|
|
def sanitize_basedformer(request):
|
|
return True, request
|
|
|
|
def sanitize_embedder(request):
|
|
return True, request
|
|
|
|
def sanitize_input(config, request):
|
|
"""
|
|
Sanitize the input data and set defaults
|
|
"""
|
|
request = DotMap(request)
|
|
default, forced_default = defaults[config.model_name]
|
|
for k, v in default.items():
|
|
if k not in request:
|
|
request[k] = v
|
|
|
|
for k, v in forced_default.items():
|
|
request[k] = v
|
|
|
|
if config.model_name == 'stable-diffusion':
|
|
return sanitize_stable_diffusion(request, config)
|
|
|
|
elif config.model_name == 'dalle-mini':
|
|
return sanitize_dalle_mini(request)
|
|
|
|
elif config.model_name == 'basedformer':
|
|
return sanitize_basedformer(request)
|
|
|
|
elif config.model_name == "embedder":
|
|
return sanitize_embedder(request)
|