mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Support different types of tokenizers.
Support tokenizers without an eos token. Pass full sentences to tokenizer for more efficient tokenizing.
This commit is contained in:
parent
a220d11e6b
commit
1c8d11e48a
@ -90,8 +90,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
if textmodel_json_config is None:
|
if textmodel_json_config is None:
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||||
|
|
||||||
with open(textmodel_json_config) as f:
|
if isinstance(textmodel_json_config, dict):
|
||||||
config = json.load(f)
|
config = textmodel_json_config
|
||||||
|
else:
|
||||||
|
with open(textmodel_json_config) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
operations = model_options.get("custom_operations", None)
|
operations = model_options.get("custom_operations", None)
|
||||||
scaled_fp8 = None
|
scaled_fp8 = None
|
||||||
@ -411,22 +414,25 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
|||||||
return embed_out
|
return embed_out
|
||||||
|
|
||||||
class SDTokenizer:
|
class SDTokenizer:
|
||||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None, pad_token=None, tokenizer_data={}):
|
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, tokenizer_data={}):
|
||||||
if tokenizer_path is None:
|
if tokenizer_path is None:
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
|
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.min_length = min_length
|
self.min_length = min_length
|
||||||
|
self.end_token = None
|
||||||
|
|
||||||
empty = self.tokenizer('')["input_ids"]
|
empty = self.tokenizer('')["input_ids"]
|
||||||
if has_start_token:
|
if has_start_token:
|
||||||
self.tokens_start = 1
|
self.tokens_start = 1
|
||||||
self.start_token = empty[0]
|
self.start_token = empty[0]
|
||||||
self.end_token = empty[1]
|
if has_end_token:
|
||||||
|
self.end_token = empty[1]
|
||||||
else:
|
else:
|
||||||
self.tokens_start = 0
|
self.tokens_start = 0
|
||||||
self.start_token = None
|
self.start_token = None
|
||||||
self.end_token = empty[0]
|
if has_end_token:
|
||||||
|
self.end_token = empty[0]
|
||||||
|
|
||||||
if pad_token is not None:
|
if pad_token is not None:
|
||||||
self.pad_token = pad_token
|
self.pad_token = pad_token
|
||||||
@ -451,13 +457,16 @@ class SDTokenizer:
|
|||||||
Takes a potential embedding name and tries to retrieve it.
|
Takes a potential embedding name and tries to retrieve it.
|
||||||
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
|
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
|
||||||
'''
|
'''
|
||||||
|
split_embed = embedding_name.split(' ')
|
||||||
|
embedding_name = split_embed[0]
|
||||||
|
leftover = ' '.join(split_embed[1:])
|
||||||
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
|
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
|
||||||
if embed is None:
|
if embed is None:
|
||||||
stripped = embedding_name.strip(',')
|
stripped = embedding_name.strip(',')
|
||||||
if len(stripped) < len(embedding_name):
|
if len(stripped) < len(embedding_name):
|
||||||
embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key)
|
embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key)
|
||||||
return (embed, embedding_name[len(stripped):])
|
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
|
||||||
return (embed, "")
|
return (embed, leftover)
|
||||||
|
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||||
@ -474,7 +483,12 @@ class SDTokenizer:
|
|||||||
#tokenize words
|
#tokenize words
|
||||||
tokens = []
|
tokens = []
|
||||||
for weighted_segment, weight in parsed_weights:
|
for weighted_segment, weight in parsed_weights:
|
||||||
to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ')
|
to_tokenize = unescape_important(weighted_segment).replace("\n", " ")
|
||||||
|
split = to_tokenize.split(' {}'.format(self.embedding_identifier))
|
||||||
|
to_tokenize = [split[0]]
|
||||||
|
for i in range(1, len(split)):
|
||||||
|
to_tokenize.append("{}{}".format(self.embedding_identifier, split[i]))
|
||||||
|
|
||||||
to_tokenize = [x for x in to_tokenize if x != ""]
|
to_tokenize = [x for x in to_tokenize if x != ""]
|
||||||
for word in to_tokenize:
|
for word in to_tokenize:
|
||||||
#if we find an embedding, deal with the embedding
|
#if we find an embedding, deal with the embedding
|
||||||
@ -493,8 +507,11 @@ class SDTokenizer:
|
|||||||
word = leftover
|
word = leftover
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
end = 999999999999
|
||||||
|
if self.end_token is not None:
|
||||||
|
end = -1
|
||||||
#parse word
|
#parse word
|
||||||
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]])
|
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:end]])
|
||||||
|
|
||||||
#reshape token array to CLIP input size
|
#reshape token array to CLIP input size
|
||||||
batched_tokens = []
|
batched_tokens = []
|
||||||
@ -512,11 +529,13 @@ class SDTokenizer:
|
|||||||
#break word in two and add end token
|
#break word in two and add end token
|
||||||
if is_large:
|
if is_large:
|
||||||
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
|
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
|
||||||
batch.append((self.end_token, 1.0, 0))
|
if self.end_token is not None:
|
||||||
|
batch.append((self.end_token, 1.0, 0))
|
||||||
t_group = t_group[remaining_length:]
|
t_group = t_group[remaining_length:]
|
||||||
#add end token and pad
|
#add end token and pad
|
||||||
else:
|
else:
|
||||||
batch.append((self.end_token, 1.0, 0))
|
if self.end_token is not None:
|
||||||
|
batch.append((self.end_token, 1.0, 0))
|
||||||
if self.pad_to_max_length:
|
if self.pad_to_max_length:
|
||||||
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
|
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
|
||||||
#start new batch
|
#start new batch
|
||||||
@ -529,7 +548,8 @@ class SDTokenizer:
|
|||||||
t_group = []
|
t_group = []
|
||||||
|
|
||||||
#fill last batch
|
#fill last batch
|
||||||
batch.append((self.end_token, 1.0, 0))
|
if self.end_token is not None:
|
||||||
|
batch.append((self.end_token, 1.0, 0))
|
||||||
if self.pad_to_max_length:
|
if self.pad_to_max_length:
|
||||||
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||||||
if self.min_length is not None and len(batch) < self.min_length:
|
if self.min_length is not None and len(batch) < self.min_length:
|
||||||
|
Loading…
Reference in New Issue
Block a user