mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-06 23:37:10 +08:00
add transf options argument to stream blocks
This commit is contained in:
parent
855ea059a4
commit
38e97ef2fd
@ -159,7 +159,8 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
self.flipped_img_txt = flipped_img_txt
|
self.flipped_img_txt = flipped_img_txt
|
||||||
|
|
||||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
|
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
||||||
|
|
||||||
img_mod1, img_mod2 = self.img_mod(vec)
|
img_mod1, img_mod2 = self.img_mod(vec)
|
||||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||||
|
|
||||||
@ -244,7 +245,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
self.mlp_act = nn.GELU(approximate="tanh")
|
self.mlp_act = nn.GELU(approximate="tanh")
|
||||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor:
|
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
|
||||||
mod, _ = self.modulation(vec)
|
mod, _ = self.modulation(vec)
|
||||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user