From 2b14041d4bda0b05f6b5277b37b845e719fcf3dc Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 13 Jun 2023 02:40:58 -0400 Subject: [PATCH] Remove useless code. --- comfy/k_diffusion/augmentation.py | 105 ------------ comfy/k_diffusion/config.py | 110 ------------ comfy/k_diffusion/evaluation.py | 134 --------------- comfy/k_diffusion/gns.py | 99 ----------- comfy/k_diffusion/layers.py | 246 --------------------------- comfy/k_diffusion/models/__init__.py | 1 - comfy/k_diffusion/models/image_v1.py | 156 ----------------- comfy/k_diffusion/utils.py | 19 --- 8 files changed, 870 deletions(-) delete mode 100644 comfy/k_diffusion/augmentation.py delete mode 100644 comfy/k_diffusion/config.py delete mode 100644 comfy/k_diffusion/evaluation.py delete mode 100644 comfy/k_diffusion/gns.py delete mode 100644 comfy/k_diffusion/layers.py delete mode 100644 comfy/k_diffusion/models/__init__.py delete mode 100644 comfy/k_diffusion/models/image_v1.py diff --git a/comfy/k_diffusion/augmentation.py b/comfy/k_diffusion/augmentation.py deleted file mode 100644 index 7dd17c68..00000000 --- a/comfy/k_diffusion/augmentation.py +++ /dev/null @@ -1,105 +0,0 @@ -from functools import reduce -import math -import operator - -import numpy as np -from skimage import transform -import torch -from torch import nn - - -def translate2d(tx, ty): - mat = [[1, 0, tx], - [0, 1, ty], - [0, 0, 1]] - return torch.tensor(mat, dtype=torch.float32) - - -def scale2d(sx, sy): - mat = [[sx, 0, 0], - [ 0, sy, 0], - [ 0, 0, 1]] - return torch.tensor(mat, dtype=torch.float32) - - -def rotate2d(theta): - mat = [[torch.cos(theta), torch.sin(-theta), 0], - [torch.sin(theta), torch.cos(theta), 0], - [ 0, 0, 1]] - return torch.tensor(mat, dtype=torch.float32) - - -class KarrasAugmentationPipeline: - def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1/8): - self.a_prob = a_prob - self.a_scale = a_scale - self.a_aniso = a_aniso - self.a_trans = a_trans - - def __call__(self, image): - h, w = image.size - mats = [translate2d(h / 2 - 0.5, w / 2 - 0.5)] - - # x-flip - a0 = torch.randint(2, []).float() - mats.append(scale2d(1 - 2 * a0, 1)) - # y-flip - do = (torch.rand([]) < self.a_prob).float() - a1 = torch.randint(2, []).float() * do - mats.append(scale2d(1, 1 - 2 * a1)) - # scaling - do = (torch.rand([]) < self.a_prob).float() - a2 = torch.randn([]) * do - mats.append(scale2d(self.a_scale ** a2, self.a_scale ** a2)) - # rotation - do = (torch.rand([]) < self.a_prob).float() - a3 = (torch.rand([]) * 2 * math.pi - math.pi) * do - mats.append(rotate2d(-a3)) - # anisotropy - do = (torch.rand([]) < self.a_prob).float() - a4 = (torch.rand([]) * 2 * math.pi - math.pi) * do - a5 = torch.randn([]) * do - mats.append(rotate2d(a4)) - mats.append(scale2d(self.a_aniso ** a5, self.a_aniso ** -a5)) - mats.append(rotate2d(-a4)) - # translation - do = (torch.rand([]) < self.a_prob).float() - a6 = torch.randn([]) * do - a7 = torch.randn([]) * do - mats.append(translate2d(self.a_trans * w * a6, self.a_trans * h * a7)) - - # form the transformation matrix and conditioning vector - mats.append(translate2d(-h / 2 + 0.5, -w / 2 + 0.5)) - mat = reduce(operator.matmul, mats) - cond = torch.stack([a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7]) - - # apply the transformation - image_orig = np.array(image, dtype=np.float32) / 255 - if image_orig.ndim == 2: - image_orig = image_orig[..., None] - tf = transform.AffineTransform(mat.numpy()) - image = transform.warp(image_orig, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True) - image_orig = torch.as_tensor(image_orig).movedim(2, 0) * 2 - 1 - image = torch.as_tensor(image).movedim(2, 0) * 2 - 1 - return image, image_orig, cond - - -class KarrasAugmentWrapper(nn.Module): - def __init__(self, model): - super().__init__() - self.inner_model = model - - def forward(self, input, sigma, aug_cond=None, mapping_cond=None, **kwargs): - if aug_cond is None: - aug_cond = input.new_zeros([input.shape[0], 9]) - if mapping_cond is None: - mapping_cond = aug_cond - else: - mapping_cond = torch.cat([aug_cond, mapping_cond], dim=1) - return self.inner_model(input, sigma, mapping_cond=mapping_cond, **kwargs) - - def set_skip_stages(self, skip_stages): - return self.inner_model.set_skip_stages(skip_stages) - - def set_patch_size(self, patch_size): - return self.inner_model.set_patch_size(patch_size) diff --git a/comfy/k_diffusion/config.py b/comfy/k_diffusion/config.py deleted file mode 100644 index 4b504d6d..00000000 --- a/comfy/k_diffusion/config.py +++ /dev/null @@ -1,110 +0,0 @@ -from functools import partial -import json -import math -import warnings - -from jsonmerge import merge - -from . import augmentation, layers, models, utils - - -def load_config(file): - defaults = { - 'model': { - 'sigma_data': 1., - 'patch_size': 1, - 'dropout_rate': 0., - 'augment_wrapper': True, - 'augment_prob': 0., - 'mapping_cond_dim': 0, - 'unet_cond_dim': 0, - 'cross_cond_dim': 0, - 'cross_attn_depths': None, - 'skip_stages': 0, - 'has_variance': False, - }, - 'dataset': { - 'type': 'imagefolder', - }, - 'optimizer': { - 'type': 'adamw', - 'lr': 1e-4, - 'betas': [0.95, 0.999], - 'eps': 1e-6, - 'weight_decay': 1e-3, - }, - 'lr_sched': { - 'type': 'inverse', - 'inv_gamma': 20000., - 'power': 1., - 'warmup': 0.99, - }, - 'ema_sched': { - 'type': 'inverse', - 'power': 0.6667, - 'max_value': 0.9999 - }, - } - config = json.load(file) - return merge(defaults, config) - - -def make_model(config): - config = config['model'] - assert config['type'] == 'image_v1' - model = models.ImageDenoiserModelV1( - config['input_channels'], - config['mapping_out'], - config['depths'], - config['channels'], - config['self_attn_depths'], - config['cross_attn_depths'], - patch_size=config['patch_size'], - dropout_rate=config['dropout_rate'], - mapping_cond_dim=config['mapping_cond_dim'] + (9 if config['augment_wrapper'] else 0), - unet_cond_dim=config['unet_cond_dim'], - cross_cond_dim=config['cross_cond_dim'], - skip_stages=config['skip_stages'], - has_variance=config['has_variance'], - ) - if config['augment_wrapper']: - model = augmentation.KarrasAugmentWrapper(model) - return model - - -def make_denoiser_wrapper(config): - config = config['model'] - sigma_data = config.get('sigma_data', 1.) - has_variance = config.get('has_variance', False) - if not has_variance: - return partial(layers.Denoiser, sigma_data=sigma_data) - return partial(layers.DenoiserWithVariance, sigma_data=sigma_data) - - -def make_sample_density(config): - sd_config = config['sigma_sample_density'] - sigma_data = config['sigma_data'] - if sd_config['type'] == 'lognormal': - loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc'] - scale = sd_config['std'] if 'std' in sd_config else sd_config['scale'] - return partial(utils.rand_log_normal, loc=loc, scale=scale) - if sd_config['type'] == 'loglogistic': - loc = sd_config['loc'] if 'loc' in sd_config else math.log(sigma_data) - scale = sd_config['scale'] if 'scale' in sd_config else 0.5 - min_value = sd_config['min_value'] if 'min_value' in sd_config else 0. - max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf') - return partial(utils.rand_log_logistic, loc=loc, scale=scale, min_value=min_value, max_value=max_value) - if sd_config['type'] == 'loguniform': - min_value = sd_config['min_value'] if 'min_value' in sd_config else config['sigma_min'] - max_value = sd_config['max_value'] if 'max_value' in sd_config else config['sigma_max'] - return partial(utils.rand_log_uniform, min_value=min_value, max_value=max_value) - if sd_config['type'] == 'v-diffusion': - min_value = sd_config['min_value'] if 'min_value' in sd_config else 0. - max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf') - return partial(utils.rand_v_diffusion, sigma_data=sigma_data, min_value=min_value, max_value=max_value) - if sd_config['type'] == 'split-lognormal': - loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc'] - scale_1 = sd_config['std_1'] if 'std_1' in sd_config else sd_config['scale_1'] - scale_2 = sd_config['std_2'] if 'std_2' in sd_config else sd_config['scale_2'] - return partial(utils.rand_split_log_normal, loc=loc, scale_1=scale_1, scale_2=scale_2) - raise ValueError('Unknown sample density type') diff --git a/comfy/k_diffusion/evaluation.py b/comfy/k_diffusion/evaluation.py deleted file mode 100644 index 2c34bbf1..00000000 --- a/comfy/k_diffusion/evaluation.py +++ /dev/null @@ -1,134 +0,0 @@ -import math -import os -from pathlib import Path - -from cleanfid.inception_torchscript import InceptionV3W -import clip -from resize_right import resize -import torch -from torch import nn -from torch.nn import functional as F -from torchvision import transforms -from tqdm.auto import trange - -from . import utils - - -class InceptionV3FeatureExtractor(nn.Module): - def __init__(self, device='cpu'): - super().__init__() - path = Path(os.environ.get('XDG_CACHE_HOME', Path.home() / '.cache')) / 'k-diffusion' - url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' - digest = 'f58cb9b6ec323ed63459aa4fb441fe750cfe39fafad6da5cb504a16f19e958f4' - utils.download_file(path / 'inception-2015-12-05.pt', url, digest) - self.model = InceptionV3W(str(path), resize_inside=False).to(device) - self.size = (299, 299) - - def forward(self, x): - if x.shape[2:4] != self.size: - x = resize(x, out_shape=self.size, pad_mode='reflect') - if x.shape[1] == 1: - x = torch.cat([x] * 3, dim=1) - x = (x * 127.5 + 127.5).clamp(0, 255) - return self.model(x) - - -class CLIPFeatureExtractor(nn.Module): - def __init__(self, name='ViT-L/14@336px', device='cpu'): - super().__init__() - self.model = clip.load(name, device=device)[0].eval().requires_grad_(False) - self.normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), - std=(0.26862954, 0.26130258, 0.27577711)) - self.size = (self.model.visual.input_resolution, self.model.visual.input_resolution) - - def forward(self, x): - if x.shape[2:4] != self.size: - x = resize(x.add(1).div(2), out_shape=self.size, pad_mode='reflect').clamp(0, 1) - x = self.normalize(x) - x = self.model.encode_image(x).float() - x = F.normalize(x) * x.shape[1] ** 0.5 - return x - - -def compute_features(accelerator, sample_fn, extractor_fn, n, batch_size): - n_per_proc = math.ceil(n / accelerator.num_processes) - feats_all = [] - try: - for i in trange(0, n_per_proc, batch_size, disable=not accelerator.is_main_process): - cur_batch_size = min(n - i, batch_size) - samples = sample_fn(cur_batch_size)[:cur_batch_size] - feats_all.append(accelerator.gather(extractor_fn(samples))) - except StopIteration: - pass - return torch.cat(feats_all)[:n] - - -def polynomial_kernel(x, y): - d = x.shape[-1] - dot = x @ y.transpose(-2, -1) - return (dot / d + 1) ** 3 - - -def squared_mmd(x, y, kernel=polynomial_kernel): - m = x.shape[-2] - n = y.shape[-2] - kxx = kernel(x, x) - kyy = kernel(y, y) - kxy = kernel(x, y) - kxx_sum = kxx.sum([-1, -2]) - kxx.diagonal(dim1=-1, dim2=-2).sum(-1) - kyy_sum = kyy.sum([-1, -2]) - kyy.diagonal(dim1=-1, dim2=-2).sum(-1) - kxy_sum = kxy.sum([-1, -2]) - term_1 = kxx_sum / m / (m - 1) - term_2 = kyy_sum / n / (n - 1) - term_3 = kxy_sum * 2 / m / n - return term_1 + term_2 - term_3 - - -@utils.tf32_mode(matmul=False) -def kid(x, y, max_size=5000): - x_size, y_size = x.shape[0], y.shape[0] - n_partitions = math.ceil(max(x_size / max_size, y_size / max_size)) - total_mmd = x.new_zeros([]) - for i in range(n_partitions): - cur_x = x[round(i * x_size / n_partitions):round((i + 1) * x_size / n_partitions)] - cur_y = y[round(i * y_size / n_partitions):round((i + 1) * y_size / n_partitions)] - total_mmd = total_mmd + squared_mmd(cur_x, cur_y) - return total_mmd / n_partitions - - -class _MatrixSquareRootEig(torch.autograd.Function): - @staticmethod - def forward(ctx, a): - vals, vecs = torch.linalg.eigh(a) - ctx.save_for_backward(vals, vecs) - return vecs @ vals.abs().sqrt().diag_embed() @ vecs.transpose(-2, -1) - - @staticmethod - def backward(ctx, grad_output): - vals, vecs = ctx.saved_tensors - d = vals.abs().sqrt().unsqueeze(-1).repeat_interleave(vals.shape[-1], -1) - vecs_t = vecs.transpose(-2, -1) - return vecs @ (vecs_t @ grad_output @ vecs / (d + d.transpose(-2, -1))) @ vecs_t - - -def sqrtm_eig(a): - if a.ndim < 2: - raise RuntimeError('tensor of matrices must have at least 2 dimensions') - if a.shape[-2] != a.shape[-1]: - raise RuntimeError('tensor must be batches of square matrices') - return _MatrixSquareRootEig.apply(a) - - -@utils.tf32_mode(matmul=False) -def fid(x, y, eps=1e-8): - x_mean = x.mean(dim=0) - y_mean = y.mean(dim=0) - mean_term = (x_mean - y_mean).pow(2).sum() - x_cov = torch.cov(x.T) - y_cov = torch.cov(y.T) - eps_eye = torch.eye(x_cov.shape[0], device=x_cov.device, dtype=x_cov.dtype) * eps - x_cov = x_cov + eps_eye - y_cov = y_cov + eps_eye - x_cov_sqrt = sqrtm_eig(x_cov) - cov_term = torch.trace(x_cov + y_cov - 2 * sqrtm_eig(x_cov_sqrt @ y_cov @ x_cov_sqrt)) - return mean_term + cov_term diff --git a/comfy/k_diffusion/gns.py b/comfy/k_diffusion/gns.py deleted file mode 100644 index dcb7b8d8..00000000 --- a/comfy/k_diffusion/gns.py +++ /dev/null @@ -1,99 +0,0 @@ -import torch -from torch import nn - - -class DDPGradientStatsHook: - def __init__(self, ddp_module): - try: - ddp_module.register_comm_hook(self, self._hook_fn) - except AttributeError: - raise ValueError('DDPGradientStatsHook does not support non-DDP wrapped modules') - self._clear_state() - - def _clear_state(self): - self.bucket_sq_norms_small_batch = [] - self.bucket_sq_norms_large_batch = [] - - @staticmethod - def _hook_fn(self, bucket): - buf = bucket.buffer() - self.bucket_sq_norms_small_batch.append(buf.pow(2).sum()) - fut = torch.distributed.all_reduce(buf, op=torch.distributed.ReduceOp.AVG, async_op=True).get_future() - def callback(fut): - buf = fut.value()[0] - self.bucket_sq_norms_large_batch.append(buf.pow(2).sum()) - return buf - return fut.then(callback) - - def get_stats(self): - sq_norm_small_batch = sum(self.bucket_sq_norms_small_batch) - sq_norm_large_batch = sum(self.bucket_sq_norms_large_batch) - self._clear_state() - stats = torch.stack([sq_norm_small_batch, sq_norm_large_batch]) - torch.distributed.all_reduce(stats, op=torch.distributed.ReduceOp.AVG) - return stats[0].item(), stats[1].item() - - -class GradientNoiseScale: - """Calculates the gradient noise scale (1 / SNR), or critical batch size, - from _An Empirical Model of Large-Batch Training_, - https://arxiv.org/abs/1812.06162). - - Args: - beta (float): The decay factor for the exponential moving averages used to - calculate the gradient noise scale. - Default: 0.9998 - eps (float): Added for numerical stability. - Default: 1e-8 - """ - - def __init__(self, beta=0.9998, eps=1e-8): - self.beta = beta - self.eps = eps - self.ema_sq_norm = 0. - self.ema_var = 0. - self.beta_cumprod = 1. - self.gradient_noise_scale = float('nan') - - def state_dict(self): - """Returns the state of the object as a :class:`dict`.""" - return dict(self.__dict__.items()) - - def load_state_dict(self, state_dict): - """Loads the object's state. - Args: - state_dict (dict): object state. Should be an object returned - from a call to :meth:`state_dict`. - """ - self.__dict__.update(state_dict) - - def update(self, sq_norm_small_batch, sq_norm_large_batch, n_small_batch, n_large_batch): - """Updates the state with a new batch's gradient statistics, and returns the - current gradient noise scale. - - Args: - sq_norm_small_batch (float): The mean of the squared 2-norms of microbatch or - per sample gradients. - sq_norm_large_batch (float): The squared 2-norm of the mean of the microbatch or - per sample gradients. - n_small_batch (int): The batch size of the individual microbatch or per sample - gradients (1 if per sample). - n_large_batch (int): The total batch size of the mean of the microbatch or - per sample gradients. - """ - est_sq_norm = (n_large_batch * sq_norm_large_batch - n_small_batch * sq_norm_small_batch) / (n_large_batch - n_small_batch) - est_var = (sq_norm_small_batch - sq_norm_large_batch) / (1 / n_small_batch - 1 / n_large_batch) - self.ema_sq_norm = self.beta * self.ema_sq_norm + (1 - self.beta) * est_sq_norm - self.ema_var = self.beta * self.ema_var + (1 - self.beta) * est_var - self.beta_cumprod *= self.beta - self.gradient_noise_scale = max(self.ema_var, self.eps) / max(self.ema_sq_norm, self.eps) - return self.gradient_noise_scale - - def get_gns(self): - """Returns the current gradient noise scale.""" - return self.gradient_noise_scale - - def get_stats(self): - """Returns the current (debiased) estimates of the squared mean gradient - and gradient variance.""" - return self.ema_sq_norm / (1 - self.beta_cumprod), self.ema_var / (1 - self.beta_cumprod) diff --git a/comfy/k_diffusion/layers.py b/comfy/k_diffusion/layers.py deleted file mode 100644 index cdeba0ad..00000000 --- a/comfy/k_diffusion/layers.py +++ /dev/null @@ -1,246 +0,0 @@ -import math - -from einops import rearrange, repeat -import torch -from torch import nn -from torch.nn import functional as F - -from . import utils - -# Karras et al. preconditioned denoiser - -class Denoiser(nn.Module): - """A Karras et al. preconditioner for denoising diffusion models.""" - - def __init__(self, inner_model, sigma_data=1.): - super().__init__() - self.inner_model = inner_model - self.sigma_data = sigma_data - - def get_scalings(self, sigma): - c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 - c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 - return c_skip, c_out, c_in - - def loss(self, input, noise, sigma, **kwargs): - c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] - noised_input = input + noise * utils.append_dims(sigma, input.ndim) - model_output = self.inner_model(noised_input * c_in, sigma, **kwargs) - target = (input - c_skip * noised_input) / c_out - return (model_output - target).pow(2).flatten(1).mean(1) - - def forward(self, input, sigma, **kwargs): - c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] - return self.inner_model(input * c_in, sigma, **kwargs) * c_out + input * c_skip - - -class DenoiserWithVariance(Denoiser): - def loss(self, input, noise, sigma, **kwargs): - c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] - noised_input = input + noise * utils.append_dims(sigma, input.ndim) - model_output, logvar = self.inner_model(noised_input * c_in, sigma, return_variance=True, **kwargs) - logvar = utils.append_dims(logvar, model_output.ndim) - target = (input - c_skip * noised_input) / c_out - losses = ((model_output - target) ** 2 / logvar.exp() + logvar) / 2 - return losses.flatten(1).mean(1) - - -# Residual blocks - -class ResidualBlock(nn.Module): - def __init__(self, *main, skip=None): - super().__init__() - self.main = nn.Sequential(*main) - self.skip = skip if skip else nn.Identity() - - def forward(self, input): - return self.main(input) + self.skip(input) - - -# Noise level (and other) conditioning - -class ConditionedModule(nn.Module): - pass - - -class UnconditionedModule(ConditionedModule): - def __init__(self, module): - super().__init__() - self.module = module - - def forward(self, input, cond=None): - return self.module(input) - - -class ConditionedSequential(nn.Sequential, ConditionedModule): - def forward(self, input, cond): - for module in self: - if isinstance(module, ConditionedModule): - input = module(input, cond) - else: - input = module(input) - return input - - -class ConditionedResidualBlock(ConditionedModule): - def __init__(self, *main, skip=None): - super().__init__() - self.main = ConditionedSequential(*main) - self.skip = skip if skip else nn.Identity() - - def forward(self, input, cond): - skip = self.skip(input, cond) if isinstance(self.skip, ConditionedModule) else self.skip(input) - return self.main(input, cond) + skip - - -class AdaGN(ConditionedModule): - def __init__(self, feats_in, c_out, num_groups, eps=1e-5, cond_key='cond'): - super().__init__() - self.num_groups = num_groups - self.eps = eps - self.cond_key = cond_key - self.mapper = nn.Linear(feats_in, c_out * 2) - - def forward(self, input, cond): - weight, bias = self.mapper(cond[self.cond_key]).chunk(2, dim=-1) - input = F.group_norm(input, self.num_groups, eps=self.eps) - return torch.addcmul(utils.append_dims(bias, input.ndim), input, utils.append_dims(weight, input.ndim) + 1) - - -# Attention - -class SelfAttention2d(ConditionedModule): - def __init__(self, c_in, n_head, norm, dropout_rate=0.): - super().__init__() - assert c_in % n_head == 0 - self.norm_in = norm(c_in) - self.n_head = n_head - self.qkv_proj = nn.Conv2d(c_in, c_in * 3, 1) - self.out_proj = nn.Conv2d(c_in, c_in, 1) - self.dropout = nn.Dropout(dropout_rate) - - def forward(self, input, cond): - n, c, h, w = input.shape - qkv = self.qkv_proj(self.norm_in(input, cond)) - qkv = qkv.view([n, self.n_head * 3, c // self.n_head, h * w]).transpose(2, 3) - q, k, v = qkv.chunk(3, dim=1) - scale = k.shape[3] ** -0.25 - att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3) - att = self.dropout(att) - y = (att @ v).transpose(2, 3).contiguous().view([n, c, h, w]) - return input + self.out_proj(y) - - -class CrossAttention2d(ConditionedModule): - def __init__(self, c_dec, c_enc, n_head, norm_dec, dropout_rate=0., - cond_key='cross', cond_key_padding='cross_padding'): - super().__init__() - assert c_dec % n_head == 0 - self.cond_key = cond_key - self.cond_key_padding = cond_key_padding - self.norm_enc = nn.LayerNorm(c_enc) - self.norm_dec = norm_dec(c_dec) - self.n_head = n_head - self.q_proj = nn.Conv2d(c_dec, c_dec, 1) - self.kv_proj = nn.Linear(c_enc, c_dec * 2) - self.out_proj = nn.Conv2d(c_dec, c_dec, 1) - self.dropout = nn.Dropout(dropout_rate) - - def forward(self, input, cond): - n, c, h, w = input.shape - q = self.q_proj(self.norm_dec(input, cond)) - q = q.view([n, self.n_head, c // self.n_head, h * w]).transpose(2, 3) - kv = self.kv_proj(self.norm_enc(cond[self.cond_key])) - kv = kv.view([n, -1, self.n_head * 2, c // self.n_head]).transpose(1, 2) - k, v = kv.chunk(2, dim=1) - scale = k.shape[3] ** -0.25 - att = ((q * scale) @ (k.transpose(2, 3) * scale)) - att = att - (cond[self.cond_key_padding][:, None, None, :]) * 10000 - att = att.softmax(3) - att = self.dropout(att) - y = (att @ v).transpose(2, 3) - y = y.contiguous().view([n, c, h, w]) - return input + self.out_proj(y) - - -# Downsampling/upsampling - -_kernels = { - 'linear': - [1 / 8, 3 / 8, 3 / 8, 1 / 8], - 'cubic': - [-0.01171875, -0.03515625, 0.11328125, 0.43359375, - 0.43359375, 0.11328125, -0.03515625, -0.01171875], - 'lanczos3': - [0.003689131001010537, 0.015056144446134567, -0.03399861603975296, - -0.066637322306633, 0.13550527393817902, 0.44638532400131226, - 0.44638532400131226, 0.13550527393817902, -0.066637322306633, - -0.03399861603975296, 0.015056144446134567, 0.003689131001010537] -} -_kernels['bilinear'] = _kernels['linear'] -_kernels['bicubic'] = _kernels['cubic'] - - -class Downsample2d(nn.Module): - def __init__(self, kernel='linear', pad_mode='reflect'): - super().__init__() - self.pad_mode = pad_mode - kernel_1d = torch.tensor([_kernels[kernel]]) - self.pad = kernel_1d.shape[1] // 2 - 1 - self.register_buffer('kernel', kernel_1d.T @ kernel_1d) - - def forward(self, x): - x = F.pad(x, (self.pad,) * 4, self.pad_mode) - weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) - indices = torch.arange(x.shape[1], device=x.device) - weight[indices, indices] = self.kernel.to(weight) - return F.conv2d(x, weight, stride=2) - - -class Upsample2d(nn.Module): - def __init__(self, kernel='linear', pad_mode='reflect'): - super().__init__() - self.pad_mode = pad_mode - kernel_1d = torch.tensor([_kernels[kernel]]) * 2 - self.pad = kernel_1d.shape[1] // 2 - 1 - self.register_buffer('kernel', kernel_1d.T @ kernel_1d) - - def forward(self, x): - x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode) - weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) - indices = torch.arange(x.shape[1], device=x.device) - weight[indices, indices] = self.kernel.to(weight) - return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1) - - -# Embeddings - -class FourierFeatures(nn.Module): - def __init__(self, in_features, out_features, std=1.): - super().__init__() - assert out_features % 2 == 0 - self.register_buffer('weight', torch.randn([out_features // 2, in_features]) * std) - - def forward(self, input): - f = 2 * math.pi * input @ self.weight.T - return torch.cat([f.cos(), f.sin()], dim=-1) - - -# U-Nets - -class UNet(ConditionedModule): - def __init__(self, d_blocks, u_blocks, skip_stages=0): - super().__init__() - self.d_blocks = nn.ModuleList(d_blocks) - self.u_blocks = nn.ModuleList(u_blocks) - self.skip_stages = skip_stages - - def forward(self, input, cond): - skips = [] - for block in self.d_blocks[self.skip_stages:]: - input = block(input, cond) - skips.append(input) - for i, (block, skip) in enumerate(zip(self.u_blocks, reversed(skips))): - input = block(input, cond, skip if i > 0 else None) - return input diff --git a/comfy/k_diffusion/models/__init__.py b/comfy/k_diffusion/models/__init__.py deleted file mode 100644 index 82608ff1..00000000 --- a/comfy/k_diffusion/models/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .image_v1 import ImageDenoiserModelV1 diff --git a/comfy/k_diffusion/models/image_v1.py b/comfy/k_diffusion/models/image_v1.py deleted file mode 100644 index 9ffd5f2c..00000000 --- a/comfy/k_diffusion/models/image_v1.py +++ /dev/null @@ -1,156 +0,0 @@ -import math - -import torch -from torch import nn -from torch.nn import functional as F - -from .. import layers, utils - - -def orthogonal_(module): - nn.init.orthogonal_(module.weight) - return module - - -class ResConvBlock(layers.ConditionedResidualBlock): - def __init__(self, feats_in, c_in, c_mid, c_out, group_size=32, dropout_rate=0.): - skip = None if c_in == c_out else orthogonal_(nn.Conv2d(c_in, c_out, 1, bias=False)) - super().__init__( - layers.AdaGN(feats_in, c_in, max(1, c_in // group_size)), - nn.GELU(), - nn.Conv2d(c_in, c_mid, 3, padding=1), - nn.Dropout2d(dropout_rate, inplace=True), - layers.AdaGN(feats_in, c_mid, max(1, c_mid // group_size)), - nn.GELU(), - nn.Conv2d(c_mid, c_out, 3, padding=1), - nn.Dropout2d(dropout_rate, inplace=True), - skip=skip) - - -class DBlock(layers.ConditionedSequential): - def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., downsample=False, self_attn=False, cross_attn=False, c_enc=0): - modules = [nn.Identity()] - for i in range(n_layers): - my_c_in = c_in if i == 0 else c_mid - my_c_out = c_mid if i < n_layers - 1 else c_out - modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate)) - if self_attn: - norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) - modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate)) - if cross_attn: - norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) - modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate)) - super().__init__(*modules) - self.set_downsample(downsample) - - def set_downsample(self, downsample): - self[0] = layers.Downsample2d() if downsample else nn.Identity() - return self - - -class UBlock(layers.ConditionedSequential): - def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., upsample=False, self_attn=False, cross_attn=False, c_enc=0): - modules = [] - for i in range(n_layers): - my_c_in = c_in if i == 0 else c_mid - my_c_out = c_mid if i < n_layers - 1 else c_out - modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate)) - if self_attn: - norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) - modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate)) - if cross_attn: - norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) - modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate)) - modules.append(nn.Identity()) - super().__init__(*modules) - self.set_upsample(upsample) - - def forward(self, input, cond, skip=None): - if skip is not None: - input = torch.cat([input, skip], dim=1) - return super().forward(input, cond) - - def set_upsample(self, upsample): - self[-1] = layers.Upsample2d() if upsample else nn.Identity() - return self - - -class MappingNet(nn.Sequential): - def __init__(self, feats_in, feats_out, n_layers=2): - layers = [] - for i in range(n_layers): - layers.append(orthogonal_(nn.Linear(feats_in if i == 0 else feats_out, feats_out))) - layers.append(nn.GELU()) - super().__init__(*layers) - - -class ImageDenoiserModelV1(nn.Module): - def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, cross_attn_depths=None, mapping_cond_dim=0, unet_cond_dim=0, cross_cond_dim=0, dropout_rate=0., patch_size=1, skip_stages=0, has_variance=False): - super().__init__() - self.c_in = c_in - self.channels = channels - self.unet_cond_dim = unet_cond_dim - self.patch_size = patch_size - self.has_variance = has_variance - self.timestep_embed = layers.FourierFeatures(1, feats_in) - if mapping_cond_dim > 0: - self.mapping_cond = nn.Linear(mapping_cond_dim, feats_in, bias=False) - self.mapping = MappingNet(feats_in, feats_in) - self.proj_in = nn.Conv2d((c_in + unet_cond_dim) * self.patch_size ** 2, channels[max(0, skip_stages - 1)], 1) - self.proj_out = nn.Conv2d(channels[max(0, skip_stages - 1)], c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1) - nn.init.zeros_(self.proj_out.weight) - nn.init.zeros_(self.proj_out.bias) - if cross_cond_dim == 0: - cross_attn_depths = [False] * len(self_attn_depths) - d_blocks, u_blocks = [], [] - for i in range(len(depths)): - my_c_in = channels[max(0, i - 1)] - d_blocks.append(DBlock(depths[i], feats_in, my_c_in, channels[i], channels[i], downsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate)) - for i in range(len(depths)): - my_c_in = channels[i] * 2 if i < len(depths) - 1 else channels[i] - my_c_out = channels[max(0, i - 1)] - u_blocks.append(UBlock(depths[i], feats_in, my_c_in, channels[i], my_c_out, upsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate)) - self.u_net = layers.UNet(d_blocks, reversed(u_blocks), skip_stages=skip_stages) - - def forward(self, input, sigma, mapping_cond=None, unet_cond=None, cross_cond=None, cross_cond_padding=None, return_variance=False): - c_noise = sigma.log() / 4 - timestep_embed = self.timestep_embed(utils.append_dims(c_noise, 2)) - mapping_cond_embed = torch.zeros_like(timestep_embed) if mapping_cond is None else self.mapping_cond(mapping_cond) - mapping_out = self.mapping(timestep_embed + mapping_cond_embed) - cond = {'cond': mapping_out} - if unet_cond is not None: - input = torch.cat([input, unet_cond], dim=1) - if cross_cond is not None: - cond['cross'] = cross_cond - cond['cross_padding'] = cross_cond_padding - if self.patch_size > 1: - input = F.pixel_unshuffle(input, self.patch_size) - input = self.proj_in(input) - input = self.u_net(input, cond) - input = self.proj_out(input) - if self.has_variance: - input, logvar = input[:, :-1], input[:, -1].flatten(1).mean(1) - if self.patch_size > 1: - input = F.pixel_shuffle(input, self.patch_size) - if self.has_variance and return_variance: - return input, logvar - return input - - def set_skip_stages(self, skip_stages): - self.proj_in = nn.Conv2d(self.proj_in.in_channels, self.channels[max(0, skip_stages - 1)], 1) - self.proj_out = nn.Conv2d(self.channels[max(0, skip_stages - 1)], self.proj_out.out_channels, 1) - nn.init.zeros_(self.proj_out.weight) - nn.init.zeros_(self.proj_out.bias) - self.u_net.skip_stages = skip_stages - for i, block in enumerate(self.u_net.d_blocks): - block.set_downsample(i > skip_stages) - for i, block in enumerate(reversed(self.u_net.u_blocks)): - block.set_upsample(i > skip_stages) - return self - - def set_patch_size(self, patch_size): - self.patch_size = patch_size - self.proj_in = nn.Conv2d((self.c_in + self.unet_cond_dim) * self.patch_size ** 2, self.channels[max(0, self.u_net.skip_stages - 1)], 1) - self.proj_out = nn.Conv2d(self.channels[max(0, self.u_net.skip_stages - 1)], self.c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1) - nn.init.zeros_(self.proj_out.weight) - nn.init.zeros_(self.proj_out.bias) diff --git a/comfy/k_diffusion/utils.py b/comfy/k_diffusion/utils.py index ce6014be..a644df2f 100644 --- a/comfy/k_diffusion/utils.py +++ b/comfy/k_diffusion/utils.py @@ -10,25 +10,6 @@ from PIL import Image import torch from torch import nn, optim from torch.utils import data -from torchvision.transforms import functional as TF - - -def from_pil_image(x): - """Converts from a PIL image to a tensor.""" - x = TF.to_tensor(x) - if x.ndim == 2: - x = x[..., None] - return x * 2 - 1 - - -def to_pil_image(x): - """Converts from a tensor to a PIL image.""" - if x.ndim == 4: - assert x.shape[0] == 1 - x = x[0] - if x.shape[0] == 1: - x = x[0] - return TF.to_pil_image((x.clamp(-1, 1) + 1) / 2) def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):