mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-13 15:03:33 +00:00
382 lines
17 KiB
Python
382 lines
17 KiB
Python
"""
|
|
This file is part of ComfyUI.
|
|
Copyright (C) 2024 Comfy
|
|
|
|
This program is free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation, either version 3 of the License, or
|
|
(at your option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
import comfy.utils
|
|
import comfy.model_management
|
|
import comfy.model_base
|
|
import comfy.weight_adapter as weight_adapter
|
|
import logging
|
|
import torch
|
|
|
|
LORA_CLIP_MAP = {
|
|
"mlp.fc1": "mlp_fc1",
|
|
"mlp.fc2": "mlp_fc2",
|
|
"self_attn.k_proj": "self_attn_k_proj",
|
|
"self_attn.q_proj": "self_attn_q_proj",
|
|
"self_attn.v_proj": "self_attn_v_proj",
|
|
"self_attn.out_proj": "self_attn_out_proj",
|
|
}
|
|
|
|
|
|
def load_lora(lora, to_load, log_missing=True):
|
|
patch_dict = {}
|
|
loaded_keys = set()
|
|
for x in to_load:
|
|
alpha_name = "{}.alpha".format(x)
|
|
alpha = None
|
|
if alpha_name in lora.keys():
|
|
alpha = lora[alpha_name].item()
|
|
loaded_keys.add(alpha_name)
|
|
|
|
dora_scale_name = "{}.dora_scale".format(x)
|
|
dora_scale = None
|
|
if dora_scale_name in lora.keys():
|
|
dora_scale = lora[dora_scale_name]
|
|
loaded_keys.add(dora_scale_name)
|
|
|
|
for adapter_cls in weight_adapter.adapters:
|
|
adapter = adapter_cls.load(x, lora, alpha, dora_scale, loaded_keys)
|
|
if adapter is not None:
|
|
patch_dict[to_load[x]] = adapter
|
|
loaded_keys.update(adapter.loaded_keys)
|
|
continue
|
|
|
|
w_norm_name = "{}.w_norm".format(x)
|
|
b_norm_name = "{}.b_norm".format(x)
|
|
w_norm = lora.get(w_norm_name, None)
|
|
b_norm = lora.get(b_norm_name, None)
|
|
|
|
if w_norm is not None:
|
|
loaded_keys.add(w_norm_name)
|
|
patch_dict[to_load[x]] = ("diff", (w_norm,))
|
|
if b_norm is not None:
|
|
loaded_keys.add(b_norm_name)
|
|
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
|
|
|
|
diff_name = "{}.diff".format(x)
|
|
diff_weight = lora.get(diff_name, None)
|
|
if diff_weight is not None:
|
|
patch_dict[to_load[x]] = ("diff", (diff_weight,))
|
|
loaded_keys.add(diff_name)
|
|
|
|
diff_bias_name = "{}.diff_b".format(x)
|
|
diff_bias = lora.get(diff_bias_name, None)
|
|
if diff_bias is not None:
|
|
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
|
|
loaded_keys.add(diff_bias_name)
|
|
|
|
set_weight_name = "{}.set_weight".format(x)
|
|
set_weight = lora.get(set_weight_name, None)
|
|
if set_weight is not None:
|
|
patch_dict[to_load[x]] = ("set", (set_weight,))
|
|
loaded_keys.add(set_weight_name)
|
|
|
|
if log_missing:
|
|
for x in lora.keys():
|
|
if x not in loaded_keys:
|
|
logging.warning("lora key not loaded: {}".format(x))
|
|
|
|
return patch_dict
|
|
|
|
def model_lora_keys_clip(model, key_map={}):
|
|
sdk = model.state_dict().keys()
|
|
for k in sdk:
|
|
if k.endswith(".weight"):
|
|
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
|
|
|
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
|
clip_l_present = False
|
|
clip_g_present = False
|
|
for b in range(32): #TODO: clean up
|
|
for c in LORA_CLIP_MAP:
|
|
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
|
if k in sdk:
|
|
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
|
key_map[lora_key] = k
|
|
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
|
|
key_map[lora_key] = k
|
|
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
|
key_map[lora_key] = k
|
|
|
|
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
|
if k in sdk:
|
|
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
|
key_map[lora_key] = k
|
|
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
|
key_map[lora_key] = k
|
|
clip_l_present = True
|
|
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
|
key_map[lora_key] = k
|
|
|
|
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
|
if k in sdk:
|
|
clip_g_present = True
|
|
if clip_l_present:
|
|
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
|
key_map[lora_key] = k
|
|
lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
|
key_map[lora_key] = k
|
|
else:
|
|
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner
|
|
key_map[lora_key] = k
|
|
lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
|
|
key_map[lora_key] = k
|
|
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config
|
|
key_map[lora_key] = k
|
|
|
|
for k in sdk:
|
|
if k.endswith(".weight"):
|
|
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 and Flux lora
|
|
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
|
t5_index = 1
|
|
if clip_g_present:
|
|
t5_index += 1
|
|
if clip_l_present:
|
|
t5_index += 1
|
|
if t5_index == 2:
|
|
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k #OneTrainer Flux
|
|
t5_index += 1
|
|
|
|
key_map["lora_te{}_{}".format(t5_index, l_key.replace(".", "_"))] = k
|
|
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
|
|
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
|
|
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
|
key_map[lora_key] = k
|
|
|
|
|
|
k = "clip_g.transformer.text_projection.weight"
|
|
if k in sdk:
|
|
key_map["lora_prior_te_text_projection"] = k #cascade lora?
|
|
# key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too
|
|
key_map["lora_te2_text_projection"] = k #OneTrainer SD3 lora
|
|
|
|
k = "clip_l.transformer.text_projection.weight"
|
|
if k in sdk:
|
|
key_map["lora_te1_text_projection"] = k #OneTrainer SD3 lora, not necessary but omits warning
|
|
|
|
return key_map
|
|
|
|
def model_lora_keys_unet(model, key_map={}):
|
|
sd = model.state_dict()
|
|
sdk = sd.keys()
|
|
|
|
for k in sdk:
|
|
if k.startswith("diffusion_model."):
|
|
if k.endswith(".weight"):
|
|
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
|
key_map["lora_unet_{}".format(key_lora)] = k
|
|
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
|
else:
|
|
key_map["{}".format(k)] = k #generic lora format for not .weight without any weird key names
|
|
|
|
diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
|
|
for k in diffusers_keys:
|
|
if k.endswith(".weight"):
|
|
unet_key = "diffusion_model.{}".format(diffusers_keys[k])
|
|
key_lora = k[:-len(".weight")].replace(".", "_")
|
|
key_map["lora_unet_{}".format(key_lora)] = unet_key
|
|
key_map["lycoris_{}".format(key_lora)] = unet_key #simpletuner lycoris format
|
|
|
|
diffusers_lora_prefix = ["", "unet."]
|
|
for p in diffusers_lora_prefix:
|
|
diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_"))
|
|
if diffusers_lora_key.endswith(".to_out.0"):
|
|
diffusers_lora_key = diffusers_lora_key[:-2]
|
|
key_map[diffusers_lora_key] = unet_key
|
|
|
|
if isinstance(model, comfy.model_base.StableCascade_C):
|
|
for k in sdk:
|
|
if k.startswith("diffusion_model."):
|
|
if k.endswith(".weight"):
|
|
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
|
key_map["lora_prior_unet_{}".format(key_lora)] = k
|
|
|
|
if isinstance(model, comfy.model_base.SD3): #Diffusers lora SD3
|
|
diffusers_keys = comfy.utils.mmdit_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
|
for k in diffusers_keys:
|
|
if k.endswith(".weight"):
|
|
to = diffusers_keys[k]
|
|
key_lora = "transformer.{}".format(k[:-len(".weight")]) #regular diffusers sd3 lora format
|
|
key_map[key_lora] = to
|
|
|
|
key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #format for flash-sd3 lora and others?
|
|
key_map[key_lora] = to
|
|
|
|
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
|
|
key_map[key_lora] = to
|
|
|
|
key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) #simpletuner lycoris format
|
|
key_map[key_lora] = to
|
|
|
|
if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
|
|
diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
|
for k in diffusers_keys:
|
|
if k.endswith(".weight"):
|
|
to = diffusers_keys[k]
|
|
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format
|
|
key_map[key_lora] = to
|
|
|
|
if isinstance(model, comfy.model_base.PixArt):
|
|
diffusers_keys = comfy.utils.pixart_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
|
for k in diffusers_keys:
|
|
if k.endswith(".weight"):
|
|
to = diffusers_keys[k]
|
|
key_lora = "transformer.{}".format(k[:-len(".weight")]) #default format
|
|
key_map[key_lora] = to
|
|
|
|
key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #diffusers training script
|
|
key_map[key_lora] = to
|
|
|
|
key_lora = "unet.base_model.model.{}".format(k[:-len(".weight")]) #old reference peft script
|
|
key_map[key_lora] = to
|
|
|
|
if isinstance(model, comfy.model_base.HunyuanDiT):
|
|
for k in sdk:
|
|
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
|
key_lora = k[len("diffusion_model."):-len(".weight")]
|
|
key_map["base_model.model.{}".format(key_lora)] = k #official hunyuan lora format
|
|
|
|
if isinstance(model, comfy.model_base.Flux): #Diffusers lora Flux
|
|
diffusers_keys = comfy.utils.flux_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
|
for k in diffusers_keys:
|
|
if k.endswith(".weight"):
|
|
to = diffusers_keys[k]
|
|
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
|
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
|
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
|
|
|
if isinstance(model, comfy.model_base.GenmoMochi):
|
|
for k in sdk:
|
|
if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official Mochi lora format
|
|
key_lora = k[len("diffusion_model."):-len(".weight")]
|
|
key_map["{}".format(key_lora)] = k
|
|
|
|
if isinstance(model, comfy.model_base.HunyuanVideo):
|
|
for k in sdk:
|
|
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
|
# diffusion-pipe lora format
|
|
key_lora = k
|
|
key_lora = key_lora.replace("_mod.lin.", "_mod.linear.").replace("_attn.qkv.", "_attn_qkv.").replace("_attn.proj.", "_attn_proj.")
|
|
key_lora = key_lora.replace("mlp.0.", "mlp.fc1.").replace("mlp.2.", "mlp.fc2.")
|
|
key_lora = key_lora.replace(".modulation.lin.", ".modulation.linear.")
|
|
key_lora = key_lora[len("diffusion_model."):-len(".weight")]
|
|
key_map["transformer.{}".format(key_lora)] = k
|
|
key_map["diffusion_model.{}".format(key_lora)] = k # Old loras
|
|
|
|
return key_map
|
|
|
|
|
|
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
|
"""
|
|
Pad a tensor to a new shape with zeros.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The original tensor to be padded.
|
|
new_shape (List[int]): The desired shape of the padded tensor.
|
|
|
|
Returns:
|
|
torch.Tensor: A new tensor padded with zeros to the specified shape.
|
|
|
|
Note:
|
|
If the new shape is smaller than the original tensor in any dimension,
|
|
the original tensor will be truncated in that dimension.
|
|
"""
|
|
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
|
|
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
|
|
|
|
if len(new_shape) != len(tensor.shape):
|
|
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
|
|
|
|
# Create a new tensor filled with zeros
|
|
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
|
|
|
|
# Create slicing tuples for both tensors
|
|
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
|
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
|
|
|
# Copy the original tensor into the new tensor
|
|
padded_tensor[new_slices] = tensor[orig_slices]
|
|
|
|
return padded_tensor
|
|
|
|
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
|
|
for p in patches:
|
|
strength = p[0]
|
|
v = p[1]
|
|
strength_model = p[2]
|
|
offset = p[3]
|
|
function = p[4]
|
|
if function is None:
|
|
function = lambda a: a
|
|
|
|
old_weight = None
|
|
if offset is not None:
|
|
old_weight = weight
|
|
weight = weight.narrow(offset[0], offset[1], offset[2])
|
|
|
|
if strength_model != 1.0:
|
|
weight *= strength_model
|
|
|
|
if isinstance(v, list):
|
|
v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
|
|
|
|
if isinstance(v, weight_adapter.WeightAdapterBase):
|
|
output = v.calculate_weight(weight, key, strength, strength_model, offset, function, intermediate_dtype, original_weights)
|
|
if output is None:
|
|
logging.warning("Calculate Weight Failed: {} {}".format(v.name, key))
|
|
else:
|
|
weight = output
|
|
if old_weight is not None:
|
|
weight = old_weight
|
|
continue
|
|
|
|
if len(v) == 1:
|
|
patch_type = "diff"
|
|
elif len(v) == 2:
|
|
patch_type = v[0]
|
|
v = v[1]
|
|
|
|
if patch_type == "diff":
|
|
diff: torch.Tensor = v[0]
|
|
# An extra flag to pad the weight if the diff's shape is larger than the weight
|
|
do_pad_weight = len(v) > 1 and v[1]['pad_weight']
|
|
if do_pad_weight and diff.shape != weight.shape:
|
|
logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
|
|
weight = pad_tensor_to_shape(weight, diff.shape)
|
|
|
|
if strength != 0.0:
|
|
if diff.shape != weight.shape:
|
|
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
|
|
else:
|
|
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
|
|
elif patch_type == "set":
|
|
weight.copy_(v[0])
|
|
elif patch_type == "model_as_lora":
|
|
target_weight: torch.Tensor = v[0]
|
|
diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
|
|
comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
|
|
weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
|
|
else:
|
|
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
|
|
|
if old_weight is not None:
|
|
weight = old_weight
|
|
|
|
return weight
|