mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-14 23:53:30 +00:00
Allow disabling pe in flux code for some other models.
This commit is contained in:
parent
50614f1b79
commit
3b19fc76e3
@ -10,10 +10,11 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
|||||||
q_shape = q.shape
|
q_shape = q.shape
|
||||||
k_shape = k.shape
|
k_shape = k.shape
|
||||||
|
|
||||||
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
|
if pe is not None:
|
||||||
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
|
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
|
||||||
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
|
||||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
||||||
|
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||||
|
|
||||||
heads = q.shape[1]
|
heads = q.shape[1]
|
||||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
||||||
|
@ -115,8 +115,11 @@ class Flux(nn.Module):
|
|||||||
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
||||||
txt = self.txt_in(txt)
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
if img_ids is not None:
|
||||||
pe = self.pe_embedder(ids)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
pe = self.pe_embedder(ids)
|
||||||
|
else:
|
||||||
|
pe = None
|
||||||
|
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
for i, block in enumerate(self.double_blocks):
|
for i, block in enumerate(self.double_blocks):
|
||||||
|
Loading…
Reference in New Issue
Block a user