Support multiple paths for embeddings.

This commit is contained in:
comfyanonymous 2023-03-18 03:08:43 -04:00
parent 51d6427ddf
commit 50099bcd96
3 changed files with 26 additions and 18 deletions

View File

@ -168,18 +168,27 @@ def unescape_important(text):
return text return text
def load_embed(embedding_name, embedding_directory): def load_embed(embedding_name, embedding_directory):
embed_path = os.path.join(embedding_directory, embedding_name) if isinstance(embedding_directory, str):
embedding_directory = [embedding_directory]
valid_file = None
for embed_dir in embedding_directory:
embed_path = os.path.join(embed_dir, embedding_name)
if not os.path.isfile(embed_path): if not os.path.isfile(embed_path):
extensions = ['.safetensors', '.pt', '.bin'] extensions = ['.safetensors', '.pt', '.bin']
valid_file = None
for x in extensions: for x in extensions:
t = embed_path + x t = embed_path + x
if os.path.isfile(t): if os.path.isfile(t):
valid_file = t valid_file = t
break break
else:
valid_file = embed_path
if valid_file is not None:
break
if valid_file is None: if valid_file is None:
return None return None
else:
embed_path = valid_file embed_path = valid_file
if embed_path.lower().endswith(".safetensors"): if embed_path.lower().endswith(".safetensors"):

View File

@ -22,7 +22,7 @@ folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions) folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions) folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions) folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
# folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions) folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions) folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
@ -33,6 +33,8 @@ def add_model_folder_path(folder_name, full_folder_path):
if folder_name in folder_names_and_paths: if folder_name in folder_names_and_paths:
folder_names_and_paths[folder_name][0].append(full_folder_path) folder_names_and_paths[folder_name][0].append(full_folder_path)
def get_folder_paths(folder_name):
return folder_names_and_paths[folder_name][0][:]
def recursive_search(directory): def recursive_search(directory):
result = [] result = []

View File

@ -188,9 +188,6 @@ class VAEEncodeForInpaint:
return ({"samples":t, "noise_mask": (mask_erosion[0][:x,:y].round())}, ) return ({"samples":t, "noise_mask": (mask_erosion[0][:x,:y].round())}, )
class CheckpointLoader: class CheckpointLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
embedding_directory = os.path.join(models_dir, "embeddings")
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "config_name": (folder_paths.get_filename_list("configs"), ), return {"required": { "config_name": (folder_paths.get_filename_list("configs"), ),
@ -203,7 +200,7 @@ class CheckpointLoader:
def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True): def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
config_path = folder_paths.get_full_path("configs", config_name) config_path = folder_paths.get_full_path("configs", config_name)
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=self.embedding_directory) return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
class CheckpointLoaderSimple: class CheckpointLoaderSimple:
@classmethod @classmethod
@ -217,7 +214,7 @@ class CheckpointLoaderSimple:
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=CheckpointLoader.embedding_directory) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out return out
class CLIPSetLastLayer: class CLIPSetLastLayer: