Skip layer guidance now works on hydit model.

This commit is contained in:
comfyanonymous 2024-11-24 05:54:30 -05:00
parent 3d802710e7
commit b4526d3fc3

View File

@ -287,7 +287,7 @@ class HunYuanDiT(nn.Module):
style=None, style=None,
return_dict=False, return_dict=False,
control=None, control=None,
transformer_options=None, transformer_options={},
): ):
""" """
Forward pass of the encoder. Forward pass of the encoder.
@ -315,8 +315,7 @@ class HunYuanDiT(nn.Module):
return_dict: bool return_dict: bool
Whether to return a dictionary. Whether to return a dictionary.
""" """
#import pdb patches_replace = transformer_options.get("patches_replace", {})
#pdb.set_trace()
encoder_hidden_states = context encoder_hidden_states = context
text_states = encoder_hidden_states # 2,77,1024 text_states = encoder_hidden_states # 2,77,1024
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048 text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
@ -364,6 +363,8 @@ class HunYuanDiT(nn.Module):
# Concatenate all extra vectors # Concatenate all extra vectors
c = t + self.extra_embedder(extra_vec) # [B, D] c = t + self.extra_embedder(extra_vec) # [B, D]
blocks_replace = patches_replace.get("dit", {})
controls = None controls = None
if control: if control:
controls = control.get("output", None) controls = control.get("output", None)
@ -375,9 +376,20 @@ class HunYuanDiT(nn.Module):
skip = skips.pop() + controls.pop().to(dtype=x.dtype) skip = skips.pop() + controls.pop().to(dtype=x.dtype)
else: else:
skip = skips.pop() skip = skips.pop()
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
else: else:
x = block(x, c, text_states, freqs_cis_img) # (N, L, D) skip = None
if ("double_block", layer) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], args["vec"], args["txt"], args["pe"], args["skip"])
return out
out = blocks_replace[("double_block", layer)]({"img": x, "txt": text_states, "vec": c, "pe": freqs_cis_img, "skip": skip}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
if layer < (self.depth // 2 - 1): if layer < (self.depth // 2 - 1):
skips.append(x) skips.append(x)