From ca457f7ba1ff33bd0d724cd7324131509efcf3ec Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 17 Dec 2024 16:18:35 -0500 Subject: [PATCH] Properly tokenize the template for hunyuan video. --- comfy/sd1_clip.py | 22 ++++++++++++---------- comfy/text_encoders/hunyuan_video.py | 9 +++++---- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index d7a7ccc3..c0fe1ba5 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -10,6 +10,7 @@ import comfy.clip_model import json import logging import numbers +import re def gen_empty_tokens(special_tokens, length): start_token = special_tokens.get("start", None) @@ -429,13 +430,14 @@ class SDTokenizer: self.end_token = None empty = self.tokenizer('')["input_ids"] + self.tokenizer_adds_end_token = has_end_token if has_start_token: self.tokens_start = 1 self.start_token = empty[0] - if has_end_token: - if end_token is not None: - self.end_token = end_token - else: + if end_token is not None: + self.end_token = end_token + else: + if has_end_token: self.end_token = empty[1] else: self.tokens_start = 0 @@ -468,7 +470,7 @@ class SDTokenizer: Takes a potential embedding name and tries to retrieve it. Returns a Tuple consisting of the embedding and any leftover string, embedding can be None. ''' - split_embed = embedding_name.split(' ') + split_embed = embedding_name.split() embedding_name = split_embed[0] leftover = ' '.join(split_embed[1:]) embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key) @@ -491,18 +493,18 @@ class SDTokenizer: text = escape_important(text) parsed_weights = token_weights(text, 1.0) - #tokenize words + # tokenize words tokens = [] for weighted_segment, weight in parsed_weights: - to_tokenize = unescape_important(weighted_segment).replace("\n", " ") - split = to_tokenize.split(' {}'.format(self.embedding_identifier)) + to_tokenize = unescape_important(weighted_segment) + split = re.split(' {0}|\n{0}'.format(self.embedding_identifier), to_tokenize) to_tokenize = [split[0]] for i in range(1, len(split)): to_tokenize.append("{}{}".format(self.embedding_identifier, split[i])) to_tokenize = [x for x in to_tokenize if x != ""] for word in to_tokenize: - #if we find an embedding, deal with the embedding + # if we find an embedding, deal with the embedding if word.startswith(self.embedding_identifier) and self.embedding_directory is not None: embedding_name = word[len(self.embedding_identifier):].strip('\n') embed, leftover = self._try_get_embedding(embedding_name) @@ -519,7 +521,7 @@ class SDTokenizer: else: continue end = 999999999999 - if self.end_token is not None: + if self.tokenizer_adds_end_token: end = -1 #parse word tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:end]]) diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index 3b68f5ed..7149d687 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -22,7 +22,7 @@ def llama_detect(state_dict, prefix=""): class LLAMA3Tokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=99999999, pad_token=128258, end_token=128009, min_length=min_length) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, end_token=128009, min_length=min_length) class LLAMAModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}): @@ -38,9 +38,7 @@ class HunyuanVideoTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) - self.llama_template = """<|start_header_id|>system<|end_header_id|> - -Describe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>""" # 93 tokens + self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n""" # 95 tokens self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1) def tokenize_with_weights(self, text:str, return_word_ids=False): @@ -86,6 +84,9 @@ class HunyuanVideoClipModel(torch.nn.Module): if v[0] == 128007: # <|end_header_id|> template_end = i + if llama_out.shape[1] > (template_end + 2): + if token_weight_pairs_llama[0][template_end + 1][0] == 271: + template_end += 2 llama_out = llama_out[:, template_end:] llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end:] if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]):