diff --git a/comfy/ldm/flux/controlnet_xlabs.py b/comfy/ldm/flux/controlnet_xlabs.py index 3f40021b..e5743321 100644 --- a/comfy/ldm/flux/controlnet_xlabs.py +++ b/comfy/ldm/flux/controlnet_xlabs.py @@ -64,8 +64,8 @@ class ControlNetFlux(Flux): img = img + controlnet_cond vec = self.time_in(timestep_embedding(timesteps, 256)) if self.params.guidance_embed: - vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) - vec = vec + self.vector_in(y) + vec.add_(self.guidance_in(timestep_embedding(guidance, 256))) + vec.add_(self.vector_in(y)) txt = self.txt_in(txt) ids = torch.cat((txt_ids, img_ids), dim=1) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index da0cf61b..0cd2200c 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -150,14 +150,16 @@ class DoubleStreamBlock(nn.Module): # prepare image for attention img_modulated = self.img_norm1(img) - img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_mod1.scale += 1 + img_modulated = img_mod1.scale * img_modulated + img_mod1.shift img_qkv = self.img_attn.qkv(img_modulated) img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) # prepare txt for attention txt_modulated = self.txt_norm1(txt) - txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_mod1.scale += 1 + txt_modulated = txt_mod1.scale * txt_modulated + txt_mod1.shift txt_qkv = self.txt_attn.qkv(txt_modulated) txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) @@ -170,12 +172,12 @@ class DoubleStreamBlock(nn.Module): txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] # calculate the img bloks - img = img + img_mod1.gate * self.img_attn.proj(img_attn) - img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) + img = img.addcmul(img_mod1.gate, self.img_attn.proj(img_attn)) + img.addcmul_(img_mod2.gate, self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)) # calculate the txt bloks - txt += txt_mod1.gate * self.txt_attn.proj(txt_attn) - txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn)) + txt.addcmul_(txt_mod2.gate, self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)) if txt.dtype == torch.float16: txt = txt.clip(-65504, 65504) @@ -221,8 +223,8 @@ class SingleStreamBlock(nn.Module): def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: mod, _ = self.modulation(vec) - x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift - qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + mod.scale += 1 + qkv, mlp = torch.split(self.linear1(mod.scale * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k = self.norm(q, k, v) @@ -230,8 +232,7 @@ class SingleStreamBlock(nn.Module): # compute attention attn = attention(q, k, v, pe=pe) # compute activation in mlp stream, cat again and run second linear layer - output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) - x += mod.gate * output + x.addcmul_(mod.gate, self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))) if x.dtype == torch.float16: x = x.clip(-65504, 65504) return x diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index b5373540..f052bd0e 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -106,9 +106,9 @@ class Flux(nn.Module): if self.params.guidance_embed: if guidance is None: raise ValueError("Didn't get guidance strength for guidance distilled model.") - vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) + vec.add_(self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))) - vec = vec + self.vector_in(y) + vec.add_(self.vector_in(y)) txt = self.txt_in(txt) ids = torch.cat((txt_ids, img_ids), dim=1)