Skip layer guidance now works on stable audio model.

This commit is contained in:
comfyanonymous 2024-11-20 07:33:06 -05:00
parent 898615122f
commit 22535d0589

View File

@ -612,7 +612,9 @@ class ContinuousTransformer(nn.Module):
return_info = False, return_info = False,
**kwargs **kwargs
): ):
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
batch, seq, device = *x.shape[:2], x.device batch, seq, device = *x.shape[:2], x.device
context = kwargs["context"]
info = { info = {
"hidden_states": [], "hidden_states": [],
@ -643,9 +645,19 @@ class ContinuousTransformer(nn.Module):
if self.use_sinusoidal_emb or self.use_abs_pos_emb: if self.use_sinusoidal_emb or self.use_abs_pos_emb:
x = x + self.pos_emb(x) x = x + self.pos_emb(x)
blocks_replace = patches_replace.get("dit", {})
# Iterate over the transformer layers # Iterate over the transformer layers
for layer in self.layers: for i, layer in enumerate(self.layers):
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
x = out["img"]
else:
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) # x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
if return_info: if return_info:
@ -874,7 +886,6 @@ class AudioDiffusionTransformer(nn.Module):
mask=None, mask=None,
return_info=False, return_info=False,
control=None, control=None,
transformer_options={},
**kwargs): **kwargs):
return self._forward( return self._forward(
x, x,