mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 05:57:20 +00:00
Refactor of sampler code to deal more easily with different model types.
This commit is contained in:
parent
ac9c038ac2
commit
3ded1a3a04
@ -180,7 +180,6 @@ class NoiseScheduleVP:
|
|||||||
|
|
||||||
def model_wrapper(
|
def model_wrapper(
|
||||||
model,
|
model,
|
||||||
sampling_function,
|
|
||||||
noise_schedule,
|
noise_schedule,
|
||||||
model_type="noise",
|
model_type="noise",
|
||||||
model_kwargs={},
|
model_kwargs={},
|
||||||
@ -295,7 +294,7 @@ def model_wrapper(
|
|||||||
if t_continuous.reshape((-1,)).shape[0] == 1:
|
if t_continuous.reshape((-1,)).shape[0] == 1:
|
||||||
t_continuous = t_continuous.expand((x.shape[0]))
|
t_continuous = t_continuous.expand((x.shape[0]))
|
||||||
t_input = get_model_input_time(t_continuous)
|
t_input = get_model_input_time(t_continuous)
|
||||||
output = sampling_function(model, x, t_input, **model_kwargs)
|
output = model(x, t_input, **model_kwargs)
|
||||||
if model_type == "noise":
|
if model_type == "noise":
|
||||||
return output
|
return output
|
||||||
elif model_type == "x_start":
|
elif model_type == "x_start":
|
||||||
@ -843,10 +842,12 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
|
|||||||
else:
|
else:
|
||||||
timesteps = sigmas.clone()
|
timesteps = sigmas.clone()
|
||||||
|
|
||||||
for s in range(timesteps.shape[0]):
|
alphas_cumprod = model.inner_model.alphas_cumprod
|
||||||
timesteps[s] = (model.sigma_to_t(timesteps[s]) / 1000) + (1 / len(model.sigmas))
|
|
||||||
|
|
||||||
ns = NoiseScheduleVP('discrete', alphas_cumprod=model.inner_model.alphas_cumprod)
|
for s in range(timesteps.shape[0]):
|
||||||
|
timesteps[s] = (model.sigma_to_discrete_timestep(timesteps[s]) / 1000) + (1 / len(alphas_cumprod))
|
||||||
|
|
||||||
|
ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
||||||
|
|
||||||
if image is not None:
|
if image is not None:
|
||||||
img = image * ns.marginal_alpha(timesteps[0])
|
img = image * ns.marginal_alpha(timesteps[0])
|
||||||
@ -859,18 +860,15 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
|
|||||||
img = noise
|
img = noise
|
||||||
|
|
||||||
if to_zero:
|
if to_zero:
|
||||||
timesteps[-1] = (1 / len(model.sigmas))
|
timesteps[-1] = (1 / len(alphas_cumprod))
|
||||||
|
|
||||||
device = noise.device
|
device = noise.device
|
||||||
|
|
||||||
if model.parameterization == "v":
|
|
||||||
model_type = "v"
|
model_type = "noise"
|
||||||
else:
|
|
||||||
model_type = "noise"
|
|
||||||
|
|
||||||
model_fn = model_wrapper(
|
model_fn = model_wrapper(
|
||||||
model.inner_model.inner_model.apply_model,
|
model.predict_eps_discrete_timestep,
|
||||||
sampling_function,
|
|
||||||
ns,
|
ns,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
guidance_type="uncond",
|
guidance_type="uncond",
|
||||||
|
@ -63,12 +63,17 @@ class DiscreteSchedule(nn.Module):
|
|||||||
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
|
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
|
||||||
return sampling.append_zero(self.t_to_sigma(t))
|
return sampling.append_zero(self.t_to_sigma(t))
|
||||||
|
|
||||||
def sigma_to_t(self, sigma, quantize=None):
|
def sigma_to_discrete_timestep(self, sigma):
|
||||||
quantize = self.quantize if quantize is None else quantize
|
|
||||||
log_sigma = sigma.log()
|
log_sigma = sigma.log()
|
||||||
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
||||||
|
return dists.abs().argmin(dim=0).view(sigma.shape)
|
||||||
|
|
||||||
|
def sigma_to_t(self, sigma, quantize=None):
|
||||||
|
quantize = self.quantize if quantize is None else quantize
|
||||||
if quantize:
|
if quantize:
|
||||||
return dists.abs().argmin(dim=0).view(sigma.shape)
|
return self.sigma_to_discrete_timestep(sigma)
|
||||||
|
log_sigma = sigma.log()
|
||||||
|
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
||||||
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
||||||
high_idx = low_idx + 1
|
high_idx = low_idx + 1
|
||||||
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
|
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
|
||||||
@ -85,6 +90,10 @@ class DiscreteSchedule(nn.Module):
|
|||||||
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
||||||
return log_sigma.exp()
|
return log_sigma.exp()
|
||||||
|
|
||||||
|
def predict_eps_discrete_timestep(self, input, t, **kwargs):
|
||||||
|
sigma = self.t_to_sigma(t.round())
|
||||||
|
input = input * ((sigma ** 2 + 1.0) ** 0.5)
|
||||||
|
return (input - self(input, sigma, **kwargs)) / sigma
|
||||||
|
|
||||||
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
|
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
|
||||||
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
|
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
|
||||||
|
@ -14,6 +14,7 @@ class DDIMSampler(object):
|
|||||||
self.ddpm_num_timesteps = model.num_timesteps
|
self.ddpm_num_timesteps = model.num_timesteps
|
||||||
self.schedule = schedule
|
self.schedule = schedule
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.parameterization = kwargs.get("parameterization", "eps")
|
||||||
|
|
||||||
def register_buffer(self, name, attr):
|
def register_buffer(self, name, attr):
|
||||||
if type(attr) == torch.Tensor:
|
if type(attr) == torch.Tensor:
|
||||||
@ -261,7 +262,7 @@ class DDIMSampler(object):
|
|||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
if denoise_function is not None:
|
if denoise_function is not None:
|
||||||
model_output = denoise_function(self.model.apply_model, x, t, **extra_args)
|
model_output = denoise_function(x, t, **extra_args)
|
||||||
elif unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
elif unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||||
model_output = self.model.apply_model(x, t, c)
|
model_output = self.model.apply_model(x, t, c)
|
||||||
else:
|
else:
|
||||||
@ -289,13 +290,13 @@ class DDIMSampler(object):
|
|||||||
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||||
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
|
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
|
||||||
|
|
||||||
if self.model.parameterization == "v":
|
if self.parameterization == "v":
|
||||||
e_t = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * model_output + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
|
e_t = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * model_output + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
|
||||||
else:
|
else:
|
||||||
e_t = model_output
|
e_t = model_output
|
||||||
|
|
||||||
if score_corrector is not None:
|
if score_corrector is not None:
|
||||||
assert self.model.parameterization == "eps", 'not implemented'
|
assert self.parameterization == "eps", 'not implemented'
|
||||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||||
|
|
||||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||||
@ -309,7 +310,7 @@ class DDIMSampler(object):
|
|||||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||||
|
|
||||||
# current prediction for x_0
|
# current prediction for x_0
|
||||||
if self.model.parameterization != "v":
|
if self.parameterization != "v":
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
else:
|
else:
|
||||||
pred_x0 = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * x - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * model_output
|
pred_x0 = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * x - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * model_output
|
||||||
|
@ -4,10 +4,15 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
|
|||||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||||
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from enum import Enum
|
||||||
from . import utils
|
from . import utils
|
||||||
|
|
||||||
|
class ModelType(Enum):
|
||||||
|
EPS = 1
|
||||||
|
V_PREDICTION = 2
|
||||||
|
|
||||||
class BaseModel(torch.nn.Module):
|
class BaseModel(torch.nn.Module):
|
||||||
def __init__(self, model_config, v_prediction=False):
|
def __init__(self, model_config, model_type=ModelType.EPS):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
unet_config = model_config.unet_config
|
unet_config = model_config.unet_config
|
||||||
@ -15,16 +20,11 @@ class BaseModel(torch.nn.Module):
|
|||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
||||||
self.diffusion_model = UNetModel(**unet_config)
|
self.diffusion_model = UNetModel(**unet_config)
|
||||||
self.v_prediction = v_prediction
|
self.model_type = model_type
|
||||||
if self.v_prediction:
|
|
||||||
self.parameterization = "v"
|
|
||||||
else:
|
|
||||||
self.parameterization = "eps"
|
|
||||||
|
|
||||||
self.adm_channels = unet_config.get("adm_in_channels", None)
|
self.adm_channels = unet_config.get("adm_in_channels", None)
|
||||||
if self.adm_channels is None:
|
if self.adm_channels is None:
|
||||||
self.adm_channels = 0
|
self.adm_channels = 0
|
||||||
print("v_prediction", v_prediction)
|
print("model_type", model_type.name)
|
||||||
print("adm", self.adm_channels)
|
print("adm", self.adm_channels)
|
||||||
|
|
||||||
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||||
@ -103,8 +103,8 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class SD21UNCLIP(BaseModel):
|
class SD21UNCLIP(BaseModel):
|
||||||
def __init__(self, model_config, noise_aug_config, v_prediction=True):
|
def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION):
|
||||||
super().__init__(model_config, v_prediction)
|
super().__init__(model_config, model_type)
|
||||||
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)
|
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)
|
||||||
|
|
||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
@ -139,13 +139,13 @@ class SD21UNCLIP(BaseModel):
|
|||||||
return adm_out
|
return adm_out
|
||||||
|
|
||||||
class SDInpaint(BaseModel):
|
class SDInpaint(BaseModel):
|
||||||
def __init__(self, model_config, v_prediction=False):
|
def __init__(self, model_config, model_type=ModelType.EPS):
|
||||||
super().__init__(model_config, v_prediction)
|
super().__init__(model_config, model_type)
|
||||||
self.concat_keys = ("mask", "masked_image")
|
self.concat_keys = ("mask", "masked_image")
|
||||||
|
|
||||||
class SDXLRefiner(BaseModel):
|
class SDXLRefiner(BaseModel):
|
||||||
def __init__(self, model_config, v_prediction=False):
|
def __init__(self, model_config, model_type=ModelType.EPS):
|
||||||
super().__init__(model_config, v_prediction)
|
super().__init__(model_config, model_type)
|
||||||
self.embedder = Timestep(256)
|
self.embedder = Timestep(256)
|
||||||
|
|
||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
@ -171,8 +171,8 @@ class SDXLRefiner(BaseModel):
|
|||||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||||
|
|
||||||
class SDXL(BaseModel):
|
class SDXL(BaseModel):
|
||||||
def __init__(self, model_config, v_prediction=False):
|
def __init__(self, model_config, model_type=ModelType.EPS):
|
||||||
super().__init__(model_config, v_prediction)
|
super().__init__(model_config, model_type)
|
||||||
self.embedder = Timestep(256)
|
self.embedder = Timestep(256)
|
||||||
|
|
||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
|
@ -6,6 +6,7 @@ from comfy import model_management
|
|||||||
from .ldm.models.diffusion.ddim import DDIMSampler
|
from .ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
||||||
import math
|
import math
|
||||||
|
from comfy import model_base
|
||||||
|
|
||||||
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
||||||
return abs(a*b) // math.gcd(a, b)
|
return abs(a*b) // math.gcd(a, b)
|
||||||
@ -488,11 +489,11 @@ class KSampler:
|
|||||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
|
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.model_denoise = CFGNoisePredictor(self.model)
|
self.model_denoise = CFGNoisePredictor(self.model)
|
||||||
if self.model.parameterization == "v":
|
if self.model.model_type == model_base.ModelType.V_PREDICTION:
|
||||||
self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True)
|
self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True)
|
||||||
else:
|
else:
|
||||||
self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True)
|
self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True)
|
||||||
self.model_wrap.parameterization = self.model.parameterization
|
|
||||||
self.model_k = KSamplerX0Inpaint(self.model_wrap)
|
self.model_k = KSamplerX0Inpaint(self.model_wrap)
|
||||||
self.device = device
|
self.device = device
|
||||||
if scheduler not in self.SCHEDULERS:
|
if scheduler not in self.SCHEDULERS:
|
||||||
@ -614,7 +615,7 @@ class KSampler:
|
|||||||
elif self.sampler == "ddim":
|
elif self.sampler == "ddim":
|
||||||
timesteps = []
|
timesteps = []
|
||||||
for s in range(sigmas.shape[0]):
|
for s in range(sigmas.shape[0]):
|
||||||
timesteps.insert(0, self.model_wrap.sigma_to_t(sigmas[s]))
|
timesteps.insert(0, self.model_wrap.sigma_to_discrete_timestep(sigmas[s]))
|
||||||
noise_mask = None
|
noise_mask = None
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
noise_mask = 1.0 - denoise_mask
|
noise_mask = 1.0 - denoise_mask
|
||||||
@ -638,7 +639,7 @@ class KSampler:
|
|||||||
x_T=z_enc,
|
x_T=z_enc,
|
||||||
x0=latent_image,
|
x0=latent_image,
|
||||||
img_callback=ddim_callback,
|
img_callback=ddim_callback,
|
||||||
denoise_function=sampling_function,
|
denoise_function=self.model_wrap.predict_eps_discrete_timestep,
|
||||||
extra_args=extra_args,
|
extra_args=extra_args,
|
||||||
mask=noise_mask,
|
mask=noise_mask,
|
||||||
to_zero=sigmas[-1]==0,
|
to_zero=sigmas[-1]==0,
|
||||||
|
10
comfy/sd.py
10
comfy/sd.py
@ -1008,11 +1008,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
if "noise_aug_config" in model_config_params:
|
if "noise_aug_config" in model_config_params:
|
||||||
noise_aug_config = model_config_params["noise_aug_config"]
|
noise_aug_config = model_config_params["noise_aug_config"]
|
||||||
|
|
||||||
v_prediction = False
|
model_type = model_base.ModelType.EPS
|
||||||
|
|
||||||
if "parameterization" in model_config_params:
|
if "parameterization" in model_config_params:
|
||||||
if model_config_params["parameterization"] == "v":
|
if model_config_params["parameterization"] == "v":
|
||||||
v_prediction = True
|
model_type = model_base.ModelType.V_PREDICTION
|
||||||
|
|
||||||
clip = None
|
clip = None
|
||||||
vae = None
|
vae = None
|
||||||
@ -1032,11 +1032,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
|
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
|
||||||
|
|
||||||
if config['model']["target"].endswith("LatentInpaintDiffusion"):
|
if config['model']["target"].endswith("LatentInpaintDiffusion"):
|
||||||
model = model_base.SDInpaint(model_config, v_prediction=v_prediction)
|
model = model_base.SDInpaint(model_config, model_type=model_type)
|
||||||
elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
|
elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
|
||||||
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], v_prediction=v_prediction)
|
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type)
|
||||||
else:
|
else:
|
||||||
model = model_base.BaseModel(model_config, v_prediction=v_prediction)
|
model = model_base.BaseModel(model_config, model_type=model_type)
|
||||||
|
|
||||||
if fp16:
|
if fp16:
|
||||||
model = model.half()
|
model = model.half()
|
||||||
|
@ -53,13 +53,13 @@ class SD20(supported_models_base.BASE):
|
|||||||
|
|
||||||
latent_format = latent_formats.SD15
|
latent_format = latent_formats.SD15
|
||||||
|
|
||||||
def v_prediction(self, state_dict, prefix=""):
|
def model_type(self, state_dict, prefix=""):
|
||||||
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
|
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
|
||||||
k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
|
k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
|
||||||
out = state_dict[k]
|
out = state_dict[k]
|
||||||
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
|
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
|
||||||
return True
|
return model_base.ModelType.V_PREDICTION
|
||||||
return False
|
return model_base.ModelType.EPS
|
||||||
|
|
||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
|
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
|
||||||
@ -145,8 +145,14 @@ class SDXL(supported_models_base.BASE):
|
|||||||
|
|
||||||
latent_format = latent_formats.SDXL
|
latent_format = latent_formats.SDXL
|
||||||
|
|
||||||
|
def model_type(self, state_dict, prefix=""):
|
||||||
|
if "v_pred" in state_dict:
|
||||||
|
return model_base.ModelType.V_PREDICTION
|
||||||
|
else:
|
||||||
|
return model_base.ModelType.EPS
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix=""):
|
def get_model(self, state_dict, prefix=""):
|
||||||
return model_base.SDXL(self)
|
return model_base.SDXL(self, model_type=self.model_type(state_dict, prefix))
|
||||||
|
|
||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
keys_to_replace = {}
|
keys_to_replace = {}
|
||||||
|
@ -41,8 +41,8 @@ class BASE:
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def v_prediction(self, state_dict, prefix=""):
|
def model_type(self, state_dict, prefix=""):
|
||||||
return False
|
return model_base.ModelType.EPS
|
||||||
|
|
||||||
def inpaint_model(self):
|
def inpaint_model(self):
|
||||||
return self.unet_config["in_channels"] > 4
|
return self.unet_config["in_channels"] > 4
|
||||||
@ -55,11 +55,11 @@ class BASE:
|
|||||||
|
|
||||||
def get_model(self, state_dict, prefix=""):
|
def get_model(self, state_dict, prefix=""):
|
||||||
if self.inpaint_model():
|
if self.inpaint_model():
|
||||||
return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict, prefix))
|
return model_base.SDInpaint(self, model_type=self.model_type(state_dict, prefix))
|
||||||
elif self.noise_aug_config is not None:
|
elif self.noise_aug_config is not None:
|
||||||
return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict, prefix))
|
return model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix))
|
||||||
else:
|
else:
|
||||||
return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict, prefix))
|
return model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix))
|
||||||
|
|
||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
return state_dict
|
return state_dict
|
||||||
|
Loading…
Reference in New Issue
Block a user