Add an image_interleave option to the Hunyuan image to video encode node.

See the tooltip for what it does.
This commit is contained in:
comfyanonymous 2025-03-07 19:56:11 -05:00
parent c3d9cc4592
commit be4e760648
2 changed files with 20 additions and 13 deletions

View File

@ -42,7 +42,7 @@ class HunyuanVideoTokenizer:
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1) self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, **kwargs): def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs):
out = {} out = {}
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
@ -56,7 +56,7 @@ class HunyuanVideoTokenizer:
for i in range(len(r)): for i in range(len(r)):
if r[i][0] == 128257: if r[i][0] == 128257:
if image_embeds is not None and embed_count < image_embeds.shape[0]: if image_embeds is not None and embed_count < image_embeds.shape[0]:
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:] r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image", "image_interleave": image_interleave},) + r[i][1:]
embed_count += 1 embed_count += 1
out["llama"] = llama_text_tokens out["llama"] = llama_text_tokens
return out return out
@ -92,10 +92,10 @@ class HunyuanVideoClipModel(torch.nn.Module):
llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama) llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama)
template_end = 0 template_end = 0
image_start = None extra_template_end = 0
image_end = None
extra_sizes = 0 extra_sizes = 0
user_end = 9999999999999 user_end = 9999999999999
images = []
tok_pairs = token_weight_pairs_llama[0] tok_pairs = token_weight_pairs_llama[0]
for i, v in enumerate(tok_pairs): for i, v in enumerate(tok_pairs):
@ -112,22 +112,28 @@ class HunyuanVideoClipModel(torch.nn.Module):
else: else:
if elem.get("original_type") == "image": if elem.get("original_type") == "image":
elem_size = elem.get("data").shape[0] elem_size = elem.get("data").shape[0]
if image_start is None: if template_end > 0:
if user_end == -1:
extra_template_end += elem_size - 1
else:
image_start = i + extra_sizes image_start = i + extra_sizes
image_end = i + elem_size + extra_sizes image_end = i + elem_size + extra_sizes
extra_sizes += elem_size - 1 images.append((image_start, image_end, elem.get("image_interleave", 1)))
extra_sizes += elem_size - 1
if llama_out.shape[1] > (template_end + 2): if llama_out.shape[1] > (template_end + 2):
if tok_pairs[template_end + 1][0] == 271: if tok_pairs[template_end + 1][0] == 271:
template_end += 2 template_end += 2
llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes] llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes] llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]): if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]):
llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements
if image_start is not None: if len(images) > 0:
image_output = llama_out[:, image_start: image_end] out = []
llama_output = torch.cat([image_output[:, ::2], llama_output], dim=1) for i in images:
out.append(llama_out[:, i[0]: i[1]: i[2]])
llama_output = torch.cat(out + [llama_output], dim=1)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return llama_output, l_pooled, llama_extra_out return llama_output, l_pooled, llama_extra_out

View File

@ -57,14 +57,15 @@ class TextEncodeHunyuanVideo_ImageToVideo:
"clip": ("CLIP", ), "clip": ("CLIP", ),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ), "clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}), "prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"image_interleave": ("INT", {"default": 2, "min": 1, "max": 512, "tooltip": "How much the image influences things vs the text prompt. Higher number means more influence from the text prompt."}),
}} }}
RETURN_TYPES = ("CONDITIONING",) RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode" FUNCTION = "encode"
CATEGORY = "advanced/conditioning" CATEGORY = "advanced/conditioning"
def encode(self, clip, clip_vision_output, prompt): def encode(self, clip, clip_vision_output, prompt, image_interleave):
tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected) tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave)
return (clip.encode_from_tokens_scheduled(tokens), ) return (clip.encode_from_tokens_scheduled(tokens), )