From 31831e6ef13474b975eee1a94f39078e00b00156 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 16 Jan 2025 07:23:54 -0500 Subject: [PATCH] Code refactor. --- comfy/ldm/flux/layers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 8e055151f..59a62e0df 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -230,8 +230,7 @@ class SingleStreamBlock(nn.Module): def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor: mod, _ = self.modulation(vec) - x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift - qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + qkv, mlp = torch.split(self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k = self.norm(q, k, v)