From 13b0ff8a6fe0dfc3b32b4219d1aff1ec3e9a323a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 28 Oct 2024 21:58:52 -0400 Subject: [PATCH] Update SD3 code. --- comfy/ldm/modules/diffusionmodules/mmdit.py | 156 ++++++++++++++++---- comfy/model_detection.py | 5 + 2 files changed, 130 insertions(+), 31 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index a160b2f43..43a269fa0 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -1,6 +1,6 @@ import logging import math -from typing import Dict, Optional +from typing import Dict, Optional, List import numpy as np import torch @@ -415,6 +415,7 @@ class DismantledBlock(nn.Module): scale_mod_only: bool = False, swiglu: bool = False, qk_norm: Optional[str] = None, + x_block_self_attn: bool = False, dtype=None, device=None, operations=None, @@ -438,6 +439,24 @@ class DismantledBlock(nn.Module): device=device, operations=operations ) + if x_block_self_attn: + assert not pre_only + assert not scale_mod_only + self.x_block_self_attn = True + self.attn2 = SelfAttention( + dim=hidden_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_mode=attn_mode, + pre_only=False, + qk_norm=qk_norm, + rmsnorm=rmsnorm, + dtype=dtype, + device=device, + operations=operations + ) + else: + self.x_block_self_attn = False if not pre_only: if not rmsnorm: self.norm2 = operations.LayerNorm( @@ -464,7 +483,11 @@ class DismantledBlock(nn.Module): multiple_of=256, ) self.scale_mod_only = scale_mod_only - if not scale_mod_only: + if x_block_self_attn: + assert not pre_only + assert not scale_mod_only + n_mods = 9 + elif not scale_mod_only: n_mods = 6 if not pre_only else 2 else: n_mods = 4 if not pre_only else 1 @@ -525,14 +548,64 @@ class DismantledBlock(nn.Module): ) return x + def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + assert self.x_block_self_attn + ( + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + shift_msa2, + scale_msa2, + gate_msa2, + ) = self.adaLN_modulation(c).chunk(9, dim=1) + x_norm = self.norm1(x) + qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa)) + qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2)) + return qkv, qkv2, ( + x, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + gate_msa2, + ) + + def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2): + assert not self.pre_only + attn1 = self.attn.post_attention(attn) + attn2 = self.attn2.post_attention(attn2) + out1 = gate_msa.unsqueeze(1) * attn1 + out2 = gate_msa2.unsqueeze(1) * attn2 + x = x + out1 + x = x + out2 + x = x + gate_mlp.unsqueeze(1) * self.mlp( + modulate(self.norm2(x), shift_mlp, scale_mlp) + ) + return x + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: assert not self.pre_only - qkv, intermediates = self.pre_attention(x, c) - attn = optimized_attention( - qkv[0], qkv[1], qkv[2], - heads=self.attn.num_heads, - ) - return self.post_attention(attn, *intermediates) + if self.x_block_self_attn: + qkv, qkv2, intermediates = self.pre_attention_x(x, c) + attn, _ = optimized_attention( + qkv[0], qkv[1], qkv[2], + num_heads=self.attn.num_heads, + ) + attn2, _ = optimized_attention( + qkv2[0], qkv2[1], qkv2[2], + num_heads=self.attn2.num_heads, + ) + return self.post_attention_x(attn, attn2, *intermediates) + else: + qkv, intermediates = self.pre_attention(x, c) + attn = optimized_attention( + qkv[0], qkv[1], qkv[2], + heads=self.attn.num_heads, + ) + return self.post_attention(attn, *intermediates) def block_mixing(*args, use_checkpoint=True, **kwargs): @@ -547,7 +620,10 @@ def block_mixing(*args, use_checkpoint=True, **kwargs): def _block_mixing(context, x, context_block, x_block, c): context_qkv, context_intermediates = context_block.pre_attention(context, c) - x_qkv, x_intermediates = x_block.pre_attention(x, c) + if x_block.x_block_self_attn: + x_qkv, x_qkv2, x_intermediates = x_block.pre_attention_x(x, c) + else: + x_qkv, x_intermediates = x_block.pre_attention(x, c) o = [] for t in range(3): @@ -568,7 +644,14 @@ def _block_mixing(context, x, context_block, x_block, c): else: context = None - x = x_block.post_attention(x_attn, *x_intermediates) + if x_block.x_block_self_attn: + attn2 = optimized_attention( + x_qkv2[0], x_qkv2[1], x_qkv2[2], + heads=x_block.attn2.num_heads, + ) + x = x_block.post_attention_x(x_attn, attn2, *x_intermediates) + else: + x = x_block.post_attention(x_attn, *x_intermediates) return context, x @@ -583,8 +666,13 @@ class JointBlock(nn.Module): super().__init__() pre_only = kwargs.pop("pre_only") qk_norm = kwargs.pop("qk_norm", None) + x_block_self_attn = kwargs.pop("x_block_self_attn", False) self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs) - self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs) + self.x_block = DismantledBlock(*args, + pre_only=False, + qk_norm=qk_norm, + x_block_self_attn=x_block_self_attn, + **kwargs) def forward(self, *args, **kwargs): return block_mixing( @@ -699,9 +787,12 @@ class MMDiT(nn.Module): qk_norm: Optional[str] = None, qkv_bias: bool = True, context_processor_layers = None, + x_block_self_attn: bool = False, + x_block_self_attn_layers: Optional[List[int]] = [], context_size = 4096, num_blocks = None, final_layer = True, + skip_blocks = False, dtype = None, #TODO device = None, operations = None, @@ -716,6 +807,7 @@ class MMDiT(nn.Module): self.pos_embed_scaling_factor = pos_embed_scaling_factor self.pos_embed_offset = pos_embed_offset self.pos_embed_max_size = pos_embed_max_size + self.x_block_self_attn_layers = x_block_self_attn_layers # hidden_size = default(hidden_size, 64 * depth) # num_heads = default(num_heads, hidden_size // 64) @@ -773,26 +865,28 @@ class MMDiT(nn.Module): self.pos_embed = None self.use_checkpoint = use_checkpoint - self.joint_blocks = nn.ModuleList( - [ - JointBlock( - self.hidden_size, - num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - attn_mode=attn_mode, - pre_only=(i == num_blocks - 1) and final_layer, - rmsnorm=rmsnorm, - scale_mod_only=scale_mod_only, - swiglu=swiglu, - qk_norm=qk_norm, - dtype=dtype, - device=device, - operations=operations - ) - for i in range(num_blocks) - ] - ) + if not skip_blocks: + self.joint_blocks = nn.ModuleList( + [ + JointBlock( + self.hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + attn_mode=attn_mode, + pre_only=(i == num_blocks - 1) and final_layer, + rmsnorm=rmsnorm, + scale_mod_only=scale_mod_only, + swiglu=swiglu, + qk_norm=qk_norm, + x_block_self_attn=(i in self.x_block_self_attn_layers) or x_block_self_attn, + dtype=dtype, + device=device, + operations=operations, + ) + for i in range(num_blocks) + ] + ) if final_layer: self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 5d2abe1bf..8435de3ec 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -70,6 +70,11 @@ def detect_unet_config(state_dict, key_prefix): context_processor = '{}context_processor.layers.0.attn.qkv.weight'.format(key_prefix) if context_processor in state_dict_keys: unet_config["context_processor_layers"] = count_blocks(state_dict_keys, '{}context_processor.layers.'.format(key_prefix) + '{}.') + unet_config["x_block_self_attn_layers"] = [] + for key in state_dict_keys: + if key.startswith('{}joint_blocks.'.format(key_prefix)) and key.endswith('.x_block.attn2.qkv.weight'): + layer = key[len('{}joint_blocks.'.format(key_prefix)):-len('.x_block.attn2.qkv.weight')] + unet_config["x_block_self_attn_layers"].append(int(layer)) return unet_config if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade