mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Add support for textual inversion embedding for SD1.x CLIP.
This commit is contained in:
parent
702ac43d0c
commit
f73e57d881
1
.gitignore
vendored
1
.gitignore
vendored
@ -3,3 +3,4 @@ __pycache__/
|
|||||||
output/
|
output/
|
||||||
models/checkpoints
|
models/checkpoints
|
||||||
models/vae
|
models/vae
|
||||||
|
models/embeddings
|
||||||
|
@ -66,6 +66,10 @@ Dragging a generated png on the webpage or loading one will give you the full wo
|
|||||||
|
|
||||||
You can use () to change emphasis of a word or phrase like: (good code:1.2) or (bad code:0.8). The default emphasis for () is 1.1. To use () characters in your actual prompt escape them like \\( or \\).
|
You can use () to change emphasis of a word or phrase like: (good code:1.2) or (bad code:0.8). The default emphasis for () is 1.1. To use () characters in your actual prompt escape them like \\( or \\).
|
||||||
|
|
||||||
|
To use a textual inversion concepts/embeddings in a text prompt put them in the models/embeddings directory and use them in the CLIPTextEncode node like this (you can omit the .pt extension):
|
||||||
|
|
||||||
|
```embedding:embedding_filename.pt```
|
||||||
|
|
||||||
### Colab Notebook
|
### Colab Notebook
|
||||||
|
|
||||||
To run it on colab you can use my [Colab Notebook](notebooks/comfyui_colab.ipynb) here: [Link to open with google colab](https://colab.research.google.com/github/comfyanonymous/ComfyUI/blob/master/notebooks/comfyui_colab.ipynb)
|
To run it on colab you can use my [Colab Notebook](notebooks/comfyui_colab.ipynb) here: [Link to open with google colab](https://colab.research.google.com/github/comfyanonymous/ComfyUI/blob/master/notebooks/comfyui_colab.ipynb)
|
||||||
|
22
comfy/sd.py
22
comfy/sd.py
@ -53,19 +53,25 @@ def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
|
|||||||
|
|
||||||
|
|
||||||
class CLIP:
|
class CLIP:
|
||||||
def __init__(self, config):
|
def __init__(self, config, embedding_directory=None):
|
||||||
self.target_clip = config["target"]
|
self.target_clip = config["target"]
|
||||||
|
if "params" in config:
|
||||||
|
params = config["params"]
|
||||||
|
else:
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
tokenizer_params = {}
|
||||||
|
|
||||||
if self.target_clip == "ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder":
|
if self.target_clip == "ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder":
|
||||||
clip = sd2_clip.SD2ClipModel
|
clip = sd2_clip.SD2ClipModel
|
||||||
tokenizer = sd2_clip.SD2Tokenizer
|
tokenizer = sd2_clip.SD2Tokenizer
|
||||||
elif self.target_clip == "ldm.modules.encoders.modules.FrozenCLIPEmbedder":
|
elif self.target_clip == "ldm.modules.encoders.modules.FrozenCLIPEmbedder":
|
||||||
clip = sd1_clip.SD1ClipModel
|
clip = sd1_clip.SD1ClipModel
|
||||||
tokenizer = sd1_clip.SD1Tokenizer
|
tokenizer = sd1_clip.SD1Tokenizer
|
||||||
if "params" in config:
|
tokenizer_params['embedding_directory'] = embedding_directory
|
||||||
self.cond_stage_model = clip(**(config["params"]))
|
|
||||||
else:
|
self.cond_stage_model = clip(**(params))
|
||||||
self.cond_stage_model = clip()
|
self.tokenizer = tokenizer(**(tokenizer_params))
|
||||||
self.tokenizer = tokenizer()
|
|
||||||
|
|
||||||
def encode(self, text):
|
def encode(self, text):
|
||||||
tokens = self.tokenizer.tokenize_with_weights(text)
|
tokens = self.tokenizer.tokenize_with_weights(text)
|
||||||
@ -103,7 +109,7 @@ class VAE:
|
|||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True):
|
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None):
|
||||||
config = OmegaConf.load(config_path)
|
config = OmegaConf.load(config_path)
|
||||||
model_config_params = config['model']['params']
|
model_config_params = config['model']['params']
|
||||||
clip_config = model_config_params['cond_stage_config']
|
clip_config = model_config_params['cond_stage_config']
|
||||||
@ -124,7 +130,7 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True):
|
|||||||
load_state_dict_to = [w]
|
load_state_dict_to = [w]
|
||||||
|
|
||||||
if output_clip:
|
if output_clip:
|
||||||
clip = CLIP(config=clip_config)
|
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
|
||||||
w.cond_stage_model = clip.cond_stage_model
|
w.cond_stage_model = clip.cond_stage_model
|
||||||
load_state_dict_to = [w]
|
load_state_dict_to = [w]
|
||||||
|
|
||||||
|
@ -63,9 +63,38 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
self.layer = "hidden"
|
self.layer = "hidden"
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
|
def set_up_textual_embeddings(self, tokens, current_embeds):
|
||||||
|
out_tokens = []
|
||||||
|
next_new_token = token_dict_size = current_embeds.weight.shape[0]
|
||||||
|
embedding_weights = []
|
||||||
|
|
||||||
|
for x in tokens:
|
||||||
|
tokens_temp = []
|
||||||
|
for y in x:
|
||||||
|
if isinstance(y, int):
|
||||||
|
tokens_temp += [y]
|
||||||
|
else:
|
||||||
|
embedding_weights += [y]
|
||||||
|
tokens_temp += [next_new_token]
|
||||||
|
next_new_token += 1
|
||||||
|
out_tokens += [tokens_temp]
|
||||||
|
|
||||||
|
if len(embedding_weights) > 0:
|
||||||
|
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1])
|
||||||
|
new_embedding.weight[:token_dict_size] = current_embeds.weight[:]
|
||||||
|
n = token_dict_size
|
||||||
|
for x in embedding_weights:
|
||||||
|
new_embedding.weight[n] = x
|
||||||
|
n += 1
|
||||||
|
self.transformer.set_input_embeddings(new_embedding)
|
||||||
|
return out_tokens
|
||||||
|
|
||||||
def forward(self, tokens):
|
def forward(self, tokens):
|
||||||
|
backup_embeds = self.transformer.get_input_embeddings()
|
||||||
|
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
||||||
tokens = torch.LongTensor(tokens).to(self.device)
|
tokens = torch.LongTensor(tokens).to(self.device)
|
||||||
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
|
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
|
||||||
|
self.transformer.set_input_embeddings(backup_embeds)
|
||||||
|
|
||||||
if self.layer == "last":
|
if self.layer == "last":
|
||||||
z = outputs.last_hidden_state
|
z = outputs.last_hidden_state
|
||||||
@ -138,18 +167,49 @@ def unescape_important(text):
|
|||||||
text = text.replace("\0\2", "(")
|
text = text.replace("\0\2", "(")
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
def load_embed(embedding_name, embedding_directory):
|
||||||
|
embed_path = os.path.join(embedding_directory, embedding_name)
|
||||||
|
if not os.path.isfile(embed_path):
|
||||||
|
extensions = ['.safetensors', '.pt', '.bin']
|
||||||
|
valid_file = None
|
||||||
|
for x in extensions:
|
||||||
|
t = embed_path + x
|
||||||
|
if os.path.isfile(t):
|
||||||
|
valid_file = t
|
||||||
|
break
|
||||||
|
if valid_file is None:
|
||||||
|
print("warning, embedding {} does not exist, ignoring".format(embed_path))
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
embed_path = valid_file
|
||||||
|
|
||||||
|
if embed_path.lower().endswith(".safetensors"):
|
||||||
|
import safetensors.torch
|
||||||
|
embed = safetensors.torch.load_file(embed_path, device="cpu")
|
||||||
|
else:
|
||||||
|
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
|
||||||
|
if 'string_to_param' in embed:
|
||||||
|
values = embed['string_to_param'].values()
|
||||||
|
else:
|
||||||
|
values = embed.values()
|
||||||
|
return next(iter(values))
|
||||||
|
|
||||||
class SD1Tokenizer:
|
class SD1Tokenizer:
|
||||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True):
|
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None):
|
||||||
if tokenizer_path is None:
|
if tokenizer_path is None:
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
|
self.max_tokens_per_section = self.max_length - 2
|
||||||
|
|
||||||
empty = self.tokenizer('')["input_ids"]
|
empty = self.tokenizer('')["input_ids"]
|
||||||
self.start_token = empty[0]
|
self.start_token = empty[0]
|
||||||
self.end_token = empty[1]
|
self.end_token = empty[1]
|
||||||
self.pad_with_end = pad_with_end
|
self.pad_with_end = pad_with_end
|
||||||
vocab = self.tokenizer.get_vocab()
|
vocab = self.tokenizer.get_vocab()
|
||||||
self.inv_vocab = {v: k for k, v in vocab.items()}
|
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||||||
|
self.embedding_directory = embedding_directory
|
||||||
|
self.max_word_length = 8
|
||||||
|
|
||||||
def tokenize_with_weights(self, text):
|
def tokenize_with_weights(self, text):
|
||||||
text = escape_important(text)
|
text = escape_important(text)
|
||||||
@ -157,13 +217,34 @@ class SD1Tokenizer:
|
|||||||
|
|
||||||
tokens = []
|
tokens = []
|
||||||
for t in parsed_weights:
|
for t in parsed_weights:
|
||||||
tt = self.tokenizer(unescape_important(t[0]))["input_ids"][1:-1]
|
to_tokenize = unescape_important(t[0]).split(' ')
|
||||||
|
for word in to_tokenize:
|
||||||
|
temp_tokens = []
|
||||||
|
embedding_identifier = "embedding:"
|
||||||
|
if word.startswith(embedding_identifier) and self.embedding_directory is not None:
|
||||||
|
embedding_name = word[len(embedding_identifier):].strip('\n')
|
||||||
|
embed = load_embed(embedding_name, self.embedding_directory)
|
||||||
|
if embed is not None:
|
||||||
|
if len(embed.shape) == 1:
|
||||||
|
temp_tokens += [(embed, t[1])]
|
||||||
|
else:
|
||||||
|
for x in range(embed.shape[0]):
|
||||||
|
temp_tokens += [(embed[x], t[1])]
|
||||||
|
elif len(word) > 0:
|
||||||
|
tt = self.tokenizer(word)["input_ids"][1:-1]
|
||||||
for x in tt:
|
for x in tt:
|
||||||
tokens += [(x, t[1])]
|
temp_tokens += [(x, t[1])]
|
||||||
|
tokens_left = self.max_tokens_per_section - (len(tokens) % self.max_tokens_per_section)
|
||||||
|
|
||||||
|
#try not to split words in different sections
|
||||||
|
if tokens_left < len(temp_tokens) and len(temp_tokens) < (self.max_word_length):
|
||||||
|
for x in range(tokens_left):
|
||||||
|
tokens += [(self.end_token, 1.0)]
|
||||||
|
tokens += temp_tokens
|
||||||
|
|
||||||
out_tokens = []
|
out_tokens = []
|
||||||
for x in range(0, len(tokens), self.max_length - 2):
|
for x in range(0, len(tokens), self.max_tokens_per_section):
|
||||||
o_token = [(self.start_token, 1.0)] + tokens[x:min(self.max_length - 2 + x, len(tokens))]
|
o_token = [(self.start_token, 1.0)] + tokens[x:min(self.max_tokens_per_section + x, len(tokens))]
|
||||||
o_token += [(self.end_token, 1.0)]
|
o_token += [(self.end_token, 1.0)]
|
||||||
if self.pad_with_end:
|
if self.pad_with_end:
|
||||||
o_token +=[(self.end_token, 1.0)] * (self.max_length - len(o_token))
|
o_token +=[(self.end_token, 1.0)] * (self.max_length - len(o_token))
|
||||||
|
3
nodes.py
3
nodes.py
@ -127,7 +127,8 @@ class CheckpointLoader:
|
|||||||
def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
|
def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
|
||||||
config_path = os.path.join(self.config_dir, config_name)
|
config_path = os.path.join(self.config_dir, config_name)
|
||||||
ckpt_path = os.path.join(self.ckpt_dir, ckpt_name)
|
ckpt_path = os.path.join(self.ckpt_dir, ckpt_name)
|
||||||
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True)
|
embedding_directory = os.path.join(self.models_dir, "embeddings")
|
||||||
|
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=embedding_directory)
|
||||||
|
|
||||||
class VAELoader:
|
class VAELoader:
|
||||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||||
|
Loading…
Reference in New Issue
Block a user