mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 14:09:36 +00:00
Implement noise augmentation for SD 4X upscale model.
This commit is contained in:
parent
ef4f6037cb
commit
8c6493578b
@ -498,7 +498,7 @@ class UNetModel(nn.Module):
|
|||||||
|
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
if isinstance(self.num_classes, int):
|
if isinstance(self.num_classes, int):
|
||||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device)
|
||||||
elif self.num_classes == "continuous":
|
elif self.num_classes == "continuous":
|
||||||
print("setting up linear c_adm embedding layer")
|
print("setting up linear c_adm embedding layer")
|
||||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||||
|
@ -41,8 +41,12 @@ class AbstractLowScaleModel(nn.Module):
|
|||||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
||||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
||||||
|
|
||||||
def q_sample(self, x_start, t, noise=None):
|
def q_sample(self, x_start, t, noise=None, seed=None):
|
||||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
if noise is None:
|
||||||
|
if seed is None:
|
||||||
|
noise = torch.randn_like(x_start)
|
||||||
|
else:
|
||||||
|
noise = torch.randn(x_start.size(), dtype=x_start.dtype, layout=x_start.layout, generator=torch.manual_seed(seed)).to(x_start.device)
|
||||||
return (extract_into_tensor(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
|
return (extract_into_tensor(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
|
||||||
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise)
|
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise)
|
||||||
|
|
||||||
@ -69,12 +73,12 @@ class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
|
|||||||
super().__init__(noise_schedule_config=noise_schedule_config)
|
super().__init__(noise_schedule_config=noise_schedule_config)
|
||||||
self.max_noise_level = max_noise_level
|
self.max_noise_level = max_noise_level
|
||||||
|
|
||||||
def forward(self, x, noise_level=None):
|
def forward(self, x, noise_level=None, seed=None):
|
||||||
if noise_level is None:
|
if noise_level is None:
|
||||||
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
||||||
else:
|
else:
|
||||||
assert isinstance(noise_level, torch.Tensor)
|
assert isinstance(noise_level, torch.Tensor)
|
||||||
z = self.q_sample(x, noise_level)
|
z = self.q_sample(x, noise_level, seed=seed)
|
||||||
return z, noise_level
|
return z, noise_level
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
|
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||||
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||||
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.conds
|
import comfy.conds
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
@ -78,7 +78,8 @@ class BaseModel(torch.nn.Module):
|
|||||||
extra_conds = {}
|
extra_conds = {}
|
||||||
for o in kwargs:
|
for o in kwargs:
|
||||||
extra = kwargs[o]
|
extra = kwargs[o]
|
||||||
if hasattr(extra, "to"):
|
if hasattr(extra, "dtype"):
|
||||||
|
if extra.dtype != torch.int and extra.dtype != torch.long:
|
||||||
extra = extra.to(dtype)
|
extra = extra.to(dtype)
|
||||||
extra_conds[o] = extra
|
extra_conds[o] = extra
|
||||||
|
|
||||||
@ -368,20 +369,31 @@ class Stable_Zero123(BaseModel):
|
|||||||
class SD_X4Upscaler(BaseModel):
|
class SD_X4Upscaler(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
|
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
|
||||||
super().__init__(model_config, model_type, device=device)
|
super().__init__(model_config, model_type, device=device)
|
||||||
|
self.noise_augmentor = ImageConcatWithNoiseAugmentation(noise_schedule_config={"linear_start": 0.0001, "linear_end": 0.02}, max_noise_level=350)
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
|
|
||||||
image = kwargs.get("concat_image", None)
|
image = kwargs.get("concat_image", None)
|
||||||
noise = kwargs.get("noise", None)
|
noise = kwargs.get("noise", None)
|
||||||
|
noise_augment = kwargs.get("noise_augmentation", 0.0)
|
||||||
|
device = kwargs["device"]
|
||||||
|
seed = kwargs["seed"] - 10
|
||||||
|
|
||||||
|
noise_level = round((self.noise_augmentor.max_noise_level) * noise_augment)
|
||||||
|
|
||||||
if image is None:
|
if image is None:
|
||||||
image = torch.zeros_like(noise)[:,:3]
|
image = torch.zeros_like(noise)[:,:3]
|
||||||
|
|
||||||
if image.shape[1:] != noise.shape[1:]:
|
if image.shape[1:] != noise.shape[1:]:
|
||||||
image = utils.common_upscale(image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||||
|
|
||||||
|
noise_level = torch.tensor([noise_level], device=device)
|
||||||
|
if noise_augment > 0:
|
||||||
|
image, noise_level = self.noise_augmentor(image.to(device), noise_level=noise_level, seed=seed)
|
||||||
|
|
||||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||||
|
|
||||||
out['c_concat'] = comfy.conds.CONDNoiseShape(image)
|
out['c_concat'] = comfy.conds.CONDNoiseShape(image)
|
||||||
|
out['y'] = comfy.conds.CONDRegular(noise_level)
|
||||||
return out
|
return out
|
||||||
|
@ -603,8 +603,8 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
|||||||
latent_image = model.process_latent_in(latent_image)
|
latent_image = model.process_latent_in(latent_image)
|
||||||
|
|
||||||
if hasattr(model, 'extra_conds'):
|
if hasattr(model, 'extra_conds'):
|
||||||
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask)
|
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
|
||||||
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask)
|
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
|
||||||
|
|
||||||
#make sure each cond area has an opposite one with the same area
|
#make sure each cond area has an opposite one with the same area
|
||||||
for c in positive:
|
for c in positive:
|
||||||
|
@ -290,6 +290,7 @@ class SD_X4Upscaler(SD20):
|
|||||||
|
|
||||||
unet_extra_config = {
|
unet_extra_config = {
|
||||||
"disable_self_attentions": [True, True, True, False],
|
"disable_self_attentions": [True, True, True, False],
|
||||||
|
"num_classes": 1000,
|
||||||
"num_heads": 8,
|
"num_heads": 8,
|
||||||
"num_head_channels": -1,
|
"num_head_channels": -1,
|
||||||
}
|
}
|
||||||
|
@ -9,7 +9,7 @@ class SD_4XUpscale_Conditioning:
|
|||||||
"positive": ("CONDITIONING",),
|
"positive": ("CONDITIONING",),
|
||||||
"negative": ("CONDITIONING",),
|
"negative": ("CONDITIONING",),
|
||||||
"scale_ratio": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
"scale_ratio": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
# "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}), #TODO
|
"noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||||
RETURN_NAMES = ("positive", "negative", "latent")
|
RETURN_NAMES = ("positive", "negative", "latent")
|
||||||
@ -18,7 +18,7 @@ class SD_4XUpscale_Conditioning:
|
|||||||
|
|
||||||
CATEGORY = "conditioning/upscale_diffusion"
|
CATEGORY = "conditioning/upscale_diffusion"
|
||||||
|
|
||||||
def encode(self, images, positive, negative, scale_ratio):
|
def encode(self, images, positive, negative, scale_ratio, noise_augmentation):
|
||||||
width = max(1, round(images.shape[-2] * scale_ratio))
|
width = max(1, round(images.shape[-2] * scale_ratio))
|
||||||
height = max(1, round(images.shape[-3] * scale_ratio))
|
height = max(1, round(images.shape[-3] * scale_ratio))
|
||||||
|
|
||||||
@ -30,11 +30,13 @@ class SD_4XUpscale_Conditioning:
|
|||||||
for t in positive:
|
for t in positive:
|
||||||
n = [t[0], t[1].copy()]
|
n = [t[0], t[1].copy()]
|
||||||
n[1]['concat_image'] = pixels
|
n[1]['concat_image'] = pixels
|
||||||
|
n[1]['noise_augmentation'] = noise_augmentation
|
||||||
out_cp.append(n)
|
out_cp.append(n)
|
||||||
|
|
||||||
for t in negative:
|
for t in negative:
|
||||||
n = [t[0], t[1].copy()]
|
n = [t[0], t[1].copy()]
|
||||||
n[1]['concat_image'] = pixels
|
n[1]['concat_image'] = pixels
|
||||||
|
n[1]['noise_augmentation'] = noise_augmentation
|
||||||
out_cn.append(n)
|
out_cn.append(n)
|
||||||
|
|
||||||
latent = torch.zeros([images.shape[0], 4, height // 4, width // 4])
|
latent = torch.zeros([images.shape[0], 4, height // 4, width // 4])
|
||||||
|
Loading…
Reference in New Issue
Block a user