Experimental lyrics strength for ACE. (#7984)

This commit is contained in:
comfyanonymous 2025-05-07 16:22:07 -07:00 committed by GitHub
parent b9980592c4
commit cc33cd3422
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 12 additions and 4 deletions

View File

@ -273,6 +273,7 @@ class ACEStepTransformer2DModel(nn.Module):
speaker_embeds: Optional[torch.FloatTensor] = None,
lyric_token_idx: Optional[torch.LongTensor] = None,
lyric_mask: Optional[torch.LongTensor] = None,
lyrics_strength=1.0,
):
bs = encoder_text_hidden_states.shape[0]
@ -291,6 +292,8 @@ class ACEStepTransformer2DModel(nn.Module):
out_dtype=encoder_text_hidden_states.dtype,
)
encoder_lyric_hidden_states *= lyrics_strength
encoder_hidden_states = torch.cat([encoder_spk_hidden_states, encoder_text_hidden_states, encoder_lyric_hidden_states], dim=1)
encoder_hidden_mask = None
@ -310,7 +313,6 @@ class ACEStepTransformer2DModel(nn.Module):
output_length: int = 0,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
return_dict: bool = True,
):
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
temb = self.t_block(embedded_timestep)
@ -353,6 +355,7 @@ class ACEStepTransformer2DModel(nn.Module):
lyric_mask: Optional[torch.LongTensor] = None,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
lyrics_strength=1.0,
**kwargs
):
hidden_states = x
@ -363,6 +366,7 @@ class ACEStepTransformer2DModel(nn.Module):
speaker_embeds=speaker_embeds,
lyric_token_idx=lyric_token_idx,
lyric_mask=lyric_mask,
lyrics_strength=lyrics_strength,
)
output_length = hidden_states.shape[-1]

View File

@ -1139,4 +1139,5 @@ class ACEStep(BaseModel):
if cross_attn is not None:
out['lyric_token_idx'] = comfy.conds.CONDRegular(conditioning_lyrics)
out['speaker_embeds'] = comfy.conds.CONDRegular(torch.zeros(noise.shape[0], 512, device=noise.device, dtype=noise.dtype))
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
return out

View File

@ -1,6 +1,6 @@
import torch
import comfy.model_management
import node_helpers
class TextEncodeAceStepAudio:
@classmethod
@ -9,15 +9,18 @@ class TextEncodeAceStepAudio:
"clip": ("CLIP", ),
"tags": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"lyrics": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"lyrics_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "conditioning"
def encode(self, clip, tags, lyrics):
def encode(self, clip, tags, lyrics, lyrics_strength):
tokens = clip.tokenize(tags, lyrics=lyrics)
return (clip.encode_from_tokens_scheduled(tokens), )
conditioning = clip.encode_from_tokens_scheduled(tokens)
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
return (conditioning, )
class EmptyAceStepLatentAudio: