diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 8e055151..59a62e0d 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -230,8 +230,7 @@ class SingleStreamBlock(nn.Module): def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor: mod, _ = self.modulation(vec) - x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift - qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + qkv, mlp = torch.split(self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k = self.norm(q, k, v)