Add a way to patch blocks in SD3.

This commit is contained in:
comfyanonymous 2024-10-29 00:48:32 -04:00
parent 13b0ff8a6f
commit 30c0c81351

View File

@ -949,7 +949,9 @@ class MMDiT(nn.Module):
c_mod: torch.Tensor, c_mod: torch.Tensor,
context: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None,
control = None, control = None,
transformer_options = {},
) -> torch.Tensor: ) -> torch.Tensor:
patches_replace = transformer_options.get("patches_replace", {})
if self.register_length > 0: if self.register_length > 0:
context = torch.cat( context = torch.cat(
( (
@ -961,8 +963,19 @@ class MMDiT(nn.Module):
# context is B, L', D # context is B, L', D
# x is B, L, D # x is B, L, D
blocks_replace = patches_replace.get("dit", {})
blocks = len(self.joint_blocks) blocks = len(self.joint_blocks)
for i in range(blocks): for i in range(blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
context = out["txt"]
x = out["img"]
else:
context, x = self.joint_blocks[i]( context, x = self.joint_blocks[i](
context, context,
x, x,
@ -986,6 +999,7 @@ class MMDiT(nn.Module):
y: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None,
control = None, control = None,
transformer_options = {},
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Forward pass of DiT. Forward pass of DiT.
@ -1007,7 +1021,7 @@ class MMDiT(nn.Module):
if context is not None: if context is not None:
context = self.context_embedder(context) context = self.context_embedder(context)
x = self.forward_core_with_concat(x, c, context, control) x = self.forward_core_with_concat(x, c, context, control, transformer_options)
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W) x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
return x[:,:,:hw[-2],:hw[-1]] return x[:,:,:hw[-2],:hw[-1]]
@ -1021,7 +1035,8 @@ class OpenAISignatureMMDITWrapper(MMDiT):
context: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None,
control = None, control = None,
transformer_options = {},
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
return super().forward(x, timesteps, context=context, y=y, control=control) return super().forward(x, timesteps, context=context, y=y, control=control, transformer_options=transformer_options)