From 1c8d11e48a822327e7b77ccacdb56bccf3e20072 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 10 Dec 2024 09:44:13 -0500 Subject: [PATCH] Support different types of tokenizers. Support tokenizers without an eos token. Pass full sentences to tokenizer for more efficient tokenizing. --- comfy/sd1_clip.py | 44 ++++++++++++++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index a454f3bb..887874a0 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -90,8 +90,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): if textmodel_json_config is None: textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") - with open(textmodel_json_config) as f: - config = json.load(f) + if isinstance(textmodel_json_config, dict): + config = textmodel_json_config + else: + with open(textmodel_json_config) as f: + config = json.load(f) operations = model_options.get("custom_operations", None) scaled_fp8 = None @@ -411,22 +414,25 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No return embed_out class SDTokenizer: - def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None, pad_token=None, tokenizer_data={}): + def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, tokenizer_data={}): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path) self.max_length = max_length self.min_length = min_length + self.end_token = None empty = self.tokenizer('')["input_ids"] if has_start_token: self.tokens_start = 1 self.start_token = empty[0] - self.end_token = empty[1] + if has_end_token: + self.end_token = empty[1] else: self.tokens_start = 0 self.start_token = None - self.end_token = empty[0] + if has_end_token: + self.end_token = empty[0] if pad_token is not None: self.pad_token = pad_token @@ -451,13 +457,16 @@ class SDTokenizer: Takes a potential embedding name and tries to retrieve it. Returns a Tuple consisting of the embedding and any leftover string, embedding can be None. ''' + split_embed = embedding_name.split(' ') + embedding_name = split_embed[0] + leftover = ' '.join(split_embed[1:]) embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key) if embed is None: stripped = embedding_name.strip(',') if len(stripped) < len(embedding_name): embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key) - return (embed, embedding_name[len(stripped):]) - return (embed, "") + return (embed, "{} {}".format(embedding_name[len(stripped):], leftover)) + return (embed, leftover) def tokenize_with_weights(self, text:str, return_word_ids=False): @@ -474,7 +483,12 @@ class SDTokenizer: #tokenize words tokens = [] for weighted_segment, weight in parsed_weights: - to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ') + to_tokenize = unescape_important(weighted_segment).replace("\n", " ") + split = to_tokenize.split(' {}'.format(self.embedding_identifier)) + to_tokenize = [split[0]] + for i in range(1, len(split)): + to_tokenize.append("{}{}".format(self.embedding_identifier, split[i])) + to_tokenize = [x for x in to_tokenize if x != ""] for word in to_tokenize: #if we find an embedding, deal with the embedding @@ -493,8 +507,11 @@ class SDTokenizer: word = leftover else: continue + end = 999999999999 + if self.end_token is not None: + end = -1 #parse word - tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]]) + tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:end]]) #reshape token array to CLIP input size batched_tokens = [] @@ -512,11 +529,13 @@ class SDTokenizer: #break word in two and add end token if is_large: batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]]) - batch.append((self.end_token, 1.0, 0)) + if self.end_token is not None: + batch.append((self.end_token, 1.0, 0)) t_group = t_group[remaining_length:] #add end token and pad else: - batch.append((self.end_token, 1.0, 0)) + if self.end_token is not None: + batch.append((self.end_token, 1.0, 0)) if self.pad_to_max_length: batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length)) #start new batch @@ -529,7 +548,8 @@ class SDTokenizer: t_group = [] #fill last batch - batch.append((self.end_token, 1.0, 0)) + if self.end_token is not None: + batch.append((self.end_token, 1.0, 0)) if self.pad_to_max_length: batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch))) if self.min_length is not None and len(batch) < self.min_length: