From 29a70ca1010c1482a96467a729f172e39382d631 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 6 Mar 2025 03:07:15 -0500 Subject: [PATCH] Support HunyuanVideo image to video model. --- comfy/model_base.py | 7 +++ comfy/supported_models.py | 12 ++++- comfy/text_encoders/hunyuan_video.py | 60 +++++++++++++++++++------ comfy_extras/nodes_hunyuan.py | 67 ++++++++++++++++++++++++++++ 4 files changed, 132 insertions(+), 14 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 07fd2db43..a304c58bd 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -900,6 +900,13 @@ class HunyuanVideo(BaseModel): out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out + +class HunyuanVideoI2V(HunyuanVideo): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device) + self.concat_keys = ("concat_image", "mask_inverted") + + class HunyuanVideoSkyreelsI2V(HunyuanVideo): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 7e37a17b1..7157a15f2 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -826,6 +826,16 @@ class HunyuanVideo(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect)) +class HunyuanVideoI2V(HunyuanVideo): + unet_config = { + "image_model": "hunyuan_video", + "in_channels": 33, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.HunyuanVideoI2V(self, device=device) + return out + class HunyuanVideoSkyreelsI2V(HunyuanVideo): unet_config = { "image_model": "hunyuan_video", @@ -949,6 +959,6 @@ class WAN21_I2V(WAN21_T2V): out = model_base.WAN21(self, image_to_video=True, device=device) return out -models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V] +models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V] models += [SVD_img2vid] diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index bdee0b3df..1d814aadd 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -4,6 +4,7 @@ import comfy.text_encoders.llama from transformers import LlamaTokenizerFast import torch import os +import numbers def llama_detect(state_dict, prefix=""): @@ -22,7 +23,7 @@ def llama_detect(state_dict, prefix=""): class LLAMA3Tokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, end_token=128009, min_length=min_length) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, min_length=min_length) class LLAMAModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}): @@ -38,18 +39,26 @@ class HunyuanVideoTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) - self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n""" # 95 tokens + self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1) - def tokenize_with_weights(self, text:str, return_word_ids=False, llama_template=None, **kwargs): + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, **kwargs): out = {} out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) if llama_template is None: - llama_text = "{}{}".format(self.llama_template, text) + llama_text = self.llama_template.format(text) else: - llama_text = "{}{}".format(llama_template, text) - out["llama"] = self.llama.tokenize_with_weights(llama_text, return_word_ids) + llama_text = llama_template.format(text) + llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids) + embed_count = 0 + for r in llama_text_tokens: + for i in range(len(r)): + if r[i][0] == 128257: + if image_embeds is not None and embed_count < image_embeds.shape[0]: + r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:] + embed_count += 1 + out["llama"] = llama_text_tokens return out def untokenize(self, token_weight_pair): @@ -83,20 +92,45 @@ class HunyuanVideoClipModel(torch.nn.Module): llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama) template_end = 0 - for i, v in enumerate(token_weight_pairs_llama[0]): - if v[0] == 128007: # <|end_header_id|> - template_end = i + image_start = None + image_end = None + extra_sizes = 0 + user_end = 9999999999999 + + tok_pairs = token_weight_pairs_llama[0] + for i, v in enumerate(tok_pairs): + elem = v[0] + if not torch.is_tensor(elem): + if isinstance(elem, numbers.Integral): + if elem == 128006: + if tok_pairs[i + 1][0] == 882: + if tok_pairs[i + 2][0] == 128007: + template_end = i + 2 + user_end = -1 + if elem == 128009 and user_end == -1: + user_end = i + 1 + else: + if elem.get("original_type") == "image": + elem_size = elem.get("data").shape[0] + if image_start is None: + image_start = i + extra_sizes + image_end = i + elem_size + extra_sizes + extra_sizes += elem_size - 1 if llama_out.shape[1] > (template_end + 2): - if token_weight_pairs_llama[0][template_end + 1][0] == 271: + if tok_pairs[template_end + 1][0] == 271: template_end += 2 - llama_out = llama_out[:, template_end:] - llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end:] + llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes] + llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes] if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]): llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements + if image_start is not None: + image_output = llama_out[:, image_start: image_end] + llama_output = torch.cat([image_output[:, ::2], llama_output], dim=1) + l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) - return llama_out, l_pooled, llama_extra_out + return llama_output, l_pooled, llama_extra_out def load_sd(self, sd): if "text_model.encoder.layers.1.mlp.fc1.weight" in sd: diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index d6408269f..4f700bbe6 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -1,4 +1,5 @@ import nodes +import node_helpers import torch import comfy.model_management @@ -38,7 +39,73 @@ class EmptyHunyuanLatentVideo: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) return ({"samples":latent}, ) +PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = ( + "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the video by detailing the following aspects according to the reference image: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" +) + +class TextEncodeHunyuanVideo_ImageToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "clip": ("CLIP", ), + "clip_vision_output": ("CLIP_VISION_OUTPUT", ), + "prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "encode" + + CATEGORY = "advanced/conditioning" + + def encode(self, clip, clip_vision_output, prompt): + tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected) + return (clip.encode_from_tokens_scheduled(tokens), ) + + +class HunyuanImageToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"start_image": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, vae, width, height, length, batch_size, start_image=None): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + + concat_latent_image = vae.encode(start_image) + mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) + mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + + out_latent = {} + out_latent["samples"] = latent + return (positive, out_latent) + + NODE_CLASS_MAPPINGS = { "CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT, + "TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo, "EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo, + "HunyuanImageToVideo": HunyuanImageToVideo, }