mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Speed up lora loading a bit.
This commit is contained in:
parent
50b1180dde
commit
490771b7f4
@ -258,15 +258,11 @@ def load_model_gpu(model):
|
|||||||
if model is current_loaded_model:
|
if model is current_loaded_model:
|
||||||
return
|
return
|
||||||
unload_model()
|
unload_model()
|
||||||
try:
|
|
||||||
real_model = model.patch_model()
|
|
||||||
except Exception as e:
|
|
||||||
model.unpatch_model()
|
|
||||||
raise e
|
|
||||||
|
|
||||||
torch_dev = model.load_device
|
torch_dev = model.load_device
|
||||||
model.model_patches_to(torch_dev)
|
model.model_patches_to(torch_dev)
|
||||||
model.model_patches_to(model.model_dtype())
|
model.model_patches_to(model.model_dtype())
|
||||||
|
current_loaded_model = model
|
||||||
|
|
||||||
if is_device_cpu(torch_dev):
|
if is_device_cpu(torch_dev):
|
||||||
vram_set_state = VRAMState.DISABLED
|
vram_set_state = VRAMState.DISABLED
|
||||||
@ -280,8 +276,7 @@ def load_model_gpu(model):
|
|||||||
if model_size > (current_free_mem - minimum_inference_memory()): #only switch to lowvram if really necessary
|
if model_size > (current_free_mem - minimum_inference_memory()): #only switch to lowvram if really necessary
|
||||||
vram_set_state = VRAMState.LOW_VRAM
|
vram_set_state = VRAMState.LOW_VRAM
|
||||||
|
|
||||||
current_loaded_model = model
|
real_model = model.model
|
||||||
|
|
||||||
if vram_set_state == VRAMState.DISABLED:
|
if vram_set_state == VRAMState.DISABLED:
|
||||||
pass
|
pass
|
||||||
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
|
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
|
||||||
@ -295,6 +290,14 @@ def load_model_gpu(model):
|
|||||||
|
|
||||||
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
|
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
|
||||||
model_accelerated = True
|
model_accelerated = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
real_model = model.patch_model()
|
||||||
|
except Exception as e:
|
||||||
|
model.unpatch_model()
|
||||||
|
unload_model()
|
||||||
|
raise e
|
||||||
|
|
||||||
return current_loaded_model
|
return current_loaded_model
|
||||||
|
|
||||||
def load_controlnet_gpu(control_models):
|
def load_controlnet_gpu(control_models):
|
||||||
|
33
comfy/sd.py
33
comfy/sd.py
@ -340,7 +340,7 @@ class ModelPatcher:
|
|||||||
weight = model_sd[key]
|
weight = model_sd[key]
|
||||||
|
|
||||||
if key not in self.backup:
|
if key not in self.backup:
|
||||||
self.backup[key] = weight.clone()
|
self.backup[key] = weight.to(self.offload_device, copy=True)
|
||||||
|
|
||||||
temp_weight = weight.to(torch.float32, copy=True)
|
temp_weight = weight.to(torch.float32, copy=True)
|
||||||
weight[:] = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
weight[:] = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||||
@ -367,15 +367,16 @@ class ModelPatcher:
|
|||||||
else:
|
else:
|
||||||
weight += alpha * w1.type(weight.dtype).to(weight.device)
|
weight += alpha * w1.type(weight.dtype).to(weight.device)
|
||||||
elif len(v) == 4: #lora/locon
|
elif len(v) == 4: #lora/locon
|
||||||
mat1 = v[0]
|
mat1 = v[0].float().to(weight.device)
|
||||||
mat2 = v[1]
|
mat2 = v[1].float().to(weight.device)
|
||||||
if v[2] is not None:
|
if v[2] is not None:
|
||||||
alpha *= v[2] / mat2.shape[0]
|
alpha *= v[2] / mat2.shape[0]
|
||||||
if v[3] is not None:
|
if v[3] is not None:
|
||||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||||
final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]]
|
mat3 = v[3].float().to(weight.device)
|
||||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1)
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||||
weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||||
|
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
|
||||||
elif len(v) == 8: #lokr
|
elif len(v) == 8: #lokr
|
||||||
w1 = v[0]
|
w1 = v[0]
|
||||||
w2 = v[1]
|
w2 = v[1]
|
||||||
@ -389,20 +390,24 @@ class ModelPatcher:
|
|||||||
if w1 is None:
|
if w1 is None:
|
||||||
dim = w1_b.shape[0]
|
dim = w1_b.shape[0]
|
||||||
w1 = torch.mm(w1_a.float(), w1_b.float())
|
w1 = torch.mm(w1_a.float(), w1_b.float())
|
||||||
|
else:
|
||||||
|
w1 = w1.float().to(weight.device)
|
||||||
|
|
||||||
if w2 is None:
|
if w2 is None:
|
||||||
dim = w2_b.shape[0]
|
dim = w2_b.shape[0]
|
||||||
if t2 is None:
|
if t2 is None:
|
||||||
w2 = torch.mm(w2_a.float(), w2_b.float())
|
w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device))
|
||||||
else:
|
else:
|
||||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2_b.float(), w2_a.float())
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device))
|
||||||
|
else:
|
||||||
|
w2 = w2.float().to(weight.device)
|
||||||
|
|
||||||
if len(w2.shape) == 4:
|
if len(w2.shape) == 4:
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
if v[2] is not None and dim is not None:
|
if v[2] is not None and dim is not None:
|
||||||
alpha *= v[2] / dim
|
alpha *= v[2] / dim
|
||||||
|
|
||||||
weight += alpha * torch.kron(w1.float(), w2.float()).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
|
||||||
else: #loha
|
else: #loha
|
||||||
w1a = v[0]
|
w1a = v[0]
|
||||||
w1b = v[1]
|
w1b = v[1]
|
||||||
@ -413,13 +418,13 @@ class ModelPatcher:
|
|||||||
if v[5] is not None: #cp decomposition
|
if v[5] is not None: #cp decomposition
|
||||||
t1 = v[5]
|
t1 = v[5]
|
||||||
t2 = v[6]
|
t2 = v[6]
|
||||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float(), w1b.float(), w1a.float())
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device))
|
||||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2b.float(), w2a.float())
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device))
|
||||||
else:
|
else:
|
||||||
m1 = torch.mm(w1a.float(), w1b.float())
|
m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device))
|
||||||
m2 = torch.mm(w2a.float(), w2b.float())
|
m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device))
|
||||||
|
|
||||||
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
def unpatch_model(self):
|
def unpatch_model(self):
|
||||||
|
@ -4,18 +4,20 @@ import struct
|
|||||||
import comfy.checkpoint_pickle
|
import comfy.checkpoint_pickle
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False):
|
def load_torch_file(ckpt, safe_load=False, device=None):
|
||||||
|
if device is None:
|
||||||
|
device = torch.device("cpu")
|
||||||
if ckpt.lower().endswith(".safetensors"):
|
if ckpt.lower().endswith(".safetensors"):
|
||||||
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
sd = safetensors.torch.load_file(ckpt, device=device.type)
|
||||||
else:
|
else:
|
||||||
if safe_load:
|
if safe_load:
|
||||||
if not 'weights_only' in torch.load.__code__.co_varnames:
|
if not 'weights_only' in torch.load.__code__.co_varnames:
|
||||||
print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
|
print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
|
||||||
safe_load = False
|
safe_load = False
|
||||||
if safe_load:
|
if safe_load:
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
|
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
|
||||||
else:
|
else:
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu", pickle_module=comfy.checkpoint_pickle)
|
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
|
||||||
if "global_step" in pl_sd:
|
if "global_step" in pl_sd:
|
||||||
print(f"Global Step: {pl_sd['global_step']}")
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
if "state_dict" in pl_sd:
|
if "state_dict" in pl_sd:
|
||||||
|
Loading…
Reference in New Issue
Block a user