diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 3dd8262a..45bc9526 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -226,12 +226,11 @@ class SD1Tokenizer: self.max_word_length = 8 self.embedding_identifier = "embedding:" - def _try_get_embedding(self, name:str): + def _try_get_embedding(self, embedding_name:str): ''' Takes a potential embedding name and tries to retrieve it. Returns a Tuple consisting of the embedding and any leftover string, embedding can be None. ''' - embedding_name = name[len(self.embedding_identifier):].strip('\n') embed = load_embed(embedding_name, self.embedding_directory) if embed is None: stripped = embedding_name.strip(',') @@ -259,9 +258,10 @@ class SD1Tokenizer: for word in to_tokenize: #if we find an embedding, deal with the embedding if word.startswith(self.embedding_identifier) and self.embedding_directory is not None: - embed, leftover = self._try_get_embedding(word) + embedding_name = word[len(self.embedding_identifier):].strip('\n') + embed, leftover = self._try_get_embedding(embedding_name) if embed is None: - print(f"warning, embedding:{word} does not exist, ignoring") + print(f"warning, embedding:{embedding_name} does not exist, ignoring") else: if len(embed.shape) == 1: tokens.append([(embed, weight)]) @@ -280,21 +280,21 @@ class SD1Tokenizer: batch = [] batched_tokens.append(batch) for i, t_group in enumerate(tokens): - #start a new batch if there is not enough room - if len(t_group) + len(batch) > self.max_tokens_per_section: - remaining_length = self.max_tokens_per_section - len(batch) - #fill remaining space depending on length of tokens - if len(t_group) > self.max_word_length: - #put part of group of tokens in the batch - batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]]) - t_group = t_group[remaining_length:] + #determine if we're going to try and keep the tokens in a single batch + is_large = len(t_group) >= self.max_word_length + while len(t_group) > 0: + if len(t_group) + len(batch) > self.max_tokens_per_section: + remaining_length = self.max_tokens_per_section - len(batch) + if is_large: + batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]]) + t_group = t_group[remaining_length:] + else: + batch.extend([(self.end_token, 1.0, 0)] * remaining_length) + batch = [] + batched_tokens.append(batch) else: - #filler tokens - batch.extend([(self.end_token, 1.0, 0)] * remaining_length) - batch = [] - batched_tokens.append(batch) - #put current group of tokens in the batch - batch.extend([(t,w,i+1) for t,w in t_group]) + batch.extend([(t,w,i+1) for t,w in t_group]) + t_group = [] #fill last batch batch.extend([(self.end_token, 1.0, 0)] * (self.max_tokens_per_section - len(batch)))