mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Skip layer guidance now works on hydit model.
This commit is contained in:
parent
3d802710e7
commit
b4526d3fc3
@ -287,7 +287,7 @@ class HunYuanDiT(nn.Module):
|
|||||||
style=None,
|
style=None,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
control=None,
|
control=None,
|
||||||
transformer_options=None,
|
transformer_options={},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Forward pass of the encoder.
|
Forward pass of the encoder.
|
||||||
@ -315,8 +315,7 @@ class HunYuanDiT(nn.Module):
|
|||||||
return_dict: bool
|
return_dict: bool
|
||||||
Whether to return a dictionary.
|
Whether to return a dictionary.
|
||||||
"""
|
"""
|
||||||
#import pdb
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
#pdb.set_trace()
|
|
||||||
encoder_hidden_states = context
|
encoder_hidden_states = context
|
||||||
text_states = encoder_hidden_states # 2,77,1024
|
text_states = encoder_hidden_states # 2,77,1024
|
||||||
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
|
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
|
||||||
@ -364,6 +363,8 @@ class HunYuanDiT(nn.Module):
|
|||||||
# Concatenate all extra vectors
|
# Concatenate all extra vectors
|
||||||
c = t + self.extra_embedder(extra_vec) # [B, D]
|
c = t + self.extra_embedder(extra_vec) # [B, D]
|
||||||
|
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
|
||||||
controls = None
|
controls = None
|
||||||
if control:
|
if control:
|
||||||
controls = control.get("output", None)
|
controls = control.get("output", None)
|
||||||
@ -375,9 +376,20 @@ class HunYuanDiT(nn.Module):
|
|||||||
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
|
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
|
||||||
else:
|
else:
|
||||||
skip = skips.pop()
|
skip = skips.pop()
|
||||||
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
|
||||||
else:
|
else:
|
||||||
x = block(x, c, text_states, freqs_cis_img) # (N, L, D)
|
skip = None
|
||||||
|
|
||||||
|
if ("double_block", layer) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"] = block(args["img"], args["vec"], args["txt"], args["pe"], args["skip"])
|
||||||
|
return out
|
||||||
|
|
||||||
|
out = blocks_replace[("double_block", layer)]({"img": x, "txt": text_states, "vec": c, "pe": freqs_cis_img, "skip": skip}, {"original_block": block_wrap})
|
||||||
|
x = out["img"]
|
||||||
|
else:
|
||||||
|
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
||||||
|
|
||||||
|
|
||||||
if layer < (self.depth // 2 - 1):
|
if layer < (self.depth // 2 - 1):
|
||||||
skips.append(x)
|
skips.append(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user