mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 14:09:36 +00:00
Code refactor.
This commit is contained in:
parent
88ceb28e20
commit
31831e6ef1
@ -230,8 +230,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
|
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
|
||||||
mod, _ = self.modulation(vec)
|
mod, _ = self.modulation(vec)
|
||||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
qkv, mlp = torch.split(self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
|
||||||
|
|
||||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
q, k = self.norm(q, k, v)
|
q, k = self.norm(q, k, v)
|
||||||
|
Loading…
Reference in New Issue
Block a user