mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Add transformer options blocks replace patch to mochi.
This commit is contained in:
parent
22a1d7ce78
commit
41886af138
@ -494,8 +494,9 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
packed_indices: Dict[str, torch.Tensor] = None,
|
packed_indices: Dict[str, torch.Tensor] = None,
|
||||||
rope_cos: torch.Tensor = None,
|
rope_cos: torch.Tensor = None,
|
||||||
rope_sin: torch.Tensor = None,
|
rope_sin: torch.Tensor = None,
|
||||||
control=None, **kwargs
|
control=None, transformer_options={}, **kwargs
|
||||||
):
|
):
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
y_feat = context
|
y_feat = context
|
||||||
y_mask = attention_mask
|
y_mask = attention_mask
|
||||||
sigma = timestep
|
sigma = timestep
|
||||||
@ -515,15 +516,32 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
)
|
)
|
||||||
del y_mask
|
del y_mask
|
||||||
|
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
x, y_feat = block(
|
if ("double_block", i) in blocks_replace:
|
||||||
x,
|
def block_wrap(args):
|
||||||
c,
|
out = {}
|
||||||
y_feat,
|
out["img"], out["txt"] = block(
|
||||||
rope_cos=rope_cos,
|
args["img"],
|
||||||
rope_sin=rope_sin,
|
args["vec"],
|
||||||
crop_y=num_tokens,
|
args["txt"],
|
||||||
) # (B, M, D), (B, L, D)
|
rope_cos=args["rope_cos"],
|
||||||
|
rope_sin=args["rope_sin"],
|
||||||
|
crop_y=args["num_tokens"]
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
|
||||||
|
y_feat = out["txt"]
|
||||||
|
x = out["img"]
|
||||||
|
else:
|
||||||
|
x, y_feat = block(
|
||||||
|
x,
|
||||||
|
c,
|
||||||
|
y_feat,
|
||||||
|
rope_cos=rope_cos,
|
||||||
|
rope_sin=rope_sin,
|
||||||
|
crop_y=num_tokens,
|
||||||
|
) # (B, M, D), (B, L, D)
|
||||||
del y_feat # Final layers don't use dense text features.
|
del y_feat # Final layers don't use dense text features.
|
||||||
|
|
||||||
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
|
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
|
||||||
|
Loading…
Reference in New Issue
Block a user