mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
align behavior with old tokenize function
This commit is contained in:
parent
44fe868b66
commit
752f7a162b
@ -226,12 +226,11 @@ class SD1Tokenizer:
|
||||
self.max_word_length = 8
|
||||
self.embedding_identifier = "embedding:"
|
||||
|
||||
def _try_get_embedding(self, name:str):
|
||||
def _try_get_embedding(self, embedding_name:str):
|
||||
'''
|
||||
Takes a potential embedding name and tries to retrieve it.
|
||||
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
|
||||
'''
|
||||
embedding_name = name[len(self.embedding_identifier):].strip('\n')
|
||||
embed = load_embed(embedding_name, self.embedding_directory)
|
||||
if embed is None:
|
||||
stripped = embedding_name.strip(',')
|
||||
@ -259,9 +258,10 @@ class SD1Tokenizer:
|
||||
for word in to_tokenize:
|
||||
#if we find an embedding, deal with the embedding
|
||||
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
|
||||
embed, leftover = self._try_get_embedding(word)
|
||||
embedding_name = word[len(self.embedding_identifier):].strip('\n')
|
||||
embed, leftover = self._try_get_embedding(embedding_name)
|
||||
if embed is None:
|
||||
print(f"warning, embedding:{word} does not exist, ignoring")
|
||||
print(f"warning, embedding:{embedding_name} does not exist, ignoring")
|
||||
else:
|
||||
if len(embed.shape) == 1:
|
||||
tokens.append([(embed, weight)])
|
||||
@ -280,21 +280,21 @@ class SD1Tokenizer:
|
||||
batch = []
|
||||
batched_tokens.append(batch)
|
||||
for i, t_group in enumerate(tokens):
|
||||
#start a new batch if there is not enough room
|
||||
if len(t_group) + len(batch) > self.max_tokens_per_section:
|
||||
remaining_length = self.max_tokens_per_section - len(batch)
|
||||
#fill remaining space depending on length of tokens
|
||||
if len(t_group) > self.max_word_length:
|
||||
#put part of group of tokens in the batch
|
||||
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
|
||||
t_group = t_group[remaining_length:]
|
||||
#determine if we're going to try and keep the tokens in a single batch
|
||||
is_large = len(t_group) >= self.max_word_length
|
||||
while len(t_group) > 0:
|
||||
if len(t_group) + len(batch) > self.max_tokens_per_section:
|
||||
remaining_length = self.max_tokens_per_section - len(batch)
|
||||
if is_large:
|
||||
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
|
||||
t_group = t_group[remaining_length:]
|
||||
else:
|
||||
batch.extend([(self.end_token, 1.0, 0)] * remaining_length)
|
||||
batch = []
|
||||
batched_tokens.append(batch)
|
||||
else:
|
||||
#filler tokens
|
||||
batch.extend([(self.end_token, 1.0, 0)] * remaining_length)
|
||||
batch = []
|
||||
batched_tokens.append(batch)
|
||||
#put current group of tokens in the batch
|
||||
batch.extend([(t,w,i+1) for t,w in t_group])
|
||||
batch.extend([(t,w,i+1) for t,w in t_group])
|
||||
t_group = []
|
||||
|
||||
#fill last batch
|
||||
batch.extend([(self.end_token, 1.0, 0)] * (self.max_tokens_per_section - len(batch)))
|
||||
|
Loading…
Reference in New Issue
Block a user