mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Not sure if this actually changes anything but it can't hurt.
This commit is contained in:
parent
39fb74c5bd
commit
34608de2e9
@ -64,8 +64,8 @@ class ControlNetFlux(Flux):
|
|||||||
img = img + controlnet_cond
|
img = img + controlnet_cond
|
||||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||||
if self.params.guidance_embed:
|
if self.params.guidance_embed:
|
||||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
vec.add_(self.guidance_in(timestep_embedding(guidance, 256)))
|
||||||
vec = vec + self.vector_in(y)
|
vec.add_(self.vector_in(y))
|
||||||
txt = self.txt_in(txt)
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
@ -150,14 +150,16 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
|
|
||||||
# prepare image for attention
|
# prepare image for attention
|
||||||
img_modulated = self.img_norm1(img)
|
img_modulated = self.img_norm1(img)
|
||||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
img_mod1.scale += 1
|
||||||
|
img_modulated = img_mod1.scale * img_modulated + img_mod1.shift
|
||||||
img_qkv = self.img_attn.qkv(img_modulated)
|
img_qkv = self.img_attn.qkv(img_modulated)
|
||||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||||
|
|
||||||
# prepare txt for attention
|
# prepare txt for attention
|
||||||
txt_modulated = self.txt_norm1(txt)
|
txt_modulated = self.txt_norm1(txt)
|
||||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
txt_mod1.scale += 1
|
||||||
|
txt_modulated = txt_mod1.scale * txt_modulated + txt_mod1.shift
|
||||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||||
@ -170,12 +172,12 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||||
|
|
||||||
# calculate the img bloks
|
# calculate the img bloks
|
||||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
img = img.addcmul(img_mod1.gate, self.img_attn.proj(img_attn))
|
||||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
img.addcmul_(img_mod2.gate, self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift))
|
||||||
|
|
||||||
# calculate the txt bloks
|
# calculate the txt bloks
|
||||||
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
|
||||||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
txt.addcmul_(txt_mod2.gate, self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift))
|
||||||
|
|
||||||
if txt.dtype == torch.float16:
|
if txt.dtype == torch.float16:
|
||||||
txt = txt.clip(-65504, 65504)
|
txt = txt.clip(-65504, 65504)
|
||||||
@ -221,8 +223,8 @@ class SingleStreamBlock(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||||
mod, _ = self.modulation(vec)
|
mod, _ = self.modulation(vec)
|
||||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
mod.scale += 1
|
||||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
qkv, mlp = torch.split(self.linear1(mod.scale * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||||
|
|
||||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
q, k = self.norm(q, k, v)
|
q, k = self.norm(q, k, v)
|
||||||
@ -230,8 +232,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
# compute attention
|
# compute attention
|
||||||
attn = attention(q, k, v, pe=pe)
|
attn = attention(q, k, v, pe=pe)
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
x.addcmul_(mod.gate, self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)))
|
||||||
x += mod.gate * output
|
|
||||||
if x.dtype == torch.float16:
|
if x.dtype == torch.float16:
|
||||||
x = x.clip(-65504, 65504)
|
x = x.clip(-65504, 65504)
|
||||||
return x
|
return x
|
||||||
|
@ -106,9 +106,9 @@ class Flux(nn.Module):
|
|||||||
if self.params.guidance_embed:
|
if self.params.guidance_embed:
|
||||||
if guidance is None:
|
if guidance is None:
|
||||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
vec.add_(self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)))
|
||||||
|
|
||||||
vec = vec + self.vector_in(y)
|
vec.add_(self.vector_in(y))
|
||||||
txt = self.txt_in(txt)
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
Loading…
Reference in New Issue
Block a user