Update SD3 code.

This commit is contained in:
comfyanonymous 2024-10-28 21:58:52 -04:00
parent c320801187
commit 13b0ff8a6f
2 changed files with 130 additions and 31 deletions

View File

@ -1,6 +1,6 @@
import logging import logging
import math import math
from typing import Dict, Optional from typing import Dict, Optional, List
import numpy as np import numpy as np
import torch import torch
@ -415,6 +415,7 @@ class DismantledBlock(nn.Module):
scale_mod_only: bool = False, scale_mod_only: bool = False,
swiglu: bool = False, swiglu: bool = False,
qk_norm: Optional[str] = None, qk_norm: Optional[str] = None,
x_block_self_attn: bool = False,
dtype=None, dtype=None,
device=None, device=None,
operations=None, operations=None,
@ -438,6 +439,24 @@ class DismantledBlock(nn.Module):
device=device, device=device,
operations=operations operations=operations
) )
if x_block_self_attn:
assert not pre_only
assert not scale_mod_only
self.x_block_self_attn = True
self.attn2 = SelfAttention(
dim=hidden_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_mode=attn_mode,
pre_only=False,
qk_norm=qk_norm,
rmsnorm=rmsnorm,
dtype=dtype,
device=device,
operations=operations
)
else:
self.x_block_self_attn = False
if not pre_only: if not pre_only:
if not rmsnorm: if not rmsnorm:
self.norm2 = operations.LayerNorm( self.norm2 = operations.LayerNorm(
@ -464,7 +483,11 @@ class DismantledBlock(nn.Module):
multiple_of=256, multiple_of=256,
) )
self.scale_mod_only = scale_mod_only self.scale_mod_only = scale_mod_only
if not scale_mod_only: if x_block_self_attn:
assert not pre_only
assert not scale_mod_only
n_mods = 9
elif not scale_mod_only:
n_mods = 6 if not pre_only else 2 n_mods = 6 if not pre_only else 2
else: else:
n_mods = 4 if not pre_only else 1 n_mods = 4 if not pre_only else 1
@ -525,8 +548,58 @@ class DismantledBlock(nn.Module):
) )
return x return x
def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
assert self.x_block_self_attn
(
shift_msa,
scale_msa,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
shift_msa2,
scale_msa2,
gate_msa2,
) = self.adaLN_modulation(c).chunk(9, dim=1)
x_norm = self.norm1(x)
qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa))
qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2))
return qkv, qkv2, (
x,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
gate_msa2,
)
def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2):
assert not self.pre_only
attn1 = self.attn.post_attention(attn)
attn2 = self.attn2.post_attention(attn2)
out1 = gate_msa.unsqueeze(1) * attn1
out2 = gate_msa2.unsqueeze(1) * attn2
x = x + out1
x = x + out2
x = x + gate_mlp.unsqueeze(1) * self.mlp(
modulate(self.norm2(x), shift_mlp, scale_mlp)
)
return x
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
assert not self.pre_only assert not self.pre_only
if self.x_block_self_attn:
qkv, qkv2, intermediates = self.pre_attention_x(x, c)
attn, _ = optimized_attention(
qkv[0], qkv[1], qkv[2],
num_heads=self.attn.num_heads,
)
attn2, _ = optimized_attention(
qkv2[0], qkv2[1], qkv2[2],
num_heads=self.attn2.num_heads,
)
return self.post_attention_x(attn, attn2, *intermediates)
else:
qkv, intermediates = self.pre_attention(x, c) qkv, intermediates = self.pre_attention(x, c)
attn = optimized_attention( attn = optimized_attention(
qkv[0], qkv[1], qkv[2], qkv[0], qkv[1], qkv[2],
@ -547,6 +620,9 @@ def block_mixing(*args, use_checkpoint=True, **kwargs):
def _block_mixing(context, x, context_block, x_block, c): def _block_mixing(context, x, context_block, x_block, c):
context_qkv, context_intermediates = context_block.pre_attention(context, c) context_qkv, context_intermediates = context_block.pre_attention(context, c)
if x_block.x_block_self_attn:
x_qkv, x_qkv2, x_intermediates = x_block.pre_attention_x(x, c)
else:
x_qkv, x_intermediates = x_block.pre_attention(x, c) x_qkv, x_intermediates = x_block.pre_attention(x, c)
o = [] o = []
@ -568,6 +644,13 @@ def _block_mixing(context, x, context_block, x_block, c):
else: else:
context = None context = None
if x_block.x_block_self_attn:
attn2 = optimized_attention(
x_qkv2[0], x_qkv2[1], x_qkv2[2],
heads=x_block.attn2.num_heads,
)
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
else:
x = x_block.post_attention(x_attn, *x_intermediates) x = x_block.post_attention(x_attn, *x_intermediates)
return context, x return context, x
@ -583,8 +666,13 @@ class JointBlock(nn.Module):
super().__init__() super().__init__()
pre_only = kwargs.pop("pre_only") pre_only = kwargs.pop("pre_only")
qk_norm = kwargs.pop("qk_norm", None) qk_norm = kwargs.pop("qk_norm", None)
x_block_self_attn = kwargs.pop("x_block_self_attn", False)
self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs) self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs) self.x_block = DismantledBlock(*args,
pre_only=False,
qk_norm=qk_norm,
x_block_self_attn=x_block_self_attn,
**kwargs)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return block_mixing( return block_mixing(
@ -699,9 +787,12 @@ class MMDiT(nn.Module):
qk_norm: Optional[str] = None, qk_norm: Optional[str] = None,
qkv_bias: bool = True, qkv_bias: bool = True,
context_processor_layers = None, context_processor_layers = None,
x_block_self_attn: bool = False,
x_block_self_attn_layers: Optional[List[int]] = [],
context_size = 4096, context_size = 4096,
num_blocks = None, num_blocks = None,
final_layer = True, final_layer = True,
skip_blocks = False,
dtype = None, #TODO dtype = None, #TODO
device = None, device = None,
operations = None, operations = None,
@ -716,6 +807,7 @@ class MMDiT(nn.Module):
self.pos_embed_scaling_factor = pos_embed_scaling_factor self.pos_embed_scaling_factor = pos_embed_scaling_factor
self.pos_embed_offset = pos_embed_offset self.pos_embed_offset = pos_embed_offset
self.pos_embed_max_size = pos_embed_max_size self.pos_embed_max_size = pos_embed_max_size
self.x_block_self_attn_layers = x_block_self_attn_layers
# hidden_size = default(hidden_size, 64 * depth) # hidden_size = default(hidden_size, 64 * depth)
# num_heads = default(num_heads, hidden_size // 64) # num_heads = default(num_heads, hidden_size // 64)
@ -773,6 +865,7 @@ class MMDiT(nn.Module):
self.pos_embed = None self.pos_embed = None
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
if not skip_blocks:
self.joint_blocks = nn.ModuleList( self.joint_blocks = nn.ModuleList(
[ [
JointBlock( JointBlock(
@ -786,9 +879,10 @@ class MMDiT(nn.Module):
scale_mod_only=scale_mod_only, scale_mod_only=scale_mod_only,
swiglu=swiglu, swiglu=swiglu,
qk_norm=qk_norm, qk_norm=qk_norm,
x_block_self_attn=(i in self.x_block_self_attn_layers) or x_block_self_attn,
dtype=dtype, dtype=dtype,
device=device, device=device,
operations=operations operations=operations,
) )
for i in range(num_blocks) for i in range(num_blocks)
] ]

View File

@ -70,6 +70,11 @@ def detect_unet_config(state_dict, key_prefix):
context_processor = '{}context_processor.layers.0.attn.qkv.weight'.format(key_prefix) context_processor = '{}context_processor.layers.0.attn.qkv.weight'.format(key_prefix)
if context_processor in state_dict_keys: if context_processor in state_dict_keys:
unet_config["context_processor_layers"] = count_blocks(state_dict_keys, '{}context_processor.layers.'.format(key_prefix) + '{}.') unet_config["context_processor_layers"] = count_blocks(state_dict_keys, '{}context_processor.layers.'.format(key_prefix) + '{}.')
unet_config["x_block_self_attn_layers"] = []
for key in state_dict_keys:
if key.startswith('{}joint_blocks.'.format(key_prefix)) and key.endswith('.x_block.attn2.qkv.weight'):
layer = key[len('{}joint_blocks.'.format(key_prefix)):-len('.x_block.attn2.qkv.weight')]
unet_config["x_block_self_attn_layers"].append(int(layer))
return unet_config return unet_config
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade