mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Add a way to patch blocks in SD3.
This commit is contained in:
parent
13b0ff8a6f
commit
30c0c81351
@ -949,7 +949,9 @@ class MMDiT(nn.Module):
|
|||||||
c_mod: torch.Tensor,
|
c_mod: torch.Tensor,
|
||||||
context: Optional[torch.Tensor] = None,
|
context: Optional[torch.Tensor] = None,
|
||||||
control = None,
|
control = None,
|
||||||
|
transformer_options = {},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
if self.register_length > 0:
|
if self.register_length > 0:
|
||||||
context = torch.cat(
|
context = torch.cat(
|
||||||
(
|
(
|
||||||
@ -961,8 +963,19 @@ class MMDiT(nn.Module):
|
|||||||
|
|
||||||
# context is B, L', D
|
# context is B, L', D
|
||||||
# x is B, L, D
|
# x is B, L, D
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
blocks = len(self.joint_blocks)
|
blocks = len(self.joint_blocks)
|
||||||
for i in range(blocks):
|
for i in range(blocks):
|
||||||
|
if ("double_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
|
||||||
|
return out
|
||||||
|
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
|
||||||
|
context = out["txt"]
|
||||||
|
x = out["img"]
|
||||||
|
else:
|
||||||
context, x = self.joint_blocks[i](
|
context, x = self.joint_blocks[i](
|
||||||
context,
|
context,
|
||||||
x,
|
x,
|
||||||
@ -986,6 +999,7 @@ class MMDiT(nn.Module):
|
|||||||
y: Optional[torch.Tensor] = None,
|
y: Optional[torch.Tensor] = None,
|
||||||
context: Optional[torch.Tensor] = None,
|
context: Optional[torch.Tensor] = None,
|
||||||
control = None,
|
control = None,
|
||||||
|
transformer_options = {},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Forward pass of DiT.
|
Forward pass of DiT.
|
||||||
@ -1007,7 +1021,7 @@ class MMDiT(nn.Module):
|
|||||||
if context is not None:
|
if context is not None:
|
||||||
context = self.context_embedder(context)
|
context = self.context_embedder(context)
|
||||||
|
|
||||||
x = self.forward_core_with_concat(x, c, context, control)
|
x = self.forward_core_with_concat(x, c, context, control, transformer_options)
|
||||||
|
|
||||||
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
|
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
|
||||||
return x[:,:,:hw[-2],:hw[-1]]
|
return x[:,:,:hw[-2],:hw[-1]]
|
||||||
@ -1021,7 +1035,8 @@ class OpenAISignatureMMDITWrapper(MMDiT):
|
|||||||
context: Optional[torch.Tensor] = None,
|
context: Optional[torch.Tensor] = None,
|
||||||
y: Optional[torch.Tensor] = None,
|
y: Optional[torch.Tensor] = None,
|
||||||
control = None,
|
control = None,
|
||||||
|
transformer_options = {},
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return super().forward(x, timesteps, context=context, y=y, control=control)
|
return super().forward(x, timesteps, context=context, y=y, control=control, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user