wrap flux in diffusion model patcher

This commit is contained in:
kabachuha 2025-03-25 20:41:23 +03:00
parent 8edc1f44c1
commit 33528f31be

View File

@ -6,6 +6,7 @@ import torch
from torch import Tensor, nn
from einops import rearrange, repeat
import comfy.ldm.common_dit
import comfy.patcher_extension
from .layers import (
DoubleStreamBlock,
@ -188,7 +189,7 @@ class Flux(nn.Module):
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs):
def _forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape
patch_size = self.patch_size
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
@ -205,3 +206,10 @@ class Flux(nn.Module):
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self.forward_lame,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, guidance=guidance, control=control, transformer_options=transformer_options, **kwargs)