mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Update SD3 code.
This commit is contained in:
parent
c320801187
commit
13b0ff8a6f
@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -415,6 +415,7 @@ class DismantledBlock(nn.Module):
|
|||||||
scale_mod_only: bool = False,
|
scale_mod_only: bool = False,
|
||||||
swiglu: bool = False,
|
swiglu: bool = False,
|
||||||
qk_norm: Optional[str] = None,
|
qk_norm: Optional[str] = None,
|
||||||
|
x_block_self_attn: bool = False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@ -438,6 +439,24 @@ class DismantledBlock(nn.Module):
|
|||||||
device=device,
|
device=device,
|
||||||
operations=operations
|
operations=operations
|
||||||
)
|
)
|
||||||
|
if x_block_self_attn:
|
||||||
|
assert not pre_only
|
||||||
|
assert not scale_mod_only
|
||||||
|
self.x_block_self_attn = True
|
||||||
|
self.attn2 = SelfAttention(
|
||||||
|
dim=hidden_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
attn_mode=attn_mode,
|
||||||
|
pre_only=False,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
rmsnorm=rmsnorm,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.x_block_self_attn = False
|
||||||
if not pre_only:
|
if not pre_only:
|
||||||
if not rmsnorm:
|
if not rmsnorm:
|
||||||
self.norm2 = operations.LayerNorm(
|
self.norm2 = operations.LayerNorm(
|
||||||
@ -464,7 +483,11 @@ class DismantledBlock(nn.Module):
|
|||||||
multiple_of=256,
|
multiple_of=256,
|
||||||
)
|
)
|
||||||
self.scale_mod_only = scale_mod_only
|
self.scale_mod_only = scale_mod_only
|
||||||
if not scale_mod_only:
|
if x_block_self_attn:
|
||||||
|
assert not pre_only
|
||||||
|
assert not scale_mod_only
|
||||||
|
n_mods = 9
|
||||||
|
elif not scale_mod_only:
|
||||||
n_mods = 6 if not pre_only else 2
|
n_mods = 6 if not pre_only else 2
|
||||||
else:
|
else:
|
||||||
n_mods = 4 if not pre_only else 1
|
n_mods = 4 if not pre_only else 1
|
||||||
@ -525,8 +548,58 @@ class DismantledBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||||
|
assert self.x_block_self_attn
|
||||||
|
(
|
||||||
|
shift_msa,
|
||||||
|
scale_msa,
|
||||||
|
gate_msa,
|
||||||
|
shift_mlp,
|
||||||
|
scale_mlp,
|
||||||
|
gate_mlp,
|
||||||
|
shift_msa2,
|
||||||
|
scale_msa2,
|
||||||
|
gate_msa2,
|
||||||
|
) = self.adaLN_modulation(c).chunk(9, dim=1)
|
||||||
|
x_norm = self.norm1(x)
|
||||||
|
qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa))
|
||||||
|
qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2))
|
||||||
|
return qkv, qkv2, (
|
||||||
|
x,
|
||||||
|
gate_msa,
|
||||||
|
shift_mlp,
|
||||||
|
scale_mlp,
|
||||||
|
gate_mlp,
|
||||||
|
gate_msa2,
|
||||||
|
)
|
||||||
|
|
||||||
|
def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2):
|
||||||
|
assert not self.pre_only
|
||||||
|
attn1 = self.attn.post_attention(attn)
|
||||||
|
attn2 = self.attn2.post_attention(attn2)
|
||||||
|
out1 = gate_msa.unsqueeze(1) * attn1
|
||||||
|
out2 = gate_msa2.unsqueeze(1) * attn2
|
||||||
|
x = x + out1
|
||||||
|
x = x + out2
|
||||||
|
x = x + gate_mlp.unsqueeze(1) * self.mlp(
|
||||||
|
modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||||
assert not self.pre_only
|
assert not self.pre_only
|
||||||
|
if self.x_block_self_attn:
|
||||||
|
qkv, qkv2, intermediates = self.pre_attention_x(x, c)
|
||||||
|
attn, _ = optimized_attention(
|
||||||
|
qkv[0], qkv[1], qkv[2],
|
||||||
|
num_heads=self.attn.num_heads,
|
||||||
|
)
|
||||||
|
attn2, _ = optimized_attention(
|
||||||
|
qkv2[0], qkv2[1], qkv2[2],
|
||||||
|
num_heads=self.attn2.num_heads,
|
||||||
|
)
|
||||||
|
return self.post_attention_x(attn, attn2, *intermediates)
|
||||||
|
else:
|
||||||
qkv, intermediates = self.pre_attention(x, c)
|
qkv, intermediates = self.pre_attention(x, c)
|
||||||
attn = optimized_attention(
|
attn = optimized_attention(
|
||||||
qkv[0], qkv[1], qkv[2],
|
qkv[0], qkv[1], qkv[2],
|
||||||
@ -547,6 +620,9 @@ def block_mixing(*args, use_checkpoint=True, **kwargs):
|
|||||||
def _block_mixing(context, x, context_block, x_block, c):
|
def _block_mixing(context, x, context_block, x_block, c):
|
||||||
context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
||||||
|
|
||||||
|
if x_block.x_block_self_attn:
|
||||||
|
x_qkv, x_qkv2, x_intermediates = x_block.pre_attention_x(x, c)
|
||||||
|
else:
|
||||||
x_qkv, x_intermediates = x_block.pre_attention(x, c)
|
x_qkv, x_intermediates = x_block.pre_attention(x, c)
|
||||||
|
|
||||||
o = []
|
o = []
|
||||||
@ -568,6 +644,13 @@ def _block_mixing(context, x, context_block, x_block, c):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
context = None
|
context = None
|
||||||
|
if x_block.x_block_self_attn:
|
||||||
|
attn2 = optimized_attention(
|
||||||
|
x_qkv2[0], x_qkv2[1], x_qkv2[2],
|
||||||
|
heads=x_block.attn2.num_heads,
|
||||||
|
)
|
||||||
|
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
|
||||||
|
else:
|
||||||
x = x_block.post_attention(x_attn, *x_intermediates)
|
x = x_block.post_attention(x_attn, *x_intermediates)
|
||||||
return context, x
|
return context, x
|
||||||
|
|
||||||
@ -583,8 +666,13 @@ class JointBlock(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
pre_only = kwargs.pop("pre_only")
|
pre_only = kwargs.pop("pre_only")
|
||||||
qk_norm = kwargs.pop("qk_norm", None)
|
qk_norm = kwargs.pop("qk_norm", None)
|
||||||
|
x_block_self_attn = kwargs.pop("x_block_self_attn", False)
|
||||||
self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
|
self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
|
||||||
self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs)
|
self.x_block = DismantledBlock(*args,
|
||||||
|
pre_only=False,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
x_block_self_attn=x_block_self_attn,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
return block_mixing(
|
return block_mixing(
|
||||||
@ -699,9 +787,12 @@ class MMDiT(nn.Module):
|
|||||||
qk_norm: Optional[str] = None,
|
qk_norm: Optional[str] = None,
|
||||||
qkv_bias: bool = True,
|
qkv_bias: bool = True,
|
||||||
context_processor_layers = None,
|
context_processor_layers = None,
|
||||||
|
x_block_self_attn: bool = False,
|
||||||
|
x_block_self_attn_layers: Optional[List[int]] = [],
|
||||||
context_size = 4096,
|
context_size = 4096,
|
||||||
num_blocks = None,
|
num_blocks = None,
|
||||||
final_layer = True,
|
final_layer = True,
|
||||||
|
skip_blocks = False,
|
||||||
dtype = None, #TODO
|
dtype = None, #TODO
|
||||||
device = None,
|
device = None,
|
||||||
operations = None,
|
operations = None,
|
||||||
@ -716,6 +807,7 @@ class MMDiT(nn.Module):
|
|||||||
self.pos_embed_scaling_factor = pos_embed_scaling_factor
|
self.pos_embed_scaling_factor = pos_embed_scaling_factor
|
||||||
self.pos_embed_offset = pos_embed_offset
|
self.pos_embed_offset = pos_embed_offset
|
||||||
self.pos_embed_max_size = pos_embed_max_size
|
self.pos_embed_max_size = pos_embed_max_size
|
||||||
|
self.x_block_self_attn_layers = x_block_self_attn_layers
|
||||||
|
|
||||||
# hidden_size = default(hidden_size, 64 * depth)
|
# hidden_size = default(hidden_size, 64 * depth)
|
||||||
# num_heads = default(num_heads, hidden_size // 64)
|
# num_heads = default(num_heads, hidden_size // 64)
|
||||||
@ -773,6 +865,7 @@ class MMDiT(nn.Module):
|
|||||||
self.pos_embed = None
|
self.pos_embed = None
|
||||||
|
|
||||||
self.use_checkpoint = use_checkpoint
|
self.use_checkpoint = use_checkpoint
|
||||||
|
if not skip_blocks:
|
||||||
self.joint_blocks = nn.ModuleList(
|
self.joint_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
JointBlock(
|
JointBlock(
|
||||||
@ -786,9 +879,10 @@ class MMDiT(nn.Module):
|
|||||||
scale_mod_only=scale_mod_only,
|
scale_mod_only=scale_mod_only,
|
||||||
swiglu=swiglu,
|
swiglu=swiglu,
|
||||||
qk_norm=qk_norm,
|
qk_norm=qk_norm,
|
||||||
|
x_block_self_attn=(i in self.x_block_self_attn_layers) or x_block_self_attn,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations
|
operations=operations,
|
||||||
)
|
)
|
||||||
for i in range(num_blocks)
|
for i in range(num_blocks)
|
||||||
]
|
]
|
||||||
|
@ -70,6 +70,11 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
context_processor = '{}context_processor.layers.0.attn.qkv.weight'.format(key_prefix)
|
context_processor = '{}context_processor.layers.0.attn.qkv.weight'.format(key_prefix)
|
||||||
if context_processor in state_dict_keys:
|
if context_processor in state_dict_keys:
|
||||||
unet_config["context_processor_layers"] = count_blocks(state_dict_keys, '{}context_processor.layers.'.format(key_prefix) + '{}.')
|
unet_config["context_processor_layers"] = count_blocks(state_dict_keys, '{}context_processor.layers.'.format(key_prefix) + '{}.')
|
||||||
|
unet_config["x_block_self_attn_layers"] = []
|
||||||
|
for key in state_dict_keys:
|
||||||
|
if key.startswith('{}joint_blocks.'.format(key_prefix)) and key.endswith('.x_block.attn2.qkv.weight'):
|
||||||
|
layer = key[len('{}joint_blocks.'.format(key_prefix)):-len('.x_block.attn2.qkv.weight')]
|
||||||
|
unet_config["x_block_self_attn_layers"].append(int(layer))
|
||||||
return unet_config
|
return unet_config
|
||||||
|
|
||||||
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
|
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
|
||||||
|
Loading…
Reference in New Issue
Block a user