From 5a9ddf94eb302f4a7384f7e50eee29355de3dd4f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 29 Jun 2023 23:40:02 -0400 Subject: [PATCH] LoraLoader node now caches the lora file between executions. --- comfy/sd.py | 7 +++---- nodes.py | 16 +++++++++++++++- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 542f704a..8eac1f8e 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -89,8 +89,7 @@ LORA_UNET_MAP_RESNET = { "skip_connection": "resnets_{}_conv_shortcut" } -def load_lora(path, to_load): - lora = utils.load_torch_file(path, safe_load=True) +def load_lora(lora, to_load): patch_dict = {} loaded_keys = set() for x in to_load: @@ -501,10 +500,10 @@ class ModelPatcher: self.backup = {} -def load_lora_for_models(model, clip, lora_path, strength_model, strength_clip): +def load_lora_for_models(model, clip, lora, strength_model, strength_clip): key_map = model_lora_keys(model.model) key_map = model_lora_keys(clip.cond_stage_model, key_map) - loaded = load_lora(lora_path, key_map) + loaded = load_lora(lora, key_map) new_modelpatcher = model.clone() k = new_modelpatcher.add_patches(loaded, strength_model) new_clip = clip.clone() diff --git a/nodes.py b/nodes.py index f10515f8..a9f2e962 100644 --- a/nodes.py +++ b/nodes.py @@ -434,6 +434,9 @@ class CLIPSetLastLayer: return (clip,) class LoraLoader: + def __init__(self): + self.loaded_lora = None + @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), @@ -452,7 +455,18 @@ class LoraLoader: return (model, clip) lora_path = folder_paths.get_full_path("loras", lora_name) - model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip) + lora = None + if self.loaded_lora is not None: + if self.loaded_lora[0] == lora_path: + lora = self.loaded_lora[1] + else: + del self.loaded_lora + + if lora is None: + lora = comfy.utils.load_torch_file(lora_path, safe_load=True) + self.loaded_lora = (lora_path, lora) + + model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip) return (model_lora, clip_lora) class VAELoader: