mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Some fixes/cleanups to pixart code.
Commented out the masking related code because it is never used in this implementation.
This commit is contained in:
parent
d7969cb070
commit
e946667216
@ -46,32 +46,33 @@ class MultiHeadCrossAttention(nn.Module):
|
|||||||
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
|
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
|
||||||
k, v = kv.unbind(2)
|
k, v = kv.unbind(2)
|
||||||
|
|
||||||
# TODO: xformers needs separate mask logic here
|
assert mask is None # TODO?
|
||||||
if model_management.xformers_enabled():
|
# # TODO: xformers needs separate mask logic here
|
||||||
attn_bias = None
|
# if model_management.xformers_enabled():
|
||||||
if mask is not None:
|
# attn_bias = None
|
||||||
attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask)
|
# if mask is not None:
|
||||||
x = xformers.ops.memory_efficient_attention(q, k, v, p=0, attn_bias=attn_bias)
|
# attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask)
|
||||||
else:
|
# x = xformers.ops.memory_efficient_attention(q, k, v, p=0, attn_bias=attn_bias)
|
||||||
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),)
|
# else:
|
||||||
attn_mask = None
|
# q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),)
|
||||||
if mask is not None and len(mask) > 1:
|
# attn_mask = None
|
||||||
# Create equivalent of xformer diagonal block mask, still only correct for square masks
|
# mask = torch.ones(())
|
||||||
# But depth doesn't matter as tensors can expand in that dimension
|
# if mask is not None and len(mask) > 1:
|
||||||
attn_mask_template = torch.ones(
|
# # Create equivalent of xformer diagonal block mask, still only correct for square masks
|
||||||
[q.shape[2] // B, mask[0]],
|
# # But depth doesn't matter as tensors can expand in that dimension
|
||||||
dtype=torch.bool,
|
# attn_mask_template = torch.ones(
|
||||||
device=q.device
|
# [q.shape[2] // B, mask[0]],
|
||||||
)
|
# dtype=torch.bool,
|
||||||
attn_mask = torch.block_diag(attn_mask_template)
|
# device=q.device
|
||||||
|
# )
|
||||||
|
# attn_mask = torch.block_diag(attn_mask_template)
|
||||||
|
#
|
||||||
|
# # create a mask on the diagonal for each mask in the batch
|
||||||
|
# for _ in range(B - 1):
|
||||||
|
# attn_mask = torch.block_diag(attn_mask, attn_mask_template)
|
||||||
|
# x = optimized_attention(q, k, v, self.num_heads, mask=attn_mask, skip_reshape=True)
|
||||||
|
|
||||||
# create a mask on the diagonal for each mask in the batch
|
x = optimized_attention(q.view(B, -1, C), k.view(B, -1, C), v.view(B, -1, C), self.num_heads, mask=None)
|
||||||
for _ in range(B - 1):
|
|
||||||
attn_mask = torch.block_diag(attn_mask, attn_mask_template)
|
|
||||||
|
|
||||||
x = optimized_attention(q, k, v, self.num_heads, mask=attn_mask, skip_reshape=True)
|
|
||||||
|
|
||||||
x = x.view(B, -1, C)
|
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
x = self.proj_drop(x)
|
x = self.proj_drop(x)
|
||||||
return x
|
return x
|
||||||
@ -155,9 +156,9 @@ class AttentionKVCompress(nn.Module):
|
|||||||
k, new_N = self.downsample_2d(k, H, W, self.sr_ratio, sampling=self.sampling)
|
k, new_N = self.downsample_2d(k, H, W, self.sr_ratio, sampling=self.sampling)
|
||||||
v, new_N = self.downsample_2d(v, H, W, self.sr_ratio, sampling=self.sampling)
|
v, new_N = self.downsample_2d(v, H, W, self.sr_ratio, sampling=self.sampling)
|
||||||
|
|
||||||
q = q.reshape(B, N, self.num_heads, C // self.num_heads).to(dtype)
|
q = q.reshape(B, N, self.num_heads, C // self.num_heads)
|
||||||
k = k.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype)
|
k = k.reshape(B, new_N, self.num_heads, C // self.num_heads)
|
||||||
v = v.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype)
|
v = v.reshape(B, new_N, self.num_heads, C // self.num_heads)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
raise NotImplementedError("Attn mask logic not added for self attention")
|
raise NotImplementedError("Attn mask logic not added for self attention")
|
||||||
@ -209,9 +210,9 @@ class T2IFinalLayer(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, t):
|
def forward(self, x, t):
|
||||||
dtype = x.dtype
|
dtype = x.dtype
|
||||||
shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
|
shift, scale = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t[:, None]).chunk(2, dim=1)
|
||||||
x = t2i_modulate(self.norm_final(x), shift, scale)
|
x = t2i_modulate(self.norm_final(x), shift, scale)
|
||||||
x = self.linear(x.to(dtype))
|
x = self.linear(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@ -127,12 +127,8 @@ class PixArt(nn.Module):
|
|||||||
t: (N,) tensor of diffusion timesteps
|
t: (N,) tensor of diffusion timesteps
|
||||||
y: (N, 1, 120, C) tensor of class labels
|
y: (N, 1, 120, C) tensor of class labels
|
||||||
"""
|
"""
|
||||||
x = x.to(self.dtype)
|
|
||||||
timestep = t.to(self.dtype)
|
|
||||||
y = y.to(self.dtype)
|
|
||||||
pos_embed = self.pos_embed.to(self.dtype)
|
|
||||||
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
||||||
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
|
t = self.t_embedder(timestep) # (N, D)
|
||||||
t0 = self.t_block(t)
|
t0 = self.t_block(t)
|
||||||
y = self.y_embedder(y, self.training) # (N, 1, L, D)
|
y = self.y_embedder(y, self.training) # (N, 1, L, D)
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
@ -142,7 +138,7 @@ class PixArt(nn.Module):
|
|||||||
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
|
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
|
||||||
y_lens = mask.sum(dim=1).tolist()
|
y_lens = mask.sum(dim=1).tolist()
|
||||||
else:
|
else:
|
||||||
y_lens = [y.shape[2]] * y.shape[0]
|
y_lens = None
|
||||||
y = y.squeeze(1).view(1, -1, x.shape[-1])
|
y = y.squeeze(1).view(1, -1, x.shape[-1])
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x, y, t0, y_lens) # (N, T, D)
|
x = block(x, y, t0, y_lens) # (N, T, D)
|
||||||
@ -164,13 +160,12 @@ class PixArt(nn.Module):
|
|||||||
|
|
||||||
## run original forward pass
|
## run original forward pass
|
||||||
out = self.forward_raw(
|
out = self.forward_raw(
|
||||||
x = x.to(self.dtype),
|
x = x,
|
||||||
t = timesteps.to(self.dtype),
|
t = timesteps,
|
||||||
y = context.to(self.dtype),
|
y = context,
|
||||||
)
|
)
|
||||||
|
|
||||||
## only return EPS
|
## only return EPS
|
||||||
out = out.to(torch.float)
|
|
||||||
eps, _ = out[:, :self.in_channels], out[:, self.in_channels:]
|
eps, _ = out[:, :self.in_channels], out[:, self.in_channels:]
|
||||||
return eps
|
return eps
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ class PixArtMSBlock(nn.Module):
|
|||||||
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
|
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
|
||||||
B, N, C = x.shape
|
B, N, C = x.shape
|
||||||
|
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(x.dtype) + t.reshape(B, 6, -1)).chunk(6, dim=1)
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
|
||||||
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
|
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
|
||||||
x = x + self.cross_attn(x, y, mask)
|
x = x + self.cross_attn(x, y, mask)
|
||||||
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
|
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
|
||||||
@ -196,7 +196,7 @@ class PixArtMS(PixArt):
|
|||||||
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
|
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
|
||||||
y_lens = mask.sum(dim=1).tolist()
|
y_lens = mask.sum(dim=1).tolist()
|
||||||
else:
|
else:
|
||||||
y_lens = [y.shape[2]] * y.shape[0]
|
y_lens = None
|
||||||
y = y.squeeze(1).view(1, -1, x.shape[-1])
|
y = y.squeeze(1).view(1, -1, x.shape[-1])
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
|
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
|
||||||
|
@ -726,6 +726,10 @@ class PixArt(BaseModel):
|
|||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = super().extra_conds(**kwargs)
|
out = super().extra_conds(**kwargs)
|
||||||
|
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
width = kwargs.get("width", None)
|
width = kwargs.get("width", None)
|
||||||
height = kwargs.get("height", None)
|
height = kwargs.get("height", None)
|
||||||
if width is not None and height is not None:
|
if width is not None and height is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user