diff --git a/comfy/diffusers_load.py b/comfy/diffusers_load.py index a52e0102b..c0b420e79 100644 --- a/comfy/diffusers_load.py +++ b/comfy/diffusers_load.py @@ -31,6 +31,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire vae = None if output_vae: - vae = comfy.sd.VAE(ckpt_path=vae_path) + sd = comfy.utils.load_torch_file(vae_path) + vae = comfy.sd.VAE(sd=sd) return (unet, clip, vae) diff --git a/comfy/ldm/models/autoencoder.py b/comfy/ldm/models/autoencoder.py index 1fb7ed879..d2f1d74a9 100644 --- a/comfy/ldm/models/autoencoder.py +++ b/comfy/ldm/models/autoencoder.py @@ -2,67 +2,66 @@ import torch # import pytorch_lightning as pl import torch.nn.functional as F from contextlib import contextmanager +from typing import Any, Dict, List, Optional, Tuple, Union -from comfy.ldm.modules.diffusionmodules.model import Encoder, Decoder from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution from comfy.ldm.util import instantiate_from_config from comfy.ldm.modules.ema import LitEma -# class AutoencoderKL(pl.LightningModule): -class AutoencoderKL(torch.nn.Module): - def __init__(self, - ddconfig, - lossconfig, - embed_dim, - ckpt_path=None, - ignore_keys=[], - image_key="image", - colorize_nlabels=None, - monitor=None, - ema_decay=None, - learn_logvar=False - ): +class DiagonalGaussianRegularizer(torch.nn.Module): + def __init__(self, sample: bool = True): super().__init__() - self.learn_logvar = learn_logvar - self.image_key = image_key - self.encoder = Encoder(**ddconfig) - self.decoder = Decoder(**ddconfig) - self.loss = instantiate_from_config(lossconfig) - assert ddconfig["double_z"] - self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) - self.embed_dim = embed_dim - if colorize_nlabels is not None: - assert type(colorize_nlabels)==int - self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + self.sample = sample + + def get_trainable_parameters(self) -> Any: + yield from () + + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: + log = dict() + posterior = DiagonalGaussianDistribution(z) + if self.sample: + z = posterior.sample() + else: + z = posterior.mode() + kl_loss = posterior.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + log["kl_loss"] = kl_loss + return z, log + + +class AbstractAutoencoder(torch.nn.Module): + """ + This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators, + unCLIP models, etc. Hence, it is fairly general, and specific features + (e.g. discriminator training, encoding, decoding) must be implemented in subclasses. + """ + + def __init__( + self, + ema_decay: Union[None, float] = None, + monitor: Union[None, str] = None, + input_key: str = "jpg", + **kwargs, + ): + super().__init__() + + self.input_key = input_key + self.use_ema = ema_decay is not None if monitor is not None: self.monitor = monitor - self.use_ema = ema_decay is not None if self.use_ema: - self.ema_decay = ema_decay - assert 0. < ema_decay < 1. self.model_ema = LitEma(self, decay=ema_decay) - print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + def get_input(self, batch) -> Any: + raise NotImplementedError() - def init_from_ckpt(self, path, ignore_keys=list()): - if path.lower().endswith(".safetensors"): - import safetensors.torch - sd = safetensors.torch.load_file(path, device="cpu") - else: - sd = torch.load(path, map_location="cpu")["state_dict"] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) - del sd[k] - self.load_state_dict(sd, strict=False) - print(f"Restored from {path}") + def on_train_batch_end(self, *args, **kwargs): + # for EMA computation + if self.use_ema: + self.model_ema(self) @contextmanager def ema_scope(self, context=None): @@ -70,154 +69,159 @@ class AutoencoderKL(torch.nn.Module): self.model_ema.store(self.parameters()) self.model_ema.copy_to(self) if context is not None: - print(f"{context}: Switched to EMA weights") + logpy.info(f"{context}: Switched to EMA weights") try: yield None finally: if self.use_ema: self.model_ema.restore(self.parameters()) if context is not None: - print(f"{context}: Restored training weights") + logpy.info(f"{context}: Restored training weights") - def on_train_batch_end(self, *args, **kwargs): - if self.use_ema: - self.model_ema(self) + def encode(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError("encode()-method of abstract base class called") - def encode(self, x): - h = self.encoder(x) - moments = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(moments) - return posterior + def decode(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError("decode()-method of abstract base class called") - def decode(self, z): - z = self.post_quant_conv(z) - dec = self.decoder(z) - return dec + def instantiate_optimizer_from_config(self, params, lr, cfg): + logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config") + return get_obj_from_str(cfg["target"])( + params, lr=lr, **cfg.get("params", dict()) + ) - def forward(self, input, sample_posterior=True): - posterior = self.encode(input) - if sample_posterior: - z = posterior.sample() - else: - z = posterior.mode() - dec = self.decode(z) - return dec, posterior + def configure_optimizers(self) -> Any: + raise NotImplementedError() - def get_input(self, batch, k): - x = batch[k] - if len(x.shape) == 3: - x = x[..., None] - x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() - return x - def training_step(self, batch, batch_idx, optimizer_idx): - inputs = self.get_input(batch, self.image_key) - reconstructions, posterior = self(inputs) +class AutoencodingEngine(AbstractAutoencoder): + """ + Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL + (we also restore them explicitly as special cases for legacy reasons). + Regularizations such as KL or VQ are moved to the regularizer class. + """ - if optimizer_idx == 0: - # train encoder+decoder+logvar - aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") - self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) - self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) - return aeloss + def __init__( + self, + *args, + encoder_config: Dict, + decoder_config: Dict, + regularizer_config: Dict, + **kwargs, + ): + super().__init__(*args, **kwargs) - if optimizer_idx == 1: - # train the discriminator - discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") - - self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) - self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) - return discloss - - def validation_step(self, batch, batch_idx): - log_dict = self._validation_step(batch, batch_idx) - with self.ema_scope(): - log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") - return log_dict - - def _validation_step(self, batch, batch_idx, postfix=""): - inputs = self.get_input(batch, self.image_key) - reconstructions, posterior = self(inputs) - aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, - last_layer=self.get_last_layer(), split="val"+postfix) - - discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, - last_layer=self.get_last_layer(), split="val"+postfix) - - self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) - self.log_dict(log_dict_ae) - self.log_dict(log_dict_disc) - return self.log_dict - - def configure_optimizers(self): - lr = self.learning_rate - ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( - self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) - if self.learn_logvar: - print(f"{self.__class__.__name__}: Learning logvar") - ae_params_list.append(self.loss.logvar) - opt_ae = torch.optim.Adam(ae_params_list, - lr=lr, betas=(0.5, 0.9)) - opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), - lr=lr, betas=(0.5, 0.9)) - return [opt_ae, opt_disc], [] + self.encoder: torch.nn.Module = instantiate_from_config(encoder_config) + self.decoder: torch.nn.Module = instantiate_from_config(decoder_config) + self.regularization: AbstractRegularizer = instantiate_from_config( + regularizer_config + ) def get_last_layer(self): - return self.decoder.conv_out.weight + return self.decoder.get_last_layer() - @torch.no_grad() - def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): - log = dict() - x = self.get_input(batch, self.image_key) - x = x.to(self.device) - if not only_inputs: - xrec, posterior = self(x) - if x.shape[1] > 3: - # colorize with random projection - assert xrec.shape[1] > 3 - x = self.to_rgb(x) - xrec = self.to_rgb(xrec) - log["samples"] = self.decode(torch.randn_like(posterior.sample())) - log["reconstructions"] = xrec - if log_ema or self.use_ema: - with self.ema_scope(): - xrec_ema, posterior_ema = self(x) - if x.shape[1] > 3: - # colorize with random projection - assert xrec_ema.shape[1] > 3 - xrec_ema = self.to_rgb(xrec_ema) - log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) - log["reconstructions_ema"] = xrec_ema - log["inputs"] = x - return log + def encode( + self, + x: torch.Tensor, + return_reg_log: bool = False, + unregularized: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: + z = self.encoder(x) + if unregularized: + return z, dict() + z, reg_log = self.regularization(z) + if return_reg_log: + return z, reg_log + return z - def to_rgb(self, x): - assert self.image_key == "segmentation" - if not hasattr(self, "colorize"): - self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) - x = F.conv2d(x, weight=self.colorize) - x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: + x = self.decoder(z, **kwargs) return x + def forward( + self, x: torch.Tensor, **additional_decode_kwargs + ) -> Tuple[torch.Tensor, torch.Tensor, dict]: + z, reg_log = self.encode(x, return_reg_log=True) + dec = self.decode(z, **additional_decode_kwargs) + return z, dec, reg_log -class IdentityFirstStage(torch.nn.Module): - def __init__(self, *args, vq_interface=False, **kwargs): - self.vq_interface = vq_interface - super().__init__() - def encode(self, x, *args, **kwargs): - return x +class AutoencodingEngineLegacy(AutoencodingEngine): + def __init__(self, embed_dim: int, **kwargs): + self.max_batch_size = kwargs.pop("max_batch_size", None) + ddconfig = kwargs.pop("ddconfig") + super().__init__( + encoder_config={ + "target": "comfy.ldm.modules.diffusionmodules.model.Encoder", + "params": ddconfig, + }, + decoder_config={ + "target": "comfy.ldm.modules.diffusionmodules.model.Decoder", + "params": ddconfig, + }, + **kwargs, + ) + self.quant_conv = torch.nn.Conv2d( + (1 + ddconfig["double_z"]) * ddconfig["z_channels"], + (1 + ddconfig["double_z"]) * embed_dim, + 1, + ) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim - def decode(self, x, *args, **kwargs): - return x + def get_autoencoder_params(self) -> list: + params = super().get_autoencoder_params() + return params - def quantize(self, x, *args, **kwargs): - if self.vq_interface: - return x, None, [None, None, None] - return x + def encode( + self, x: torch.Tensor, return_reg_log: bool = False + ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: + if self.max_batch_size is None: + z = self.encoder(x) + z = self.quant_conv(z) + else: + N = x.shape[0] + bs = self.max_batch_size + n_batches = int(math.ceil(N / bs)) + z = list() + for i_batch in range(n_batches): + z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs]) + z_batch = self.quant_conv(z_batch) + z.append(z_batch) + z = torch.cat(z, 0) - def forward(self, x, *args, **kwargs): - return x + z, reg_log = self.regularization(z) + if return_reg_log: + return z, reg_log + return z + def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor: + if self.max_batch_size is None: + dec = self.post_quant_conv(z) + dec = self.decoder(dec, **decoder_kwargs) + else: + N = z.shape[0] + bs = self.max_batch_size + n_batches = int(math.ceil(N / bs)) + dec = list() + for i_batch in range(n_batches): + dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs]) + dec_batch = self.decoder(dec_batch, **decoder_kwargs) + dec.append(dec_batch) + dec = torch.cat(dec, 0) + + return dec + + +class AutoencoderKL(AutoencodingEngineLegacy): + def __init__(self, **kwargs): + if "lossconfig" in kwargs: + kwargs["loss_config"] = kwargs.pop("lossconfig") + super().__init__( + regularizer_config={ + "target": ( + "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer" + ) + }, + **kwargs, + ) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 852d367c9..f23417fd2 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -541,7 +541,10 @@ class Decoder(nn.Module): def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, - attn_type="vanilla", **ignorekwargs): + conv_out_op=comfy.ops.Conv2d, + resnet_op=ResnetBlock, + attn_op=AttnBlock, + **ignorekwargs): super().__init__() if use_linear_attn: attn_type = "linear" self.ch = ch @@ -570,12 +573,12 @@ class Decoder(nn.Module): # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, + self.mid.block_1 = resnet_op(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, + self.mid.attn_1 = attn_op(block_in) + self.mid.block_2 = resnet_op(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) @@ -587,13 +590,13 @@ class Decoder(nn.Module): attn = nn.ModuleList() block_out = ch*ch_mult[i_level] for i_block in range(self.num_res_blocks+1): - block.append(ResnetBlock(in_channels=block_in, + block.append(resnet_op(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) + attn.append(attn_op(block_in)) up = nn.Module() up.block = block up.attn = attn @@ -604,13 +607,13 @@ class Decoder(nn.Module): # end self.norm_out = Normalize(block_in) - self.conv_out = comfy.ops.Conv2d(block_in, + self.conv_out = conv_out_op(block_in, out_ch, kernel_size=3, stride=1, padding=1) - def forward(self, z): + def forward(self, z, **kwargs): #assert z.shape[1:] == self.z_shape[1:] self.last_z_shape = z.shape @@ -621,16 +624,16 @@ class Decoder(nn.Module): h = self.conv_in(z) # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) + h = self.mid.block_1(h, temb, **kwargs) + h = self.mid.attn_1(h, **kwargs) + h = self.mid.block_2(h, temb, **kwargs) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks+1): - h = self.up[i_level].block[i_block](h, temb) + h = self.up[i_level].block[i_block](h, temb, **kwargs) if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) + h = self.up[i_level].attn[i_block](h, **kwargs) if i_level != 0: h = self.up[i_level].upsample(h) @@ -640,7 +643,7 @@ class Decoder(nn.Module): h = self.norm_out(h) h = nonlinearity(h) - h = self.conv_out(h) + h = self.conv_out(h, **kwargs) if self.tanh_out: h = torch.tanh(h) return h diff --git a/comfy/sd.py b/comfy/sd.py index fd8b94df8..48ee5721b 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -4,7 +4,7 @@ import math from comfy import model_management from .ldm.util import instantiate_from_config -from .ldm.models.autoencoder import AutoencoderKL +from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine import yaml import comfy.utils @@ -140,21 +140,24 @@ class CLIP: return self.patcher.get_key_patches() class VAE: - def __init__(self, ckpt_path=None, device=None, config=None): + def __init__(self, sd=None, device=None, config=None): + if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format + sd = diffusers_convert.convert_vae_state_dict(sd) + if config is None: #default SD1.x/SD2.x VAE parameters ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} - self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss") + self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) else: self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() - if ckpt_path is not None: - sd = comfy.utils.load_torch_file(ckpt_path) - if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format - sd = diffusers_convert.convert_vae_state_dict(sd) - m, u = self.first_stage_model.load_state_dict(sd, strict=False) - if len(m) > 0: - print("Missing VAE keys", m) + + m, u = self.first_stage_model.load_state_dict(sd, strict=False) + if len(m) > 0: + print("Missing VAE keys", m) + + if len(u) > 0: + print("Leftover VAE keys", u) if device is None: device = model_management.vae_device() @@ -183,7 +186,7 @@ class VAE: steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = comfy.utils.ProgressBar(steps) - encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).sample().float() + encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float() samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) @@ -229,7 +232,7 @@ class VAE: samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu") for x in range(0, pixel_samples.shape[0], batch_number): pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device) - samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu().float() + samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float() except model_management.OOM_EXCEPTION as e: print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") @@ -375,10 +378,8 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl model.load_model_weights(state_dict, "model.diffusion_model.") if output_vae: - w = WeightsLoader() - vae = VAE(config=vae_config) - w.first_stage_model = vae.first_stage_model - load_model_weights(w, state_dict) + vae_sd = comfy.utils.state_dict_prefix_replace(state_dict, {"first_stage_model.": ""}, filter_keys=True) + vae = VAE(sd=vae_sd, config=vae_config) if output_clip: w = WeightsLoader() @@ -427,10 +428,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o model.load_model_weights(sd, "model.diffusion_model.") if output_vae: - vae = VAE() - w = WeightsLoader() - w.first_stage_model = vae.first_stage_model - load_model_weights(w, sd) + vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True) + vae = VAE(sd=vae_sd) if output_clip: w = WeightsLoader() diff --git a/comfy/utils.py b/comfy/utils.py index df016ef9e..a1807aa1d 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -47,12 +47,17 @@ def state_dict_key_replace(state_dict, keys_to_replace): state_dict[keys_to_replace[x]] = state_dict.pop(x) return state_dict -def state_dict_prefix_replace(state_dict, replace_prefix): +def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False): + if filter_keys: + out = {} + else: + out = state_dict for rp in replace_prefix: replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys()))) for x in replace: - state_dict[x[1]] = state_dict.pop(x[0]) - return state_dict + w = state_dict.pop(x[0]) + out[x[1]] = w + return out def transformers_convert(sd, prefix_from, prefix_to, number): diff --git a/nodes.py b/nodes.py index d0e9ffcc5..0dbc2be32 100644 --- a/nodes.py +++ b/nodes.py @@ -584,7 +584,8 @@ class VAELoader: #TODO: scale factor? def load_vae(self, vae_name): vae_path = folder_paths.get_full_path("vae", vae_name) - vae = comfy.sd.VAE(ckpt_path=vae_path) + sd = comfy.utils.load_torch_file(vae_path) + vae = comfy.sd.VAE(sd=sd) return (vae,) class ControlNetLoader: