From e946667216b524ae087863fdea23936a3e9394f6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 20 Dec 2024 17:10:52 -0500 Subject: [PATCH] Some fixes/cleanups to pixart code. Commented out the masking related code because it is never used in this implementation. --- comfy/ldm/pixart/blocks.py | 61 ++++++++++++++++++------------------ comfy/ldm/pixart/pixart.py | 15 +++------ comfy/ldm/pixart/pixartms.py | 4 +-- comfy/model_base.py | 4 +++ 4 files changed, 42 insertions(+), 42 deletions(-) diff --git a/comfy/ldm/pixart/blocks.py b/comfy/ldm/pixart/blocks.py index 7ad2ec29..f60bfd79 100644 --- a/comfy/ldm/pixart/blocks.py +++ b/comfy/ldm/pixart/blocks.py @@ -46,32 +46,33 @@ class MultiHeadCrossAttention(nn.Module): kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim) k, v = kv.unbind(2) - # TODO: xformers needs separate mask logic here - if model_management.xformers_enabled(): - attn_bias = None - if mask is not None: - attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask) - x = xformers.ops.memory_efficient_attention(q, k, v, p=0, attn_bias=attn_bias) - else: - q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),) - attn_mask = None - if mask is not None and len(mask) > 1: - # Create equivalent of xformer diagonal block mask, still only correct for square masks - # But depth doesn't matter as tensors can expand in that dimension - attn_mask_template = torch.ones( - [q.shape[2] // B, mask[0]], - dtype=torch.bool, - device=q.device - ) - attn_mask = torch.block_diag(attn_mask_template) + assert mask is None # TODO? + # # TODO: xformers needs separate mask logic here + # if model_management.xformers_enabled(): + # attn_bias = None + # if mask is not None: + # attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask) + # x = xformers.ops.memory_efficient_attention(q, k, v, p=0, attn_bias=attn_bias) + # else: + # q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),) + # attn_mask = None + # mask = torch.ones(()) + # if mask is not None and len(mask) > 1: + # # Create equivalent of xformer diagonal block mask, still only correct for square masks + # # But depth doesn't matter as tensors can expand in that dimension + # attn_mask_template = torch.ones( + # [q.shape[2] // B, mask[0]], + # dtype=torch.bool, + # device=q.device + # ) + # attn_mask = torch.block_diag(attn_mask_template) + # + # # create a mask on the diagonal for each mask in the batch + # for _ in range(B - 1): + # attn_mask = torch.block_diag(attn_mask, attn_mask_template) + # x = optimized_attention(q, k, v, self.num_heads, mask=attn_mask, skip_reshape=True) - # create a mask on the diagonal for each mask in the batch - for _ in range(B - 1): - attn_mask = torch.block_diag(attn_mask, attn_mask_template) - - x = optimized_attention(q, k, v, self.num_heads, mask=attn_mask, skip_reshape=True) - - x = x.view(B, -1, C) + x = optimized_attention(q.view(B, -1, C), k.view(B, -1, C), v.view(B, -1, C), self.num_heads, mask=None) x = self.proj(x) x = self.proj_drop(x) return x @@ -155,9 +156,9 @@ class AttentionKVCompress(nn.Module): k, new_N = self.downsample_2d(k, H, W, self.sr_ratio, sampling=self.sampling) v, new_N = self.downsample_2d(v, H, W, self.sr_ratio, sampling=self.sampling) - q = q.reshape(B, N, self.num_heads, C // self.num_heads).to(dtype) - k = k.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype) - v = v.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype) + q = q.reshape(B, N, self.num_heads, C // self.num_heads) + k = k.reshape(B, new_N, self.num_heads, C // self.num_heads) + v = v.reshape(B, new_N, self.num_heads, C // self.num_heads) if mask is not None: raise NotImplementedError("Attn mask logic not added for self attention") @@ -209,9 +210,9 @@ class T2IFinalLayer(nn.Module): def forward(self, x, t): dtype = x.dtype - shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1) + shift, scale = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t[:, None]).chunk(2, dim=1) x = t2i_modulate(self.norm_final(x), shift, scale) - x = self.linear(x.to(dtype)) + x = self.linear(x) return x diff --git a/comfy/ldm/pixart/pixart.py b/comfy/ldm/pixart/pixart.py index cd572efc..e1e61faf 100644 --- a/comfy/ldm/pixart/pixart.py +++ b/comfy/ldm/pixart/pixart.py @@ -127,12 +127,8 @@ class PixArt(nn.Module): t: (N,) tensor of diffusion timesteps y: (N, 1, 120, C) tensor of class labels """ - x = x.to(self.dtype) - timestep = t.to(self.dtype) - y = y.to(self.dtype) - pos_embed = self.pos_embed.to(self.dtype) x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 - t = self.t_embedder(timestep.to(x.dtype)) # (N, D) + t = self.t_embedder(timestep) # (N, D) t0 = self.t_block(t) y = self.y_embedder(y, self.training) # (N, 1, L, D) if mask is not None: @@ -142,7 +138,7 @@ class PixArt(nn.Module): y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) y_lens = mask.sum(dim=1).tolist() else: - y_lens = [y.shape[2]] * y.shape[0] + y_lens = None y = y.squeeze(1).view(1, -1, x.shape[-1]) for block in self.blocks: x = block(x, y, t0, y_lens) # (N, T, D) @@ -164,13 +160,12 @@ class PixArt(nn.Module): ## run original forward pass out = self.forward_raw( - x = x.to(self.dtype), - t = timesteps.to(self.dtype), - y = context.to(self.dtype), + x = x, + t = timesteps, + y = context, ) ## only return EPS - out = out.to(torch.float) eps, _ = out[:, :self.in_channels], out[:, self.in_channels:] return eps diff --git a/comfy/ldm/pixart/pixartms.py b/comfy/ldm/pixart/pixartms.py index 195063b0..8ff0d0a4 100644 --- a/comfy/ldm/pixart/pixartms.py +++ b/comfy/ldm/pixart/pixartms.py @@ -44,7 +44,7 @@ class PixArtMSBlock(nn.Module): def forward(self, x, y, t, mask=None, HW=None, **kwargs): B, N, C = x.shape - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(x.dtype) + t.reshape(B, 6, -1)).chunk(6, dim=1) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1) x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW)) x = x + self.cross_attn(x, y, mask) x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) @@ -196,7 +196,7 @@ class PixArtMS(PixArt): y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) y_lens = mask.sum(dim=1).tolist() else: - y_lens = [y.shape[2]] * y.shape[0] + y_lens = None y = y.squeeze(1).view(1, -1, x.shape[-1]) for block in self.blocks: x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D) diff --git a/comfy/model_base.py b/comfy/model_base.py index af3f0f14..76b2289b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -726,6 +726,10 @@ class PixArt(BaseModel): def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + width = kwargs.get("width", None) height = kwargs.get("height", None) if width is not None and height is not None: