From d6656b0c0c849895109d8dd83ba4ac6282b13957 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 17 Dec 2024 04:19:22 -0500 Subject: [PATCH] Support llama hunyuan video text encoder in scaled fp8 format. --- comfy/sd.py | 10 +++++++++- comfy/supported_models.py | 6 +++--- comfy/text_encoders/hunyuan_video.py | 13 +++++++++++++ 3 files changed, 25 insertions(+), 4 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 16b0acc8..3b29eecb 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -603,6 +603,14 @@ def t5xxl_detect(clip_data): return {} +def llama_detect(clip_data): + weight_name = "model.layers.0.self_attn.k_proj.weight" + + for sd in clip_data: + if weight_name in sd: + return comfy.text_encoders.hunyuan_video.llama_detect(sd) + + return {} def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): clip_data = state_dicts @@ -669,7 +677,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.flux.flux_clip(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer elif clip_type == CLIPType.HUNYUAN_VIDEO: - clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip() #TODO + clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer else: clip_target.clip = sdxl_clip.SDXLClipModel diff --git a/comfy/supported_models.py b/comfy/supported_models.py index ed3af9d1..68e2b13f 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -783,9 +783,9 @@ class HunyuanVideo(supported_models_base.BASE): return utils.state_dict_prefix_replace(state_dict, replace_prefix) def clip_target(self, state_dict={}): - # pref = self.text_encoder_key_prefix[0] - # t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) - return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip()) #TODO + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect)) models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo] diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index 2bbf9f6c..3b68f5ed 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -6,6 +6,19 @@ import torch import os +def llama_detect(state_dict, prefix=""): + out = {} + t5_key = "{}model.norm.weight".format(prefix) + if t5_key in state_dict: + out["dtype_llama"] = state_dict[t5_key].dtype + + scaled_fp8_key = "{}scaled_fp8".format(prefix) + if scaled_fp8_key in state_dict: + out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype + + return out + + class LLAMA3Tokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer")