From 3c3988df45826808210b9964dbaf85055f80e695 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 15 Mar 2025 08:26:36 -0400 Subject: [PATCH 1/6] Show a better error message if the VAE is invalid. --- comfy/sd.py | 8 ++++++++ nodes.py | 1 + 2 files changed, 9 insertions(+) diff --git a/comfy/sd.py b/comfy/sd.py index fd98585a..51fe425a 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -440,6 +440,10 @@ class VAE: self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) + def throw_exception_if_invalid(self): + if self.first_stage_model is None: + raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.") + def vae_encode_crop_pixels(self, pixels): downscale_ratio = self.spacial_compression_encode() @@ -495,6 +499,7 @@ class VAE: return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) def decode(self, samples_in): + self.throw_exception_if_invalid() pixel_samples = None try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) @@ -525,6 +530,7 @@ class VAE: return pixel_samples def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): + self.throw_exception_if_invalid() memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile model_management.load_models_gpu([self.patcher], memory_required=memory_used) dims = samples.ndim - 2 @@ -553,6 +559,7 @@ class VAE: return output.movedim(1, -1) def encode(self, pixel_samples): + self.throw_exception_if_invalid() pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = pixel_samples.movedim(-1, 1) if self.latent_dim == 3 and pixel_samples.ndim < 5: @@ -585,6 +592,7 @@ class VAE: return samples def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): + self.throw_exception_if_invalid() pixel_samples = self.vae_encode_crop_pixels(pixel_samples) dims = self.latent_dim pixel_samples = pixel_samples.movedim(-1, 1) diff --git a/nodes.py b/nodes.py index 63791e20..71d1b8dd 100644 --- a/nodes.py +++ b/nodes.py @@ -770,6 +770,7 @@ class VAELoader: vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) sd = comfy.utils.load_torch_file(vae_path) vae = comfy.sd.VAE(sd=sd) + vae.throw_exception_if_invalid() return (vae,) class ControlNetLoader: From 55a1b09ddc9f81b6406710e69df3ec2eaa4880ac Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 15 Mar 2025 08:27:49 -0400 Subject: [PATCH 2/6] Allow loading diffusion model files with the "Load Checkpoint" node. --- comfy/sd.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index 51fe425a..3d72a04d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -907,7 +907,12 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata) if model_config is None: - return None + logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.") + diffusion_model = load_diffusion_model_state_dict(sd, model_options={}) + if diffusion_model is None: + return None + return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used' + unet_weight_dtype = list(model_config.supported_inference_dtypes) if model_config.scaled_fp8 is not None: From fd5297131f81d03966adf3f2250d4502f34a8828 Mon Sep 17 00:00:00 2001 From: chaObserv <154517000+chaObserv@users.noreply.github.com> Date: Sun, 16 Mar 2025 18:02:25 +0800 Subject: [PATCH 3/6] Guard the edge cases of noise term in er_sde (#7265) --- comfy/k_diffusion/sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index a28a30ac..5b8d8000 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1419,6 +1419,6 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None old_denoised_d = denoised_d if s_noise != 0 and sigmas[i + 1] > 0: - x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt() + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0) old_denoised = denoised return x From 2e24a15905122b4f310ac590265cea83aac96b15 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 16 Mar 2025 05:02:45 -0500 Subject: [PATCH 4/6] Call unpatch_hooks at the start of ModelPatcher.partially_unload (#7253) * Call unpatch_hooks at the start of ModelPatcher.partially_unload * Only call unpatch_hooks in partially_unload if lowvram is possible --- comfy/model_patcher.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index e291158c..b7cb12df 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -747,6 +747,7 @@ class ModelPatcher: def partially_unload(self, device_to, memory_to_free=0): with self.use_ejected(): + hooks_unpatched = False memory_freed = 0 patch_counter = 0 unload_list = self._load_list() @@ -770,6 +771,10 @@ class ModelPatcher: move_weight = False break + if not hooks_unpatched: + self.unpatch_hooks() + hooks_unpatched = True + if bk.inplace_update: comfy.utils.copy_to_param(self.model, key, bk.weight) else: From e8e990d6b8b5c813c87d1aeaed3e5110c7aba166 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 16 Mar 2025 06:29:12 -0400 Subject: [PATCH 5/6] Cleanup code. --- comfy/ldm/flux/math.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 36b67931..c0cbd291 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -10,8 +10,8 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor: q_shape = q.shape k_shape = k.shape - q = q.float().reshape(*q.shape[:-1], -1, 1, 2) - k = k.float().reshape(*k.shape[:-1], -1, 1, 2) + q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2) + k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2) q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v) k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v) @@ -36,8 +36,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): - xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) - xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2) xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) From 6dc7b0bfe3cd44302444f0f34db0e62b86764482 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 17 Mar 2025 05:53:54 -0400 Subject: [PATCH 6/6] Add support for giant dinov2 image encoder. --- comfy/clip_vision.py | 11 +- comfy/image_encoders/dino2.py | 141 ++++++++++++++++++++++++++ comfy/image_encoders/dino2_giant.json | 21 ++++ 3 files changed, 172 insertions(+), 1 deletion(-) create mode 100644 comfy/image_encoders/dino2.py create mode 100644 comfy/image_encoders/dino2_giant.json diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 297b3bca..25baf5ca 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -9,6 +9,7 @@ import comfy.model_patcher import comfy.model_management import comfy.utils import comfy.clip_model +import comfy.image_encoders.dino2 class Output: def __getitem__(self, key): @@ -34,6 +35,11 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s image = torch.clip((255. * image), 0, 255).round() / 255.0 return (image - mean.view([3,1,1])) / std.view([3,1,1]) +IMAGE_ENCODERS = { + "clip_vision": comfy.clip_model.CLIPVisionModelProjection, + "dinov2": comfy.image_encoders.dino2.Dinov2Model, +} + class ClipVisionModel(): def __init__(self, json_config): with open(json_config) as f: @@ -42,10 +48,11 @@ class ClipVisionModel(): self.image_size = config.get("image_size", 224) self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073]) self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711]) + model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision")) self.load_device = comfy.model_management.text_encoder_device() offload_device = comfy.model_management.text_encoder_offload_device() self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) - self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast) + self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast) self.model.eval() self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) @@ -111,6 +118,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json") else: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") + elif "embeddings.patch_embeddings.projection.weight" in sd: + json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json") else: return None diff --git a/comfy/image_encoders/dino2.py b/comfy/image_encoders/dino2.py new file mode 100644 index 00000000..130ed6fd --- /dev/null +++ b/comfy/image_encoders/dino2.py @@ -0,0 +1,141 @@ +import torch +from comfy.text_encoders.bert import BertAttention +import comfy.model_management +from comfy.ldm.modules.attention import optimized_attention_for_device + + +class Dino2AttentionOutput(torch.nn.Module): + def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations): + super().__init__() + self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device) + + def forward(self, x): + return self.dense(x) + + +class Dino2AttentionBlock(torch.nn.Module): + def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations): + super().__init__() + self.attention = BertAttention(embed_dim, heads, dtype, device, operations) + self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations) + + def forward(self, x, mask, optimized_attention): + return self.output(self.attention(x, mask, optimized_attention)) + + +class LayerScale(torch.nn.Module): + def __init__(self, dim, dtype, device, operations): + super().__init__() + self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) + + def forward(self, x): + return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype) + + +class SwiGLUFFN(torch.nn.Module): + def __init__(self, dim, dtype, device, operations): + super().__init__() + in_features = out_features = dim + hidden_features = int(dim * 4) + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + + self.weights_in = operations.Linear(in_features, 2 * hidden_features, bias=True, device=device, dtype=dtype) + self.weights_out = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype) + + def forward(self, x): + x = self.weights_in(x) + x1, x2 = x.chunk(2, dim=-1) + x = torch.nn.functional.silu(x1) * x2 + return self.weights_out(x) + + +class Dino2Block(torch.nn.Module): + def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations): + super().__init__() + self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations) + self.layer_scale1 = LayerScale(dim, dtype, device, operations) + self.layer_scale2 = LayerScale(dim, dtype, device, operations) + self.mlp = SwiGLUFFN(dim, dtype, device, operations) + self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) + self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) + + def forward(self, x, optimized_attention): + x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention)) + x = x + self.layer_scale2(self.mlp(self.norm2(x))) + return x + + +class Dino2Encoder(torch.nn.Module): + def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations): + super().__init__() + self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)]) + + def forward(self, x, intermediate_output=None): + optimized_attention = optimized_attention_for_device(x.device, False, small_input=True) + + if intermediate_output is not None: + if intermediate_output < 0: + intermediate_output = len(self.layer) + intermediate_output + + intermediate = None + for i, l in enumerate(self.layer): + x = l(x, optimized_attention) + if i == intermediate_output: + intermediate = x.clone() + return x, intermediate + + +class Dino2PatchEmbeddings(torch.nn.Module): + def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None): + super().__init__() + self.projection = operations.Conv2d( + in_channels=num_channels, + out_channels=dim, + kernel_size=patch_size, + stride=patch_size, + bias=True, + dtype=dtype, + device=device + ) + + def forward(self, pixel_values): + return self.projection(pixel_values).flatten(2).transpose(1, 2) + + +class Dino2Embeddings(torch.nn.Module): + def __init__(self, dim, dtype, device, operations): + super().__init__() + patch_size = 14 + image_size = 518 + + self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations) + self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device)) + self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device)) + self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device)) + + def forward(self, pixel_values): + x = self.patch_embeddings(pixel_values) + # TODO: mask_token? + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype) + return x + + +class Dinov2Model(torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + num_layers = config_dict["num_hidden_layers"] + dim = config_dict["hidden_size"] + heads = config_dict["num_attention_heads"] + layer_norm_eps = config_dict["layer_norm_eps"] + + self.embeddings = Dino2Embeddings(dim, dtype, device, operations) + self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations) + self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) + + def forward(self, pixel_values, attention_mask=None, intermediate_output=None): + x = self.embeddings(pixel_values) + x, i = self.encoder(x, intermediate_output=intermediate_output) + x = self.layernorm(x) + pooled_output = x[:, 0, :] + return x, i, pooled_output, None diff --git a/comfy/image_encoders/dino2_giant.json b/comfy/image_encoders/dino2_giant.json new file mode 100644 index 00000000..f6076a4d --- /dev/null +++ b/comfy/image_encoders/dino2_giant.json @@ -0,0 +1,21 @@ +{ + "attention_probs_dropout_prob": 0.0, + "drop_path_rate": 0.0, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "hidden_size": 1536, + "image_size": 518, + "initializer_range": 0.02, + "layer_norm_eps": 1e-06, + "layerscale_value": 1.0, + "mlp_ratio": 4, + "model_type": "dinov2", + "num_attention_heads": 24, + "num_channels": 3, + "num_hidden_layers": 40, + "patch_size": 14, + "qkv_bias": true, + "use_swiglu_ffn": true, + "image_mean": [0.485, 0.456, 0.406], + "image_std": [0.229, 0.224, 0.225] +}