From 7aceb9f91c1c2b860c1a65ac93a64b3bad575794 Mon Sep 17 00:00:00 2001 From: FeepingCreature <540727+FeepingCreature@users.noreply.github.com> Date: Fri, 14 Mar 2025 08:22:41 +0100 Subject: [PATCH 1/8] Add --use-flash-attention flag. (#7223) * Add --use-flash-attention flag. This is useful on AMD systems, as FA builds are still 10% faster than Pytorch cross-attention. --- comfy/cli_args.py | 1 + comfy/ldm/modules/attention.py | 60 ++++++++++++++++++++++++++++++++++ comfy/model_management.py | 3 ++ 3 files changed, 64 insertions(+) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index a864205b..91c1fe70 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -106,6 +106,7 @@ attn_group.add_argument("--use-split-cross-attention", action="store_true", help attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.") +attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 2758f950..3e5089a6 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -24,6 +24,13 @@ if model_management.sage_attention_enabled(): logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention") exit(-1) +if model_management.flash_attention_enabled(): + try: + from flash_attn import flash_attn_func + except ModuleNotFoundError: + logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn") + exit(-1) + from comfy.cli_args import args import comfy.ops ops = comfy.ops.disable_weight_init @@ -496,6 +503,56 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= return out +@torch.library.custom_op("flash_attention::flash_attn", mutates_args=()) +def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: + return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal) + + +@flash_attn_wrapper.register_fake +def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False): + # Output shape is the same as q + return q.new_empty(q.shape) + + +def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): + if skip_reshape: + b, _, _, dim_head = q.shape + else: + b, _, dim_head = q.shape + dim_head //= heads + q, k, v = map( + lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), + (q, k, v), + ) + + if mask is not None: + # add a batch dimension if there isn't already one + if mask.ndim == 2: + mask = mask.unsqueeze(0) + # add a heads dimension if there isn't already one + if mask.ndim == 3: + mask = mask.unsqueeze(1) + + try: + assert mask is None + out = flash_attn_wrapper( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + dropout_p=0.0, + causal=False, + ).transpose(1, 2) + except Exception as e: + logging.warning(f"Flash Attention failed, using default SDPA: {e}") + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + if not skip_output_reshape: + out = ( + out.transpose(1, 2).reshape(b, -1, heads * dim_head) + ) + return out + + optimized_attention = attention_basic if model_management.sage_attention_enabled(): @@ -504,6 +561,9 @@ if model_management.sage_attention_enabled(): elif model_management.xformers_enabled(): logging.info("Using xformers attention") optimized_attention = attention_xformers +elif model_management.flash_attention_enabled(): + logging.info("Using Flash Attention") + optimized_attention = attention_flash elif model_management.pytorch_attention_enabled(): logging.info("Using pytorch attention") optimized_attention = attention_pytorch diff --git a/comfy/model_management.py b/comfy/model_management.py index b6f4e2d1..2a9b022b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -930,6 +930,9 @@ def cast_to_device(tensor, device, dtype, copy=False): def sage_attention_enabled(): return args.use_sage_attention +def flash_attention_enabled(): + return args.use_flash_attention + def xformers_enabled(): global directml_enabled global cpu_state From 9c98c6358be2c7896de1547490bc87c9ad7a1ecb Mon Sep 17 00:00:00 2001 From: FeepingCreature <540727+FeepingCreature@users.noreply.github.com> Date: Fri, 14 Mar 2025 14:51:26 +0100 Subject: [PATCH 2/8] Tolerate missing `@torch.library.custom_op` (#7234) This can happen on Pytorch versions older than 2.4. --- comfy/ldm/modules/attention.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 3e5089a6..7908d131 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -503,16 +503,23 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= return out -@torch.library.custom_op("flash_attention::flash_attn", mutates_args=()) -def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: - return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal) +try: + @torch.library.custom_op("flash_attention::flash_attn", mutates_args=()) + def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: + return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal) -@flash_attn_wrapper.register_fake -def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False): - # Output shape is the same as q - return q.new_empty(q.shape) + @flash_attn_wrapper.register_fake + def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False): + # Output shape is the same as q + return q.new_empty(q.shape) +except AttributeError as error: + FLASH_ATTN_ERROR = error + + def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: + assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}" def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): From 6a0daa79b6a8ed99b6859fb1c143081eef9e7aa0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 14 Mar 2025 10:55:19 -0400 Subject: [PATCH 3/8] Make the SkipLayerGuidanceDIT node work on WAN. --- comfy/ldm/wan/model.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index e78d846b..9966b20a 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -384,6 +384,7 @@ class WanModel(torch.nn.Module): context, clip_fea=None, freqs=None, + transformer_options={}, ): r""" Forward pass through the diffusion model @@ -429,8 +430,18 @@ class WanModel(torch.nn.Module): freqs=freqs, context=context) - for block in self.blocks: - x = block(x, **kwargs) + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"]) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context) # head x = self.head(x, e) @@ -439,7 +450,7 @@ class WanModel(torch.nn.Module): x = self.unpatchify(x, grid_sizes) return x - def forward(self, x, timestep, context, clip_fea=None, **kwargs): + def forward(self, x, timestep, context, clip_fea=None, transformer_options={},**kwargs): bs, c, t, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) patch_size = self.patch_size @@ -453,7 +464,7 @@ class WanModel(torch.nn.Module): img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) freqs = self.rope_embedder(img_ids).movedim(1, 2) - return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs)[:, :, :t, :h, :w] + return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options)[:, :, :t, :h, :w] def unpatchify(self, x, grid_sizes): r""" From a2448fc52701651d183e35fbb37924b4441f7a98 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 14 Mar 2025 18:10:37 -0400 Subject: [PATCH 4/8] Remove useless code. --- comfy/ldm/wan/model.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 9966b20a..9b5e5332 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -424,12 +424,6 @@ class WanModel(torch.nn.Module): context_clip = self.img_emb(clip_fea) # bs x 257 x dim context = torch.concat([context_clip, context], dim=1) - # arguments - kwargs = dict( - e=e0, - freqs=freqs, - context=context) - patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) for i, block in enumerate(self.blocks): From c624c29d6685377faa298d4151af09e433cea875 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Fri, 14 Mar 2025 18:17:26 -0400 Subject: [PATCH 5/8] Update frontend to 1.12.9 (#7236) * Update frontend to 1.12.9 * Update requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e1316ccf..771e53c2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.11.8 +comfyui-frontend-package==1.12.11 torch torchsde torchvision From 7ebd8087ffb9c713d308ff74f1bd14f07d569bed Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Fri, 14 Mar 2025 22:38:10 -0700 Subject: [PATCH 6/8] hotfix fe (#7244) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 771e53c2..70689bc9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.12.11 +comfyui-frontend-package==1.12.14 torch torchsde torchvision From 3c3988df45826808210b9964dbaf85055f80e695 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 15 Mar 2025 08:26:36 -0400 Subject: [PATCH 7/8] Show a better error message if the VAE is invalid. --- comfy/sd.py | 8 ++++++++ nodes.py | 1 + 2 files changed, 9 insertions(+) diff --git a/comfy/sd.py b/comfy/sd.py index fd98585a..51fe425a 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -440,6 +440,10 @@ class VAE: self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) + def throw_exception_if_invalid(self): + if self.first_stage_model is None: + raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.") + def vae_encode_crop_pixels(self, pixels): downscale_ratio = self.spacial_compression_encode() @@ -495,6 +499,7 @@ class VAE: return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) def decode(self, samples_in): + self.throw_exception_if_invalid() pixel_samples = None try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) @@ -525,6 +530,7 @@ class VAE: return pixel_samples def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): + self.throw_exception_if_invalid() memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile model_management.load_models_gpu([self.patcher], memory_required=memory_used) dims = samples.ndim - 2 @@ -553,6 +559,7 @@ class VAE: return output.movedim(1, -1) def encode(self, pixel_samples): + self.throw_exception_if_invalid() pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = pixel_samples.movedim(-1, 1) if self.latent_dim == 3 and pixel_samples.ndim < 5: @@ -585,6 +592,7 @@ class VAE: return samples def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): + self.throw_exception_if_invalid() pixel_samples = self.vae_encode_crop_pixels(pixel_samples) dims = self.latent_dim pixel_samples = pixel_samples.movedim(-1, 1) diff --git a/nodes.py b/nodes.py index 63791e20..71d1b8dd 100644 --- a/nodes.py +++ b/nodes.py @@ -770,6 +770,7 @@ class VAELoader: vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) sd = comfy.utils.load_torch_file(vae_path) vae = comfy.sd.VAE(sd=sd) + vae.throw_exception_if_invalid() return (vae,) class ControlNetLoader: From 55a1b09ddc9f81b6406710e69df3ec2eaa4880ac Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 15 Mar 2025 08:27:49 -0400 Subject: [PATCH 8/8] Allow loading diffusion model files with the "Load Checkpoint" node. --- comfy/sd.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index 51fe425a..3d72a04d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -907,7 +907,12 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata) if model_config is None: - return None + logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.") + diffusion_model = load_diffusion_model_state_dict(sd, model_options={}) + if diffusion_model is None: + return None + return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used' + unet_weight_dtype = list(model_config.supported_inference_dtypes) if model_config.scaled_fp8 is not None: