diff --git a/comfy/ldm/pixart/blocks.py b/comfy/ldm/pixart/blocks.py new file mode 100644 index 00000000..7ad2ec29 --- /dev/null +++ b/comfy/ldm/pixart/blocks.py @@ -0,0 +1,382 @@ +# Based on: +# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license] +# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license] +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from comfy import model_management +from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, Mlp, timestep_embedding +from comfy.ldm.modules.attention import optimized_attention + +if model_management.xformers_enabled(): + import xformers.ops + if int((xformers.__version__).split(".")[2]) >= 28: + block_diagonal_mask_from_seqlens = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens + else: + block_diagonal_mask_from_seqlens = xformers.ops.fmha.BlockDiagonalMask.from_seqlens + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + +def t2i_modulate(x, shift, scale): + return x * (1 + scale) + shift + +class MultiHeadCrossAttention(nn.Module): + def __init__(self, d_model, num_heads, attn_drop=0., proj_drop=0., dtype=None, device=None, operations=None, **kwargs): + super(MultiHeadCrossAttention, self).__init__() + assert d_model % num_heads == 0, "d_model must be divisible by num_heads" + + self.d_model = d_model + self.num_heads = num_heads + self.head_dim = d_model // num_heads + + self.q_linear = operations.Linear(d_model, d_model, dtype=dtype, device=device) + self.kv_linear = operations.Linear(d_model, d_model*2, dtype=dtype, device=device) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = operations.Linear(d_model, d_model, dtype=dtype, device=device) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, cond, mask=None): + # query/value: img tokens; key: condition; mask: if padding tokens + B, N, C = x.shape + + q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim) + kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim) + k, v = kv.unbind(2) + + # TODO: xformers needs separate mask logic here + if model_management.xformers_enabled(): + attn_bias = None + if mask is not None: + attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask) + x = xformers.ops.memory_efficient_attention(q, k, v, p=0, attn_bias=attn_bias) + else: + q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),) + attn_mask = None + if mask is not None and len(mask) > 1: + # Create equivalent of xformer diagonal block mask, still only correct for square masks + # But depth doesn't matter as tensors can expand in that dimension + attn_mask_template = torch.ones( + [q.shape[2] // B, mask[0]], + dtype=torch.bool, + device=q.device + ) + attn_mask = torch.block_diag(attn_mask_template) + + # create a mask on the diagonal for each mask in the batch + for _ in range(B - 1): + attn_mask = torch.block_diag(attn_mask, attn_mask_template) + + x = optimized_attention(q, k, v, self.num_heads, mask=attn_mask, skip_reshape=True) + + x = x.view(B, -1, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class AttentionKVCompress(nn.Module): + """Multi-head Attention block with KV token compression and qk norm.""" + def __init__(self, dim, num_heads=8, qkv_bias=True, sampling='conv', sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **kwargs): + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool: If True, add a learnable bias to query, key, value. + """ + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) + self.proj = operations.Linear(dim, dim, dtype=dtype, device=device) + + self.sampling=sampling # ['conv', 'ave', 'uniform', 'uniform_every'] + self.sr_ratio = sr_ratio + if sr_ratio > 1 and sampling == 'conv': + # Avg Conv Init. + self.sr = operations.Conv2d(dim, dim, groups=dim, kernel_size=sr_ratio, stride=sr_ratio, dtype=dtype, device=device) + # self.sr.weight.data.fill_(1/sr_ratio**2) + # self.sr.bias.data.zero_() + self.norm = operations.LayerNorm(dim, dtype=dtype, device=device) + if qk_norm: + self.q_norm = operations.LayerNorm(dim, dtype=dtype, device=device) + self.k_norm = operations.LayerNorm(dim, dtype=dtype, device=device) + else: + self.q_norm = nn.Identity() + self.k_norm = nn.Identity() + + def downsample_2d(self, tensor, H, W, scale_factor, sampling=None): + if sampling is None or scale_factor == 1: + return tensor + B, N, C = tensor.shape + + if sampling == 'uniform_every': + return tensor[:, ::scale_factor], int(N // scale_factor) + + tensor = tensor.reshape(B, H, W, C).permute(0, 3, 1, 2) + new_H, new_W = int(H / scale_factor), int(W / scale_factor) + new_N = new_H * new_W + + if sampling == 'ave': + tensor = F.interpolate( + tensor, scale_factor=1 / scale_factor, mode='nearest' + ).permute(0, 2, 3, 1) + elif sampling == 'uniform': + tensor = tensor[:, :, ::scale_factor, ::scale_factor].permute(0, 2, 3, 1) + elif sampling == 'conv': + tensor = self.sr(tensor).reshape(B, C, -1).permute(0, 2, 1) + tensor = self.norm(tensor) + else: + raise ValueError + + return tensor.reshape(B, new_N, C).contiguous(), new_N + + def forward(self, x, mask=None, HW=None, block_id=None): + B, N, C = x.shape # 2 4096 1152 + new_N = N + if HW is None: + H = W = int(N ** 0.5) + else: + H, W = HW + qkv = self.qkv(x).reshape(B, N, 3, C) + + q, k, v = qkv.unbind(2) + dtype = q.dtype + q = self.q_norm(q) + k = self.k_norm(k) + + # KV compression + if self.sr_ratio > 1: + k, new_N = self.downsample_2d(k, H, W, self.sr_ratio, sampling=self.sampling) + v, new_N = self.downsample_2d(v, H, W, self.sr_ratio, sampling=self.sampling) + + q = q.reshape(B, N, self.num_heads, C // self.num_heads).to(dtype) + k = k.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype) + v = v.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype) + + if mask is not None: + raise NotImplementedError("Attn mask logic not added for self attention") + + # This is never called at the moment + # attn_bias = None + # if mask is not None: + # attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device) + # attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float('-inf')) + + # attention 2 + q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),) + x = optimized_attention(q, k, v, self.num_heads, mask=None, skip_reshape=True) + + x = x.view(B, N, C) + x = self.proj(x) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of PixArt. + """ + def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None): + super().__init__() + self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class T2IFinalLayer(nn.Module): + """ + The final layer of PixArt. + """ + def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None): + super().__init__() + self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) + self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size ** 0.5) + self.out_channels = out_channels + + def forward(self, x, t): + dtype = x.dtype + shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1) + x = t2i_modulate(self.norm_final(x), shift, scale) + x = self.linear(x.to(dtype)) + return x + + +class MaskFinalLayer(nn.Module): + """ + The final layer of PixArt. + """ + def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels, dtype=None, device=None, operations=None): + super().__init__() + self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.linear = operations.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(c_emb_size, 2 * final_hidden_size, bias=True, dtype=dtype, device=device) + ) + def forward(self, x, t): + shift, scale = self.adaLN_modulation(t).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class DecoderLayer(nn.Module): + """ + The final layer of PixArt. + """ + def __init__(self, hidden_size, decoder_hidden_size, dtype=None, device=None, operations=None): + super().__init__() + self.norm_decoder = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.linear = operations.Linear(hidden_size, decoder_hidden_size, bias=True, dtype=dtype, device=device) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device) + ) + def forward(self, x, t): + shift, scale = self.adaLN_modulation(t).chunk(2, dim=1) + x = modulate(self.norm_decoder(x), shift, scale) + x = self.linear(x) + return x + + +class SizeEmbedder(TimestepEmbedder): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None): + super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size, operations=operations) + self.mlp = nn.Sequential( + operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), + ) + self.frequency_embedding_size = frequency_embedding_size + self.outdim = hidden_size + + def forward(self, s, bs): + if s.ndim == 1: + s = s[:, None] + assert s.ndim == 2 + if s.shape[0] != bs: + s = s.repeat(bs//s.shape[0], 1) + assert s.shape[0] == bs + b, dims = s.shape[0], s.shape[1] + s = rearrange(s, "b d -> (b d)") + s_freq = timestep_embedding(s, self.frequency_embedding_size) + s_emb = self.mlp(s_freq.to(s.dtype)) + s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim) + return s_emb + + +class LabelEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + def __init__(self, num_classes, hidden_size, dropout_prob, dtype=None, device=None, operations=None): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = operations.Embedding(num_classes + use_cfg_embedding, hidden_size, dtype=dtype, device=device), + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob + else: + drop_ids = force_drop_ids == 1 + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + return embeddings + + +class CaptionEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120, dtype=None, device=None, operations=None): + super().__init__() + self.y_proj = Mlp( + in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, + dtype=dtype, device=device, operations=operations, + ) + self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5)) + self.uncond_prob = uncond_prob + + def token_drop(self, caption, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob + else: + drop_ids = force_drop_ids == 1 + caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption) + return caption + + def forward(self, caption, train, force_drop_ids=None): + if train: + assert caption.shape[2:] == self.y_embedding.shape + use_dropout = self.uncond_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + caption = self.token_drop(caption, force_drop_ids) + caption = self.y_proj(caption) + return caption + + +class CaptionEmbedderDoubleBr(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120, dtype=None, device=None, operations=None): + super().__init__() + self.proj = Mlp( + in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, + dtype=dtype, device=device, operations=operations, + ) + self.embedding = nn.Parameter(torch.randn(1, in_channels) / 10 ** 0.5) + self.y_embedding = nn.Parameter(torch.randn(token_num, in_channels) / 10 ** 0.5) + self.uncond_prob = uncond_prob + + def token_drop(self, global_caption, caption, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(global_caption.shape[0]).cuda() < self.uncond_prob + else: + drop_ids = force_drop_ids == 1 + global_caption = torch.where(drop_ids[:, None], self.embedding, global_caption) + caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption) + return global_caption, caption + + def forward(self, caption, train, force_drop_ids=None): + assert caption.shape[2: ] == self.y_embedding.shape + global_caption = caption.mean(dim=2).squeeze() + use_dropout = self.uncond_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + global_caption, caption = self.token_drop(global_caption, caption, force_drop_ids) + y_embed = self.proj(global_caption) + return y_embed, caption diff --git a/comfy/ldm/pixart/pixart.py b/comfy/ldm/pixart/pixart.py new file mode 100644 index 00000000..cd572efc --- /dev/null +++ b/comfy/ldm/pixart/pixart.py @@ -0,0 +1,201 @@ +# Based on: +# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license] +# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license] +import torch +import torch.nn as nn + +from .blocks import ( + t2i_modulate, + CaptionEmbedder, + AttentionKVCompress, + MultiHeadCrossAttention, + T2IFinalLayer, +) +from comfy.ldm.modules.diffusionmodules.mmdit import PatchEmbed, TimestepEmbedder, Mlp, get_1d_sincos_pos_embed_from_grid_torch + + +class PixArtBlock(nn.Module): + """ + A PixArt block with adaptive layer norm (adaLN-single) conditioning. + """ + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0, input_size=None, sampling=None, sr_ratio=1, qk_norm=False, **block_kwargs): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = AttentionKVCompress( + hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio, + qk_norm=qk_norm, **block_kwargs + ) + self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + # to be compatible with lower version pytorch + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0) + self.drop_path = nn.Identity() #DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5) + self.sampling = sampling + self.sr_ratio = sr_ratio + + def forward(self, x, y, t, mask=None, **kwargs): + B, N, C = x.shape + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1) + x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C)) + x = x + self.cross_attn(x, y, mask) + x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) + + return x + + +### Core PixArt Model ### +class PixArt(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + pred_sigma=True, + drop_path: float = 0., + caption_channels=4096, + pe_interpolation=1.0, + pe_precision=None, + config=None, + model_max_length=120, + qk_norm=False, + kv_compress_config=None, + **kwargs, + ): + super().__init__() + self.pred_sigma = pred_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if pred_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.pe_interpolation = pe_interpolation + self.pe_precision = pe_precision + self.depth = depth + + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + num_patches = self.x_embedder.num_patches + self.base_size = input_size // self.patch_size + # Will use fixed sin-cos embedding: + self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size)) + + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.t_block = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + self.y_embedder = CaptionEmbedder( + in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob, + act_layer=approx_gelu, token_num=model_max_length + ) + drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule + self.kv_compress_config = kv_compress_config + if kv_compress_config is None: + self.kv_compress_config = { + 'sampling': None, + 'scale_factor': 1, + 'kv_compress_layer': [], + } + self.blocks = nn.ModuleList([ + PixArtBlock( + hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i], + input_size=(input_size // patch_size, input_size // patch_size), + sampling=self.kv_compress_config['sampling'], + sr_ratio=int( + self.kv_compress_config['scale_factor'] + ) if i in self.kv_compress_config['kv_compress_layer'] else 1, + qk_norm=qk_norm, + ) + for i in range(depth) + ]) + self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels) + + def forward_raw(self, x, t, y, mask=None, data_info=None): + """ + Original forward pass of PixArt. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N, 1, 120, C) tensor of class labels + """ + x = x.to(self.dtype) + timestep = t.to(self.dtype) + y = y.to(self.dtype) + pos_embed = self.pos_embed.to(self.dtype) + x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 + t = self.t_embedder(timestep.to(x.dtype)) # (N, D) + t0 = self.t_block(t) + y = self.y_embedder(y, self.training) # (N, 1, L, D) + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) + y_lens = mask.sum(dim=1).tolist() + else: + y_lens = [y.shape[2]] * y.shape[0] + y = y.squeeze(1).view(1, -1, x.shape[-1]) + for block in self.blocks: + x = block(x, y, t0, y_lens) # (N, T, D) + x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x) # (N, out_channels, H, W) + return x + + def forward(self, x, timesteps, context, y=None, **kwargs): + """ + Forward pass that adapts comfy input to original forward function + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + timesteps: (N,) tensor of diffusion timesteps + context: (N, 1, 120, C) conditioning + y: extra conditioning. + """ + ## Still accepts the input w/o that dim but returns garbage + if len(context.shape) == 3: + context = context.unsqueeze(1) + + ## run original forward pass + out = self.forward_raw( + x = x.to(self.dtype), + t = timesteps.to(self.dtype), + y = context.to(self.dtype), + ) + + ## only return EPS + out = out.to(torch.float) + eps, _ = out[:, :self.in_channels], out[:, self.in_channels:] + return eps + + def unpatchify(self, x): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) + return imgs + +def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32): + grid_h, grid_w = torch.meshgrid( + torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation, + torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation, + indexing='ij' + ) + emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype) + emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype) + emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D) + return emb diff --git a/comfy/ldm/pixart/pixartms.py b/comfy/ldm/pixart/pixartms.py new file mode 100644 index 00000000..195063b0 --- /dev/null +++ b/comfy/ldm/pixart/pixartms.py @@ -0,0 +1,246 @@ +# Based on: +# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license] +# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license] +import torch +import torch.nn as nn + +from .blocks import ( + t2i_modulate, + CaptionEmbedder, + AttentionKVCompress, + MultiHeadCrossAttention, + T2IFinalLayer, + SizeEmbedder, +) +from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp +from .pixart import PixArt, get_2d_sincos_pos_embed_torch + + +class PixArtMSBlock(nn.Module): + """ + A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning. + """ + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None, + sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs): + super().__init__() + self.hidden_size = hidden_size + self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.attn = AttentionKVCompress( + hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio, + qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs + ) + self.cross_attn = MultiHeadCrossAttention( + hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs + ) + self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + # to be compatible with lower version pytorch + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp( + in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, + dtype=dtype, device=device, operations=operations + ) + self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5) + + def forward(self, x, y, t, mask=None, HW=None, **kwargs): + B, N, C = x.shape + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(x.dtype) + t.reshape(B, 6, -1)).chunk(6, dim=1) + x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW)) + x = x + self.cross_attn(x, y, mask) + x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) + + return x + + +### Core PixArt Model ### +class PixArtMS(PixArt): + """ + Diffusion model with a Transformer backbone. + """ + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + learn_sigma=True, + pred_sigma=True, + drop_path: float = 0., + caption_channels=4096, + pe_interpolation=None, + pe_precision=None, + config=None, + model_max_length=120, + micro_condition=True, + qk_norm=False, + kv_compress_config=None, + dtype=None, + device=None, + operations=None, + **kwargs, + ): + nn.Module.__init__(self) + self.dtype = dtype + self.pred_sigma = pred_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if pred_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.pe_interpolation = pe_interpolation + self.pe_precision = pe_precision + self.hidden_size = hidden_size + self.depth = depth + + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.t_block = nn.Sequential( + nn.SiLU(), + operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device) + ) + self.x_embedder = PatchEmbed( + patch_size=patch_size, + in_chans=in_channels, + embed_dim=hidden_size, + bias=True, + dtype=dtype, + device=device, + operations=operations + ) + self.t_embedder = TimestepEmbedder( + hidden_size, dtype=dtype, device=device, operations=operations, + ) + self.y_embedder = CaptionEmbedder( + in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob, + act_layer=approx_gelu, token_num=model_max_length, + dtype=dtype, device=device, operations=operations, + ) + + self.micro_conditioning = micro_condition + if self.micro_conditioning: + self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations) + self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations) + + # For fixed sin-cos embedding: + # num_patches = (input_size // patch_size) * (input_size // patch_size) + # self.base_size = input_size // self.patch_size + # self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size)) + + drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule + if kv_compress_config is None: + kv_compress_config = { + 'sampling': None, + 'scale_factor': 1, + 'kv_compress_layer': [], + } + self.blocks = nn.ModuleList([ + PixArtMSBlock( + hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i], + sampling=kv_compress_config['sampling'], + sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1, + qk_norm=qk_norm, + dtype=dtype, + device=device, + operations=operations, + ) + for i in range(depth) + ]) + self.final_layer = T2IFinalLayer( + hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations + ) + + def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs): + """ + Original forward pass of PixArt. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N, 1, 120, C) conditioning + ar: (N, 1): aspect ratio + cs: (N ,2) size conditioning for height/width + """ + B, C, H, W = x.shape + c_res = (H + W) // 2 + pe_interpolation = self.pe_interpolation + if pe_interpolation is None or self.pe_precision is not None: + # calculate pe_interpolation on-the-fly + pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0) + + pos_embed = get_2d_sincos_pos_embed_torch( + self.hidden_size, + h=(H // self.patch_size), + w=(W // self.patch_size), + pe_interpolation=pe_interpolation, + base_size=((round(c_res / 64) * 64) // self.patch_size), + device=x.device, + dtype=x.dtype, + ).unsqueeze(0) + + x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 + t = self.t_embedder(timestep, x.dtype) # (N, D) + + if self.micro_conditioning and (c_size is not None and c_ar is not None): + bs = x.shape[0] + c_size = self.csize_embedder(c_size, bs) # (N, D) + c_ar = self.ar_embedder(c_ar, bs) # (N, D) + t = t + torch.cat([c_size, c_ar], dim=1) + + t0 = self.t_block(t) + y = self.y_embedder(y, self.training) # (N, D) + + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) + y_lens = mask.sum(dim=1).tolist() + else: + y_lens = [y.shape[2]] * y.shape[0] + y = y.squeeze(1).view(1, -1, x.shape[-1]) + for block in self.blocks: + x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D) + + x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x, H, W) # (N, out_channels, H, W) + + return x + + def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs): + B, C, H, W = x.shape + + # Fallback for missing microconds + if self.micro_conditioning: + if c_size is None: + c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1) + + if c_ar is None: + c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1) + + ## Still accepts the input w/o that dim but returns garbage + if len(context.shape) == 3: + context = context.unsqueeze(1) + + ## run original forward pass + out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar) + + ## only return EPS + if self.pred_sigma: + return out[:, :self.in_channels] + return out + + def unpatchify(self, x, h, w): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = h // self.patch_size + w = w // self.patch_size + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) + return imgs diff --git a/comfy/lora.py b/comfy/lora.py index c43a6fe4..ec3da6f4 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -344,7 +344,6 @@ def model_lora_keys_unet(model, key_map={}): key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) #simpletuner lycoris format key_map[key_lora] = to - if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") for k in diffusers_keys: @@ -353,6 +352,20 @@ def model_lora_keys_unet(model, key_map={}): key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format key_map[key_lora] = to + if isinstance(model, comfy.model_base.PixArt): + diffusers_keys = comfy.utils.pixart_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") + for k in diffusers_keys: + if k.endswith(".weight"): + to = diffusers_keys[k] + key_lora = "transformer.{}".format(k[:-len(".weight")]) #default format + key_map[key_lora] = to + + key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #diffusers training script + key_map[key_lora] = to + + key_lora = "unet.base_model.model.{}".format(k[:-len(".weight")]) #old reference peft script + key_map[key_lora] = to + if isinstance(model, comfy.model_base.HunyuanDiT): for k in sdk: if k.startswith("diffusion_model.") and k.endswith(".weight"): diff --git a/comfy/model_base.py b/comfy/model_base.py index 18c6f224..99c53e57 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -26,6 +26,7 @@ from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAug from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper import comfy.ldm.genmo.joint_model.asymm_models_joint import comfy.ldm.aura.mmdit +import comfy.ldm.pixart.pixartms import comfy.ldm.hydit.models import comfy.ldm.audio.dit import comfy.ldm.audio.embedders @@ -718,6 +719,21 @@ class HunyuanDiT(BaseModel): out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]])) return out +class PixArt(BaseModel): + def __init__(self, model_config, model_type=ModelType.EPS, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + + width = kwargs.get("width", None) + height = kwargs.get("height", None) + if width is not None and height is not None: + out["c_size"] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width]])) + out["c_ar"] = comfy.conds.CONDRegular(torch.FloatTensor([[kwargs.get("aspect_ratio", height/width)]])) + + return out + class Flux(BaseModel): def __init__(self, model_config, model_type=ModelType.FLUX, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index d6b59fd7..c53bef5b 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -203,11 +203,42 @@ def detect_unet_config(state_dict, key_prefix): dit_config["rope_theta"] = 10000.0 return dit_config + if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys and '{}pos_embed.proj.bias'.format(key_prefix) in state_dict_keys: + # PixArt diffusers + return None + if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv dit_config = {} dit_config["image_model"] = "ltxv" return dit_config + if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt + patch_size = 2 + dit_config = {} + dit_config["num_heads"] = 16 + dit_config["patch_size"] = patch_size + dit_config["hidden_size"] = 1152 + dit_config["in_channels"] = 4 + dit_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.') + + y_key = "{}y_embedder.y_embedding".format(key_prefix) + if y_key in state_dict_keys: + dit_config["model_max_length"] = state_dict[y_key].shape[0] + + pe_key = "{}pos_embed".format(key_prefix) + if pe_key in state_dict_keys: + dit_config["input_size"] = int(math.sqrt(state_dict[pe_key].shape[1])) * patch_size + dit_config["pe_interpolation"] = dit_config["input_size"] // (512//8) # guess + + ar_key = "{}ar_embedder.mlp.0.weight".format(key_prefix) + if ar_key in state_dict_keys: + dit_config["image_model"] = "pixart_alpha" + dit_config["micro_condition"] = True + else: + dit_config["image_model"] = "pixart_sigma" + dit_config["micro_condition"] = False + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None @@ -573,6 +604,9 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""): num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.') num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.') sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix) + elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt + num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.') + sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix) elif 'x_embedder.weight' in state_dict: #Flux depth = count_blocks(state_dict, 'transformer_blocks.{}.') depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.') diff --git a/comfy/sd.py b/comfy/sd.py index dee8e984..fef38294 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -27,6 +27,7 @@ import comfy.text_encoders.sd2_clip import comfy.text_encoders.sd3_clip import comfy.text_encoders.sa_t5 import comfy.text_encoders.aura_t5 +import comfy.text_encoders.pixart_t5 import comfy.text_encoders.hydit import comfy.text_encoders.flux import comfy.text_encoders.long_clipl @@ -604,6 +605,8 @@ class CLIPType(Enum): MOCHI = 7 LTXV = 8 HUNYUAN_VIDEO = 9 + PIXART = 10 + def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): clip_data = [] @@ -696,6 +699,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.LTXV: clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer + elif clip_type == CLIPType.PIXART: + clip_target.clip = comfy.text_encoders.pixart_t5.pixart_te(**t5xxl_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.pixart_t5.PixArtTokenizer else: #CLIPType.MOCHI clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index c0fe1ba5..4845406d 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -37,8 +37,12 @@ class ClipTokenWeightEncoder: sections = len(to_encode) if has_weights or sections == 0: - to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len)) - + if hasattr(self, "gen_empty_tokens"): + to_encode.append(self.gen_empty_tokens(self.special_tokens, max_token_len)) + else: + to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len)) + print(to_encode) + o = self.encode(to_encode) out, pooled = o[:2] diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 68e2b13f..a5f38b5e 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -8,6 +8,7 @@ import comfy.text_encoders.sd2_clip import comfy.text_encoders.sd3_clip import comfy.text_encoders.sa_t5 import comfy.text_encoders.aura_t5 +import comfy.text_encoders.pixart_t5 import comfy.text_encoders.hydit import comfy.text_encoders.flux import comfy.text_encoders.genmo @@ -592,6 +593,37 @@ class AuraFlow(supported_models_base.BASE): def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(comfy.text_encoders.aura_t5.AuraT5Tokenizer, comfy.text_encoders.aura_t5.AuraT5Model) +class PixArtAlpha(supported_models_base.BASE): + unet_config = { + "image_model": "pixart_alpha", + } + + sampling_settings = { + "beta_schedule" : "sqrt_linear", + "linear_start" : 0.0001, + "linear_end" : 0.02, + "timesteps" : 1000, + } + + unet_extra_config = {} + latent_format = latent_formats.SD15 + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.PixArt(self, device=device) + return out.eval() + + def clip_target(self, state_dict={}): + return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.PixArtT5XXL) + +class PixArtSigma(PixArtAlpha): + unet_config = { + "image_model": "pixart_sigma", + } + latent_format = latent_formats.SDXL + class HunyuanDiT(supported_models_base.BASE): unet_config = { "image_model": "hydit", @@ -787,6 +819,6 @@ class HunyuanVideo(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect)) -models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo] +models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo] models += [SVD_img2vid] diff --git a/comfy/text_encoders/pixart_t5.py b/comfy/text_encoders/pixart_t5.py new file mode 100644 index 00000000..d56d57f1 --- /dev/null +++ b/comfy/text_encoders/pixart_t5.py @@ -0,0 +1,42 @@ +import os + +from comfy import sd1_clip +import comfy.text_encoders.t5 +import comfy.text_encoders.sd3_clip +from comfy.sd1_clip import gen_empty_tokens + +from transformers import T5TokenizerFast + +class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def gen_empty_tokens(self, special_tokens, *args, **kwargs): + # PixArt expects the negative to be all pad tokens + special_tokens = special_tokens.copy() + special_tokens.pop("end") + return gen_empty_tokens(special_tokens, *args, **kwargs) + +class PixArtT5XXL(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options) + +class T5XXLTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1) # no padding + +class PixArtTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) + +def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None): + class PixArtTEModel_(PixArtT5XXL): + def __init__(self, device="cpu", dtype=None, model_options={}): + if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + if dtype is None: + dtype = dtype_t5 + super().__init__(device=device, dtype=dtype, model_options=model_options) + return PixArtTEModel_ diff --git a/comfy/utils.py b/comfy/utils.py index 3ddbfd90..5fb5418b 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -386,6 +386,77 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""): return key_map +PIXART_MAP_BASIC = { + ("csize_embedder.mlp.0.weight", "adaln_single.emb.resolution_embedder.linear_1.weight"), + ("csize_embedder.mlp.0.bias", "adaln_single.emb.resolution_embedder.linear_1.bias"), + ("csize_embedder.mlp.2.weight", "adaln_single.emb.resolution_embedder.linear_2.weight"), + ("csize_embedder.mlp.2.bias", "adaln_single.emb.resolution_embedder.linear_2.bias"), + ("ar_embedder.mlp.0.weight", "adaln_single.emb.aspect_ratio_embedder.linear_1.weight"), + ("ar_embedder.mlp.0.bias", "adaln_single.emb.aspect_ratio_embedder.linear_1.bias"), + ("ar_embedder.mlp.2.weight", "adaln_single.emb.aspect_ratio_embedder.linear_2.weight"), + ("ar_embedder.mlp.2.bias", "adaln_single.emb.aspect_ratio_embedder.linear_2.bias"), + ("x_embedder.proj.weight", "pos_embed.proj.weight"), + ("x_embedder.proj.bias", "pos_embed.proj.bias"), + ("y_embedder.y_embedding", "caption_projection.y_embedding"), + ("y_embedder.y_proj.fc1.weight", "caption_projection.linear_1.weight"), + ("y_embedder.y_proj.fc1.bias", "caption_projection.linear_1.bias"), + ("y_embedder.y_proj.fc2.weight", "caption_projection.linear_2.weight"), + ("y_embedder.y_proj.fc2.bias", "caption_projection.linear_2.bias"), + ("t_embedder.mlp.0.weight", "adaln_single.emb.timestep_embedder.linear_1.weight"), + ("t_embedder.mlp.0.bias", "adaln_single.emb.timestep_embedder.linear_1.bias"), + ("t_embedder.mlp.2.weight", "adaln_single.emb.timestep_embedder.linear_2.weight"), + ("t_embedder.mlp.2.bias", "adaln_single.emb.timestep_embedder.linear_2.bias"), + ("t_block.1.weight", "adaln_single.linear.weight"), + ("t_block.1.bias", "adaln_single.linear.bias"), + ("final_layer.linear.weight", "proj_out.weight"), + ("final_layer.linear.bias", "proj_out.bias"), + ("final_layer.scale_shift_table", "scale_shift_table"), +} + +PIXART_MAP_BLOCK = { + ("scale_shift_table", "scale_shift_table"), + ("attn.proj.weight", "attn1.to_out.0.weight"), + ("attn.proj.bias", "attn1.to_out.0.bias"), + ("mlp.fc1.weight", "ff.net.0.proj.weight"), + ("mlp.fc1.bias", "ff.net.0.proj.bias"), + ("mlp.fc2.weight", "ff.net.2.weight"), + ("mlp.fc2.bias", "ff.net.2.bias"), + ("cross_attn.proj.weight" ,"attn2.to_out.0.weight"), + ("cross_attn.proj.bias" ,"attn2.to_out.0.bias"), +} + +def pixart_to_diffusers(mmdit_config, output_prefix=""): + key_map = {} + + depth = mmdit_config.get("depth", 0) + offset = mmdit_config.get("hidden_size", 1152) + + for i in range(depth): + block_from = "transformer_blocks.{}".format(i) + block_to = "{}blocks.{}".format(output_prefix, i) + + for end in ("weight", "bias"): + s = "{}.attn1.".format(block_from) + qkv = "{}.attn.qkv.{}".format(block_to, end) + key_map["{}to_q.{}".format(s, end)] = (qkv, (0, 0, offset)) + key_map["{}to_k.{}".format(s, end)] = (qkv, (0, offset, offset)) + key_map["{}to_v.{}".format(s, end)] = (qkv, (0, offset * 2, offset)) + + s = "{}.attn2.".format(block_from) + q = "{}.cross_attn.q_linear.{}".format(block_to, end) + kv = "{}.cross_attn.kv_linear.{}".format(block_to, end) + + key_map["{}to_q.{}".format(s, end)] = q + key_map["{}to_k.{}".format(s, end)] = (kv, (0, 0, offset)) + key_map["{}to_v.{}".format(s, end)] = (kv, (0, offset, offset)) + + for k in PIXART_MAP_BLOCK: + key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0]) + + for k in PIXART_MAP_BASIC: + key_map[k[1]] = "{}{}".format(output_prefix, k[0]) + + return key_map def auraflow_to_diffusers(mmdit_config, output_prefix=""): n_double_layers = mmdit_config.get("n_double_layers", 0) diff --git a/comfy_extras/nodes_pixart.py b/comfy_extras/nodes_pixart.py new file mode 100644 index 00000000..c7209c46 --- /dev/null +++ b/comfy_extras/nodes_pixart.py @@ -0,0 +1,24 @@ +from nodes import MAX_RESOLUTION + +class CLIPTextEncodePixArtAlpha: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), + "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), + # "aspect_ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ), + }} + + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "encode" + CATEGORY = "advanced/conditioning" + DESCRIPTION = "Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma." + + def encode(self, clip, width, height, text): + tokens = clip.tokenize(text) + return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height}),) + +NODE_CLASS_MAPPINGS = { + "CLIPTextEncodePixArtAlpha": CLIPTextEncodePixArtAlpha, +} diff --git a/nodes.py b/nodes.py index 1a90073e..bdea7564 100644 --- a/nodes.py +++ b/nodes.py @@ -898,7 +898,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart"], ), }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" @@ -918,6 +918,8 @@ class CLIPLoader: clip_type = comfy.sd.CLIPType.MOCHI elif type == "ltxv": clip_type = comfy.sd.CLIPType.LTXV + elif type == "pixart": + clip_type = comfy.sd.CLIPType.PIXART else: clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION @@ -2164,6 +2166,7 @@ def init_builtin_extra_nodes(): "nodes_stable3d.py", "nodes_sdupscale.py", "nodes_photomaker.py", + "nodes_pixart.py", "nodes_cond.py", "nodes_morphology.py", "nodes_stable_cascade.py",