diff --git a/comfy/ldm/genmo/joint_model/asymm_models_joint.py b/comfy/ldm/genmo/joint_model/asymm_models_joint.py index c36a0006..45c93896 100644 --- a/comfy/ldm/genmo/joint_model/asymm_models_joint.py +++ b/comfy/ldm/genmo/joint_model/asymm_models_joint.py @@ -494,8 +494,9 @@ class AsymmDiTJoint(nn.Module): packed_indices: Dict[str, torch.Tensor] = None, rope_cos: torch.Tensor = None, rope_sin: torch.Tensor = None, - control=None, **kwargs + control=None, transformer_options={}, **kwargs ): + patches_replace = transformer_options.get("patches_replace", {}) y_feat = context y_mask = attention_mask sigma = timestep @@ -515,15 +516,32 @@ class AsymmDiTJoint(nn.Module): ) del y_mask + blocks_replace = patches_replace.get("dit", {}) for i, block in enumerate(self.blocks): - x, y_feat = block( - x, - c, - y_feat, - rope_cos=rope_cos, - rope_sin=rope_sin, - crop_y=num_tokens, - ) # (B, M, D), (B, L, D) + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"], out["txt"] = block( + args["img"], + args["vec"], + args["txt"], + rope_cos=args["rope_cos"], + rope_sin=args["rope_sin"], + crop_y=args["num_tokens"] + ) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap}) + y_feat = out["txt"] + x = out["img"] + else: + x, y_feat = block( + x, + c, + y_feat, + rope_cos=rope_cos, + rope_sin=rope_sin, + crop_y=num_tokens, + ) # (B, M, D), (B, L, D) del y_feat # Final layers don't use dense text features. x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)