From 33528f31beb0ec3d1c494057d7a997a7fb6c6078 Mon Sep 17 00:00:00 2001 From: kabachuha Date: Tue, 25 Mar 2025 20:41:23 +0300 Subject: [PATCH 1/4] wrap flux in diffusion model patcher --- comfy/ldm/flux/model.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index ef4ba4106..039ca6c8a 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -6,6 +6,7 @@ import torch from torch import Tensor, nn from einops import rearrange, repeat import comfy.ldm.common_dit +import comfy.patcher_extension from .layers import ( DoubleStreamBlock, @@ -188,7 +189,7 @@ class Flux(nn.Module): img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img - def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs): + def _forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs): bs, c, h, w = x.shape patch_size = self.patch_size x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) @@ -205,3 +206,10 @@ class Flux(nn.Module): txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w] + + def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self.forward_lame, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, y, guidance=guidance, control=control, transformer_options=transformer_options, **kwargs) From 855ea059a4358ef26866fb313a50220d0c58e0b0 Mon Sep 17 00:00:00 2001 From: kabachuha Date: Tue, 25 Mar 2025 20:48:05 +0300 Subject: [PATCH 2/4] pass transformer options into stream blocks --- comfy/ldm/flux/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 039ca6c8a..3638a668a 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -131,7 +131,7 @@ class Flux(nn.Module): txt=args["txt"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"),transformer_options=transformer_options) return out out = blocks_replace[("double_block", i)]({"img": img, @@ -147,7 +147,7 @@ class Flux(nn.Module): txt=txt, vec=vec, pe=pe, - attn_mask=attn_mask) + attn_mask=attn_mask, transformer_options=transformer_options) if control is not None: # Controlnet control_i = control.get("input") @@ -165,7 +165,7 @@ class Flux(nn.Module): out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"),transformer_options=transformer_options) return out out = blocks_replace[("single_block", i)]({"img": img, @@ -175,7 +175,7 @@ class Flux(nn.Module): {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) if control is not None: # Controlnet control_o = control.get("output") From 38e97ef2fd1d496c0c24d6f39fa20b561417a9f1 Mon Sep 17 00:00:00 2001 From: kabachuha Date: Tue, 25 Mar 2025 20:50:50 +0300 Subject: [PATCH 3/4] add transf options argument to stream blocks --- comfy/ldm/flux/layers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 76af967e6..6b269125c 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -159,7 +159,8 @@ class DoubleStreamBlock(nn.Module): ) self.flipped_img_txt = flipped_img_txt - def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None): + def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}): + img_mod1, img_mod2 = self.img_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec) @@ -244,7 +245,7 @@ class SingleStreamBlock(nn.Module): self.mlp_act = nn.GELU(approximate="tanh") self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) - def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor: + def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor: mod, _ = self.modulation(vec) qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) From c9f9935d8a80b71a2ab5dfd33b5bcc6e39bc32be Mon Sep 17 00:00:00 2001 From: kabachuha Date: Tue, 25 Mar 2025 20:59:27 +0300 Subject: [PATCH 4/4] rename original forward to _forward --- comfy/ldm/flux/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 3638a668a..5026f6f65 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -209,7 +209,7 @@ class Flux(nn.Module): def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( - self.forward_lame, + self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) ).execute(x, timestep, context, y, guidance=guidance, control=control, transformer_options=transformer_options, **kwargs)