Support multiple paths for embeddings.

This commit is contained in:
comfyanonymous 2023-03-18 03:08:43 -04:00
parent 51d6427ddf
commit 50099bcd96
3 changed files with 26 additions and 18 deletions

View File

@ -168,18 +168,27 @@ def unescape_important(text):
return text
def load_embed(embedding_name, embedding_directory):
embed_path = os.path.join(embedding_directory, embedding_name)
if isinstance(embedding_directory, str):
embedding_directory = [embedding_directory]
valid_file = None
for embed_dir in embedding_directory:
embed_path = os.path.join(embed_dir, embedding_name)
if not os.path.isfile(embed_path):
extensions = ['.safetensors', '.pt', '.bin']
valid_file = None
for x in extensions:
t = embed_path + x
if os.path.isfile(t):
valid_file = t
break
else:
valid_file = embed_path
if valid_file is not None:
break
if valid_file is None:
return None
else:
embed_path = valid_file
if embed_path.lower().endswith(".safetensors"):

View File

@ -22,7 +22,7 @@ folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
# folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
@ -33,6 +33,8 @@ def add_model_folder_path(folder_name, full_folder_path):
if folder_name in folder_names_and_paths:
folder_names_and_paths[folder_name][0].append(full_folder_path)
def get_folder_paths(folder_name):
return folder_names_and_paths[folder_name][0][:]
def recursive_search(directory):
result = []

View File

@ -188,9 +188,6 @@ class VAEEncodeForInpaint:
return ({"samples":t, "noise_mask": (mask_erosion[0][:x,:y].round())}, )
class CheckpointLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
embedding_directory = os.path.join(models_dir, "embeddings")
@classmethod
def INPUT_TYPES(s):
return {"required": { "config_name": (folder_paths.get_filename_list("configs"), ),
@ -203,7 +200,7 @@ class CheckpointLoader:
def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
config_path = folder_paths.get_full_path("configs", config_name)
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=self.embedding_directory)
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
class CheckpointLoaderSimple:
@classmethod
@ -217,7 +214,7 @@ class CheckpointLoaderSimple:
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=CheckpointLoader.embedding_directory)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out
class CLIPSetLastLayer: