Add a T5TokenizerOptions node to set options for the T5 tokenizer. (#7803)

This commit is contained in:
comfyanonymous 2025-04-25 16:36:00 -07:00 committed by GitHub
parent 78992c4b25
commit 23e39f2ba7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 60 additions and 22 deletions

View File

@ -120,6 +120,7 @@ class CLIP:
self.layer_idx = None self.layer_idx = None
self.use_clip_schedule = False self.use_clip_schedule = False
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype)) logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
self.tokenizer_options = {}
def clone(self): def clone(self):
n = CLIP(no_init=True) n = CLIP(no_init=True)
@ -127,6 +128,7 @@ class CLIP:
n.cond_stage_model = self.cond_stage_model n.cond_stage_model = self.cond_stage_model
n.tokenizer = self.tokenizer n.tokenizer = self.tokenizer
n.layer_idx = self.layer_idx n.layer_idx = self.layer_idx
n.tokenizer_options = self.tokenizer_options.copy()
n.use_clip_schedule = self.use_clip_schedule n.use_clip_schedule = self.use_clip_schedule
n.apply_hooks_to_conds = self.apply_hooks_to_conds n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n return n
@ -134,10 +136,18 @@ class CLIP:
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
return self.patcher.add_patches(patches, strength_patch, strength_model) return self.patcher.add_patches(patches, strength_patch, strength_model)
def set_tokenizer_option(self, option_name, value):
self.tokenizer_options[option_name] = value
def clip_layer(self, layer_idx): def clip_layer(self, layer_idx):
self.layer_idx = layer_idx self.layer_idx = layer_idx
def tokenize(self, text, return_word_ids=False, **kwargs): def tokenize(self, text, return_word_ids=False, **kwargs):
tokenizer_options = kwargs.get("tokenizer_options", {})
if len(self.tokenizer_options) > 0:
tokenizer_options = {**self.tokenizer_options, **tokenizer_options}
if len(tokenizer_options) > 0:
kwargs["tokenizer_options"] = tokenizer_options
return self.tokenizer.tokenize_with_weights(text, return_word_ids, **kwargs) return self.tokenizer.tokenize_with_weights(text, return_word_ids, **kwargs)
def add_hooks_to_dict(self, pooled_dict: dict[str]): def add_hooks_to_dict(self, pooled_dict: dict[str]):

View File

@ -457,13 +457,14 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return embed_out return embed_out
class SDTokenizer: class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, tokenizer_data={}, tokenizer_args={}): def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}):
if tokenizer_path is None: if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args) self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length) self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
self.min_length = min_length self.min_length = min_length
self.end_token = None self.end_token = None
self.min_padding = min_padding
empty = self.tokenizer('')["input_ids"] empty = self.tokenizer('')["input_ids"]
self.tokenizer_adds_end_token = has_end_token self.tokenizer_adds_end_token = has_end_token
@ -518,13 +519,15 @@ class SDTokenizer:
return (embed, leftover) return (embed, leftover)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs):
''' '''
Takes a prompt and converts it to a list of (token, weight, word id) elements. Takes a prompt and converts it to a list of (token, weight, word id) elements.
Tokens can both be integer tokens and pre computed CLIP tensors. Tokens can both be integer tokens and pre computed CLIP tensors.
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens. Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
Returned list has the dimensions NxM where M is the input size of CLIP Returned list has the dimensions NxM where M is the input size of CLIP
''' '''
min_length = tokenizer_options.get("{}_min_length".format(self.embedding_key), self.min_length)
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
text = escape_important(text) text = escape_important(text)
parsed_weights = token_weights(text, 1.0) parsed_weights = token_weights(text, 1.0)
@ -603,10 +606,12 @@ class SDTokenizer:
#fill last batch #fill last batch
if self.end_token is not None: if self.end_token is not None:
batch.append((self.end_token, 1.0, 0)) batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length: if min_padding is not None:
batch.extend([(self.pad_token, 1.0, 0)] * min_padding)
if self.pad_to_max_length and len(batch) < self.max_length:
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch))) batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
if self.min_length is not None and len(batch) < self.min_length: if min_length is not None and len(batch) < min_length:
batch.extend([(self.pad_token, 1.0, 0)] * (self.min_length - len(batch))) batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch)))
if not return_word_ids: if not return_word_ids:
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens] batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
@ -634,7 +639,7 @@ class SD1Tokenizer:
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {} out = {}
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids) out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids, **kwargs)
return out return out
def untokenize(self, token_weight_pair): def untokenize(self, token_weight_pair):

View File

@ -28,8 +28,8 @@ class SDXLTokenizer:
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {} out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids) out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
return out return out
def untokenize(self, token_weight_pair): def untokenize(self, token_weight_pair):

View File

@ -19,8 +19,8 @@ class FluxTokenizer:
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {} out = {}
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids) out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
return out return out
def untokenize(self, token_weight_pair): def untokenize(self, token_weight_pair):

View File

@ -16,11 +16,11 @@ class HiDreamTokenizer:
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {} out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids) out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids) t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids) out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids, **kwargs)
return out return out
def untokenize(self, token_weight_pair): def untokenize(self, token_weight_pair):

View File

@ -49,13 +49,13 @@ class HunyuanVideoTokenizer:
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs): def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs):
out = {} out = {}
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
if llama_template is None: if llama_template is None:
llama_text = self.llama_template.format(text) llama_text = self.llama_template.format(text)
else: else:
llama_text = llama_template.format(text) llama_text = llama_template.format(text)
llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids) llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids, **kwargs)
embed_count = 0 embed_count = 0
for r in llama_text_tokens: for r in llama_text_tokens:
for i in range(len(r)): for i in range(len(r)):

View File

@ -41,8 +41,8 @@ class HyditTokenizer:
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {} out = {}
out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids) out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids, **kwargs)
out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids) out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids, **kwargs)
return out return out
def untokenize(self, token_weight_pair): def untokenize(self, token_weight_pair):

View File

@ -45,9 +45,9 @@ class SD3Tokenizer:
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {} out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids) out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids) out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
return out return out
def untokenize(self, token_weight_pair): def untokenize(self, token_weight_pair):

View File

@ -20,6 +20,29 @@ class CLIPTextEncodeControlnet:
c.append(n) c.append(n)
return (c, ) return (c, )
class T5TokenizerOptions:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"clip": ("CLIP", ),
"min_padding": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}),
"min_length": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}),
}
}
RETURN_TYPES = ("CLIP",)
FUNCTION = "set_options"
def set_options(self, clip, min_padding, min_length):
clip = clip.clone()
for t5_type in ["t5xxl", "pile_t5xl", "t5base", "mt5xl", "umt5xxl"]:
clip.set_tokenizer_option("{}_min_padding".format(t5_type), min_padding)
clip.set_tokenizer_option("{}_min_length".format(t5_type), min_length)
return (clip, )
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeControlnet": CLIPTextEncodeControlnet "CLIPTextEncodeControlnet": CLIPTextEncodeControlnet,
"T5TokenizerOptions": T5TokenizerOptions,
} }