Upcasting rope to fp32 seems to make no difference in this model.

This commit is contained in:
comfyanonymous 2025-02-05 04:32:47 -05:00
parent 60653004e5
commit 94f21f9301

View File

@ -93,9 +93,9 @@ class JointAttention(nn.Module):
and key tensor with rotary embeddings.
"""
t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2).float()
t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x_in.shape).type_as(x_in)
return t_out.reshape(*x_in.shape)
def forward(
self,
@ -552,7 +552,7 @@ class NextDiT(nn.Module):
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2)
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype)
# build freqs_cis for cap and image individually
cap_freqs_cis_shape = list(freqs_cis.shape)