From 94f21f93012173d8f1027bc5f59361cf200b8b37 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 5 Feb 2025 04:32:47 -0500 Subject: [PATCH] Upcasting rope to fp32 seems to make no difference in this model. --- comfy/ldm/lumina/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 442a814c3..ec4119722 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -93,9 +93,9 @@ class JointAttention(nn.Module): and key tensor with rotary embeddings. """ - t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2).float() + t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2) t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] - return t_out.reshape(*x_in.shape).type_as(x_in) + return t_out.reshape(*x_in.shape) def forward( self, @@ -552,7 +552,7 @@ class NextDiT(nn.Module): position_ids[i, cap_len:cap_len+img_len, 1] = row_ids position_ids[i, cap_len:cap_len+img_len, 2] = col_ids - freqs_cis = self.rope_embedder(position_ids).movedim(1, 2) + freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype) # build freqs_cis for cap and image individually cap_freqs_cis_shape = list(freqs_cis.shape)