diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 76af967e6..6b269125c 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -159,7 +159,8 @@ class DoubleStreamBlock(nn.Module): ) self.flipped_img_txt = flipped_img_txt - def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None): + def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}): + img_mod1, img_mod2 = self.img_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec) @@ -244,7 +245,7 @@ class SingleStreamBlock(nn.Module): self.mlp_act = nn.GELU(approximate="tanh") self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) - def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor: + def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor: mod, _ = self.modulation(vec) qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index ef4ba4106..5026f6f65 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -6,6 +6,7 @@ import torch from torch import Tensor, nn from einops import rearrange, repeat import comfy.ldm.common_dit +import comfy.patcher_extension from .layers import ( DoubleStreamBlock, @@ -130,7 +131,7 @@ class Flux(nn.Module): txt=args["txt"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"),transformer_options=transformer_options) return out out = blocks_replace[("double_block", i)]({"img": img, @@ -146,7 +147,7 @@ class Flux(nn.Module): txt=txt, vec=vec, pe=pe, - attn_mask=attn_mask) + attn_mask=attn_mask, transformer_options=transformer_options) if control is not None: # Controlnet control_i = control.get("input") @@ -164,7 +165,7 @@ class Flux(nn.Module): out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], - attn_mask=args.get("attn_mask")) + attn_mask=args.get("attn_mask"),transformer_options=transformer_options) return out out = blocks_replace[("single_block", i)]({"img": img, @@ -174,7 +175,7 @@ class Flux(nn.Module): {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) if control is not None: # Controlnet control_o = control.get("output") @@ -188,7 +189,7 @@ class Flux(nn.Module): img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img - def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs): + def _forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs): bs, c, h, w = x.shape patch_size = self.patch_size x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) @@ -205,3 +206,10 @@ class Flux(nn.Module): txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w] + + def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, y, guidance=guidance, control=control, transformer_options=transformer_options, **kwargs)