From 30c0c81351a14e6820c98ee22c24f3edc9062e55 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 29 Oct 2024 00:48:32 -0400 Subject: [PATCH] Add a way to patch blocks in SD3. --- comfy/ldm/modules/diffusionmodules/mmdit.py | 31 +++++++++++++++------ 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index 43a269fa..6f8f506c 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -949,7 +949,9 @@ class MMDiT(nn.Module): c_mod: torch.Tensor, context: Optional[torch.Tensor] = None, control = None, + transformer_options = {}, ) -> torch.Tensor: + patches_replace = transformer_options.get("patches_replace", {}) if self.register_length > 0: context = torch.cat( ( @@ -961,14 +963,25 @@ class MMDiT(nn.Module): # context is B, L', D # x is B, L, D + blocks_replace = patches_replace.get("dit", {}) blocks = len(self.joint_blocks) for i in range(blocks): - context, x = self.joint_blocks[i]( - context, - x, - c=c_mod, - use_checkpoint=self.use_checkpoint, - ) + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"]) + return out + + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap}) + context = out["txt"] + x = out["img"] + else: + context, x = self.joint_blocks[i]( + context, + x, + c=c_mod, + use_checkpoint=self.use_checkpoint, + ) if control is not None: control_o = control.get("output") if i < len(control_o): @@ -986,6 +999,7 @@ class MMDiT(nn.Module): y: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None, control = None, + transformer_options = {}, ) -> torch.Tensor: """ Forward pass of DiT. @@ -1007,7 +1021,7 @@ class MMDiT(nn.Module): if context is not None: context = self.context_embedder(context) - x = self.forward_core_with_concat(x, c, context, control) + x = self.forward_core_with_concat(x, c, context, control, transformer_options) x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W) return x[:,:,:hw[-2],:hw[-1]] @@ -1021,7 +1035,8 @@ class OpenAISignatureMMDITWrapper(MMDiT): context: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, control = None, + transformer_options = {}, **kwargs, ) -> torch.Tensor: - return super().forward(x, timesteps, context=context, y=y, control=control) + return super().forward(x, timesteps, context=context, y=y, control=control, transformer_options=transformer_options)