mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-06 11:32:09 +08:00
Add a T5TokenizerOptions node to set options for the T5 tokenizer. (#7803)
This commit is contained in:
parent
78992c4b25
commit
23e39f2ba7
10
comfy/sd.py
10
comfy/sd.py
@ -120,6 +120,7 @@ class CLIP:
|
|||||||
self.layer_idx = None
|
self.layer_idx = None
|
||||||
self.use_clip_schedule = False
|
self.use_clip_schedule = False
|
||||||
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
||||||
|
self.tokenizer_options = {}
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = CLIP(no_init=True)
|
n = CLIP(no_init=True)
|
||||||
@ -127,6 +128,7 @@ class CLIP:
|
|||||||
n.cond_stage_model = self.cond_stage_model
|
n.cond_stage_model = self.cond_stage_model
|
||||||
n.tokenizer = self.tokenizer
|
n.tokenizer = self.tokenizer
|
||||||
n.layer_idx = self.layer_idx
|
n.layer_idx = self.layer_idx
|
||||||
|
n.tokenizer_options = self.tokenizer_options.copy()
|
||||||
n.use_clip_schedule = self.use_clip_schedule
|
n.use_clip_schedule = self.use_clip_schedule
|
||||||
n.apply_hooks_to_conds = self.apply_hooks_to_conds
|
n.apply_hooks_to_conds = self.apply_hooks_to_conds
|
||||||
return n
|
return n
|
||||||
@ -134,10 +136,18 @@ class CLIP:
|
|||||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||||
return self.patcher.add_patches(patches, strength_patch, strength_model)
|
return self.patcher.add_patches(patches, strength_patch, strength_model)
|
||||||
|
|
||||||
|
def set_tokenizer_option(self, option_name, value):
|
||||||
|
self.tokenizer_options[option_name] = value
|
||||||
|
|
||||||
def clip_layer(self, layer_idx):
|
def clip_layer(self, layer_idx):
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
def tokenize(self, text, return_word_ids=False, **kwargs):
|
def tokenize(self, text, return_word_ids=False, **kwargs):
|
||||||
|
tokenizer_options = kwargs.get("tokenizer_options", {})
|
||||||
|
if len(self.tokenizer_options) > 0:
|
||||||
|
tokenizer_options = {**self.tokenizer_options, **tokenizer_options}
|
||||||
|
if len(tokenizer_options) > 0:
|
||||||
|
kwargs["tokenizer_options"] = tokenizer_options
|
||||||
return self.tokenizer.tokenize_with_weights(text, return_word_ids, **kwargs)
|
return self.tokenizer.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
|
|
||||||
def add_hooks_to_dict(self, pooled_dict: dict[str]):
|
def add_hooks_to_dict(self, pooled_dict: dict[str]):
|
||||||
|
@ -457,13 +457,14 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
|||||||
return embed_out
|
return embed_out
|
||||||
|
|
||||||
class SDTokenizer:
|
class SDTokenizer:
|
||||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, tokenizer_data={}, tokenizer_args={}):
|
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}):
|
||||||
if tokenizer_path is None:
|
if tokenizer_path is None:
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
||||||
self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
|
self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
|
||||||
self.min_length = min_length
|
self.min_length = min_length
|
||||||
self.end_token = None
|
self.end_token = None
|
||||||
|
self.min_padding = min_padding
|
||||||
|
|
||||||
empty = self.tokenizer('')["input_ids"]
|
empty = self.tokenizer('')["input_ids"]
|
||||||
self.tokenizer_adds_end_token = has_end_token
|
self.tokenizer_adds_end_token = has_end_token
|
||||||
@ -518,13 +519,15 @@ class SDTokenizer:
|
|||||||
return (embed, leftover)
|
return (embed, leftover)
|
||||||
|
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs):
|
||||||
'''
|
'''
|
||||||
Takes a prompt and converts it to a list of (token, weight, word id) elements.
|
Takes a prompt and converts it to a list of (token, weight, word id) elements.
|
||||||
Tokens can both be integer tokens and pre computed CLIP tensors.
|
Tokens can both be integer tokens and pre computed CLIP tensors.
|
||||||
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
|
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
|
||||||
Returned list has the dimensions NxM where M is the input size of CLIP
|
Returned list has the dimensions NxM where M is the input size of CLIP
|
||||||
'''
|
'''
|
||||||
|
min_length = tokenizer_options.get("{}_min_length".format(self.embedding_key), self.min_length)
|
||||||
|
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
|
||||||
|
|
||||||
text = escape_important(text)
|
text = escape_important(text)
|
||||||
parsed_weights = token_weights(text, 1.0)
|
parsed_weights = token_weights(text, 1.0)
|
||||||
@ -603,10 +606,12 @@ class SDTokenizer:
|
|||||||
#fill last batch
|
#fill last batch
|
||||||
if self.end_token is not None:
|
if self.end_token is not None:
|
||||||
batch.append((self.end_token, 1.0, 0))
|
batch.append((self.end_token, 1.0, 0))
|
||||||
if self.pad_to_max_length:
|
if min_padding is not None:
|
||||||
|
batch.extend([(self.pad_token, 1.0, 0)] * min_padding)
|
||||||
|
if self.pad_to_max_length and len(batch) < self.max_length:
|
||||||
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||||||
if self.min_length is not None and len(batch) < self.min_length:
|
if min_length is not None and len(batch) < min_length:
|
||||||
batch.extend([(self.pad_token, 1.0, 0)] * (self.min_length - len(batch)))
|
batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch)))
|
||||||
|
|
||||||
if not return_word_ids:
|
if not return_word_ids:
|
||||||
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
|
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
|
||||||
@ -634,7 +639,7 @@ class SD1Tokenizer:
|
|||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids)
|
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def untokenize(self, token_weight_pair):
|
def untokenize(self, token_weight_pair):
|
||||||
|
@ -28,8 +28,8 @@ class SDXLTokenizer:
|
|||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def untokenize(self, token_weight_pair):
|
def untokenize(self, token_weight_pair):
|
||||||
|
@ -19,8 +19,8 @@ class FluxTokenizer:
|
|||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
|
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def untokenize(self, token_weight_pair):
|
def untokenize(self, token_weight_pair):
|
||||||
|
@ -16,11 +16,11 @@ class HiDreamTokenizer:
|
|||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids)
|
t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens
|
out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens
|
||||||
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids)
|
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def untokenize(self, token_weight_pair):
|
def untokenize(self, token_weight_pair):
|
||||||
|
@ -49,13 +49,13 @@ class HunyuanVideoTokenizer:
|
|||||||
|
|
||||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs):
|
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
|
|
||||||
if llama_template is None:
|
if llama_template is None:
|
||||||
llama_text = self.llama_template.format(text)
|
llama_text = self.llama_template.format(text)
|
||||||
else:
|
else:
|
||||||
llama_text = llama_template.format(text)
|
llama_text = llama_template.format(text)
|
||||||
llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids)
|
llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids, **kwargs)
|
||||||
embed_count = 0
|
embed_count = 0
|
||||||
for r in llama_text_tokens:
|
for r in llama_text_tokens:
|
||||||
for i in range(len(r)):
|
for i in range(len(r)):
|
||||||
|
@ -41,8 +41,8 @@ class HyditTokenizer:
|
|||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids)
|
out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids)
|
out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def untokenize(self, token_weight_pair):
|
def untokenize(self, token_weight_pair):
|
||||||
|
@ -45,9 +45,9 @@ class SD3Tokenizer:
|
|||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
|
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def untokenize(self, token_weight_pair):
|
def untokenize(self, token_weight_pair):
|
||||||
|
@ -20,6 +20,29 @@ class CLIPTextEncodeControlnet:
|
|||||||
c.append(n)
|
c.append(n)
|
||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
|
class T5TokenizerOptions:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"clip": ("CLIP", ),
|
||||||
|
"min_padding": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}),
|
||||||
|
"min_length": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CLIP",)
|
||||||
|
FUNCTION = "set_options"
|
||||||
|
|
||||||
|
def set_options(self, clip, min_padding, min_length):
|
||||||
|
clip = clip.clone()
|
||||||
|
for t5_type in ["t5xxl", "pile_t5xl", "t5base", "mt5xl", "umt5xxl"]:
|
||||||
|
clip.set_tokenizer_option("{}_min_padding".format(t5_type), min_padding)
|
||||||
|
clip.set_tokenizer_option("{}_min_length".format(t5_type), min_length)
|
||||||
|
|
||||||
|
return (clip, )
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"CLIPTextEncodeControlnet": CLIPTextEncodeControlnet
|
"CLIPTextEncodeControlnet": CLIPTextEncodeControlnet,
|
||||||
|
"T5TokenizerOptions": T5TokenizerOptions,
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user