From 4e6b83a80a82b90bb9e61ed61e4b6a607b71aa3a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 25 Feb 2023 00:55:42 -0500 Subject: [PATCH] Add a T2IAdapterLoader node to load T2I-Adapter models. They are loaded as CONTROL_NET objects because they are similar. --- comfy/sd.py | 92 ++++++++++++- comfy/t2i_adapter/adapter.py | 125 ++++++++++++++++++ .../t2i_adapter/put_t2i_adapter_models_here | 0 nodes.py | 17 +++ 4 files changed, 233 insertions(+), 1 deletion(-) create mode 100644 comfy/t2i_adapter/adapter.py create mode 100644 models/t2i_adapter/put_t2i_adapter_models_here diff --git a/comfy/sd.py b/comfy/sd.py index 11136775..42860710 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -8,6 +8,7 @@ from ldm.util import instantiate_from_config from ldm.models.autoencoder import AutoencoderKL from omegaconf import OmegaConf from .cldm import cldm +from .t2i_adapter import adapter from . import utils @@ -388,7 +389,7 @@ class ControlNet: self.control_model = model_management.load_if_low_vram(self.control_model) control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt) self.control_model = model_management.unload_if_low_vram(self.control_model) - out = {'input':[], 'middle':[], 'output': []} + out = {'middle':[], 'output': []} autocast_enabled = torch.is_autocast_enabled() for i in range(len(control)): @@ -504,6 +505,95 @@ def load_controlnet(ckpt_path, model=None): control = ControlNet(control_model) return control +class T2IAdapter: + def __init__(self, t2i_model, channels_in, device="cuda"): + self.t2i_model = t2i_model + self.channels_in = channels_in + self.strength = 1.0 + self.device = device + self.previous_controlnet = None + self.control_input = None + self.cond_hint_original = None + self.cond_hint = None + + def get_control(self, x_noisy, t, cond_txt): + control_prev = None + if self.previous_controlnet is not None: + control_prev = self.previous_controlnet.get_control(x_noisy, t, cond_txt) + + if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: + if self.cond_hint is not None: + del self.cond_hint + self.cond_hint = None + self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").float().to(self.device) + if self.channels_in == 1 and self.cond_hint.shape[1] > 1: + self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True) + self.t2i_model.to(self.device) + self.control_input = self.t2i_model(self.cond_hint) + self.t2i_model.cpu() + + output_dtype = x_noisy.dtype + out = {'input':[]} + + for i in range(len(self.control_input)): + key = 'input' + x = self.control_input[i] * self.strength + if x.dtype != output_dtype and not autocast_enabled: + x = x.to(output_dtype) + + if control_prev is not None and key in control_prev: + index = len(control_prev[key]) - i * 3 - 3 + prev = control_prev[key][index] + if prev is not None: + x += prev + out[key].insert(0, None) + out[key].insert(0, None) + out[key].insert(0, x) + + if control_prev is not None and 'input' in control_prev: + for i in range(len(out['input'])): + if out['input'][i] is None: + out['input'][i] = control_prev['input'][i] + if control_prev is not None and 'middle' in control_prev: + out['middle'] = control_prev['middle'] + if control_prev is not None and 'output' in control_prev: + out['output'] = control_prev['output'] + return out + + def set_cond_hint(self, cond_hint, strength=1.0): + self.cond_hint_original = cond_hint + self.strength = strength + return self + + def set_previous_controlnet(self, controlnet): + self.previous_controlnet = controlnet + return self + + def copy(self): + c = T2IAdapter(self.t2i_model, self.channels_in) + c.cond_hint_original = self.cond_hint_original + c.strength = self.strength + return c + + def cleanup(self): + if self.previous_controlnet is not None: + self.previous_controlnet.cleanup() + if self.cond_hint is not None: + del self.cond_hint + self.cond_hint = None + + def get_control_models(self): + out = [] + if self.previous_controlnet is not None: + out += self.previous_controlnet.get_control_models() + return out + +def load_t2i_adapter(ckpt_path, model=None): + t2i_data = load_torch_file(ckpt_path) + cin = t2i_data['conv_in.weight'].shape[1] + model_ad = adapter.Adapter(cin=cin, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False) + model_ad.load_state_dict(t2i_data) + return T2IAdapter(model_ad, cin // 64) def load_clip(ckpt_path, embedding_directory=None): clip_data = load_torch_file(ckpt_path) diff --git a/comfy/t2i_adapter/adapter.py b/comfy/t2i_adapter/adapter.py new file mode 100644 index 00000000..d059ba91 --- /dev/null +++ b/comfy/t2i_adapter/adapter.py @@ -0,0 +1,125 @@ +#taken from https://github.com/TencentARC/T2I-Adapter + +import torch +import torch.nn as nn +import torch.nn.functional as F +from ldm.modules.attention import SpatialTransformer, BasicTransformerBlock + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResnetBlock(nn.Module): + def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True): + super().__init__() + ps = ksize//2 + if in_c != out_c or sk==False: + self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps) + else: + # print('n_in') + self.in_conv = None + self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1) + self.act = nn.ReLU() + self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps) + if sk==False: + self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps) + else: + self.skep = None + + self.down = down + if self.down == True: + self.down_opt = Downsample(in_c, use_conv=use_conv) + + def forward(self, x): + if self.down == True: + x = self.down_opt(x) + if self.in_conv is not None: # edit + x = self.in_conv(x) + + h = self.block1(x) + h = self.act(h) + h = self.block2(h) + if self.skep is not None: + return h + self.skep(x) + else: + return h + x + + +class Adapter(nn.Module): + def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True): + super(Adapter, self).__init__() + self.unshuffle = nn.PixelUnshuffle(8) + self.channels = channels + self.nums_rb = nums_rb + self.body = [] + for i in range(len(channels)): + for j in range(nums_rb): + if (i!=0) and (j==0): + self.body.append(ResnetBlock(channels[i-1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv)) + else: + self.body.append(ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv)) + self.body = nn.ModuleList(self.body) + self.conv_in = nn.Conv2d(cin,channels[0], 3, 1, 1) + + def forward(self, x): + # unshuffle + x = self.unshuffle(x) + # extract features + features = [] + x = self.conv_in(x) + for i in range(len(self.channels)): + for j in range(self.nums_rb): + idx = i*self.nums_rb +j + x = self.body[idx](x) + features.append(x) + + return features diff --git a/models/t2i_adapter/put_t2i_adapter_models_here b/models/t2i_adapter/put_t2i_adapter_models_here new file mode 100644 index 00000000..e69de29b diff --git a/nodes.py b/nodes.py index 38d0ce6e..75440be9 100644 --- a/nodes.py +++ b/nodes.py @@ -292,6 +292,22 @@ class ControlNetApply: c.append(n) return (c, ) +class T2IAdapterLoader: + models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") + t2i_adapter_dir = os.path.join(models_dir, "t2i_adapter") + @classmethod + def INPUT_TYPES(s): + return {"required": { "t2i_adapter_name": (filter_files_extensions(recursive_search(s.t2i_adapter_dir), supported_pt_extensions), )}} + + RETURN_TYPES = ("CONTROL_NET",) + FUNCTION = "load_t2i_adapter" + + CATEGORY = "loaders" + + def load_t2i_adapter(self, t2i_adapter_name): + t2i_path = os.path.join(self.t2i_adapter_dir, t2i_adapter_name) + t2i_adapter = comfy.sd.load_t2i_adapter(t2i_path) + return (t2i_adapter,) class CLIPLoader: models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") @@ -804,6 +820,7 @@ NODE_CLASS_MAPPINGS = { "ControlNetApply": ControlNetApply, "ControlNetLoader": ControlNetLoader, "DiffControlNetLoader": DiffControlNetLoader, + "T2IAdapterLoader": T2IAdapterLoader, "VAEDecodeTiled": VAEDecodeTiled, }