diff --git a/comfy/ldm/hydit/models.py b/comfy/ldm/hydit/models.py index 44e806cb..88459457 100644 --- a/comfy/ldm/hydit/models.py +++ b/comfy/ldm/hydit/models.py @@ -287,7 +287,7 @@ class HunYuanDiT(nn.Module): style=None, return_dict=False, control=None, - transformer_options=None, + transformer_options={}, ): """ Forward pass of the encoder. @@ -315,8 +315,7 @@ class HunYuanDiT(nn.Module): return_dict: bool Whether to return a dictionary. """ - #import pdb - #pdb.set_trace() + patches_replace = transformer_options.get("patches_replace", {}) encoder_hidden_states = context text_states = encoder_hidden_states # 2,77,1024 text_states_t5 = encoder_hidden_states_t5 # 2,256,2048 @@ -364,6 +363,8 @@ class HunYuanDiT(nn.Module): # Concatenate all extra vectors c = t + self.extra_embedder(extra_vec) # [B, D] + blocks_replace = patches_replace.get("dit", {}) + controls = None if control: controls = control.get("output", None) @@ -375,9 +376,20 @@ class HunYuanDiT(nn.Module): skip = skips.pop() + controls.pop().to(dtype=x.dtype) else: skip = skips.pop() - x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D) else: - x = block(x, c, text_states, freqs_cis_img) # (N, L, D) + skip = None + + if ("double_block", layer) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], args["vec"], args["txt"], args["pe"], args["skip"]) + return out + + out = blocks_replace[("double_block", layer)]({"img": x, "txt": text_states, "vec": c, "pe": freqs_cis_img, "skip": skip}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D) + if layer < (self.depth // 2 - 1): skips.append(x)