Add transformer options blocks replace patch to mochi.

This commit is contained in:
comfyanonymous 2024-11-16 14:47:06 -05:00
parent 22a1d7ce78
commit 41886af138

View File

@ -494,8 +494,9 @@ class AsymmDiTJoint(nn.Module):
packed_indices: Dict[str, torch.Tensor] = None,
rope_cos: torch.Tensor = None,
rope_sin: torch.Tensor = None,
control=None, **kwargs
control=None, transformer_options={}, **kwargs
):
patches_replace = transformer_options.get("patches_replace", {})
y_feat = context
y_mask = attention_mask
sigma = timestep
@ -515,15 +516,32 @@ class AsymmDiTJoint(nn.Module):
)
del y_mask
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
x, y_feat = block(
x,
c,
y_feat,
rope_cos=rope_cos,
rope_sin=rope_sin,
crop_y=num_tokens,
) # (B, M, D), (B, L, D)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(
args["img"],
args["vec"],
args["txt"],
rope_cos=args["rope_cos"],
rope_sin=args["rope_sin"],
crop_y=args["num_tokens"]
)
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
y_feat = out["txt"]
x = out["img"]
else:
x, y_feat = block(
x,
c,
y_feat,
rope_cos=rope_cos,
rope_sin=rope_sin,
crop_y=num_tokens,
) # (B, M, D), (B, L, D)
del y_feat # Final layers don't use dense text features.
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)