mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Try to free enough vram for control lora inference.
This commit is contained in:
parent
e3d0a9a490
commit
51dde87e97
@ -394,6 +394,12 @@ def cleanup_models():
|
||||
x.model_unload()
|
||||
del x
|
||||
|
||||
def dtype_size(dtype):
|
||||
dtype_size = 4
|
||||
if dtype == torch.float16 or dtype == torch.bfloat16:
|
||||
dtype_size = 2
|
||||
return dtype_size
|
||||
|
||||
def unet_offload_device():
|
||||
if vram_state == VRAMState.HIGH_VRAM:
|
||||
return get_torch_device()
|
||||
@ -409,11 +415,7 @@ def unet_inital_load_device(parameters, dtype):
|
||||
if DISABLE_SMART_MEMORY:
|
||||
return cpu_dev
|
||||
|
||||
dtype_size = 4
|
||||
if dtype == torch.float16 or dtype == torch.bfloat16:
|
||||
dtype_size = 2
|
||||
|
||||
model_size = dtype_size * parameters
|
||||
model_size = dtype_size(dtype) * parameters
|
||||
|
||||
mem_dev = get_free_memory(torch_dev)
|
||||
mem_cpu = get_free_memory(cpu_dev)
|
||||
|
@ -51,18 +51,20 @@ def get_models_from_cond(cond, model_type):
|
||||
models += [c[1][model_type]]
|
||||
return models
|
||||
|
||||
def get_additional_models(positive, negative):
|
||||
def get_additional_models(positive, negative, dtype):
|
||||
"""loads additional models in positive and negative conditioning"""
|
||||
control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))
|
||||
|
||||
inference_memory = 0
|
||||
control_models = []
|
||||
for m in control_nets:
|
||||
control_models += m.get_models()
|
||||
inference_memory += m.inference_memory_requirements(dtype)
|
||||
|
||||
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
|
||||
gligen = [x[1] for x in gligen]
|
||||
models = control_models + gligen
|
||||
return models
|
||||
return models, inference_memory
|
||||
|
||||
def cleanup_additional_models(models):
|
||||
"""cleanup additional models that were loaded"""
|
||||
@ -77,8 +79,8 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
||||
noise_mask = prepare_mask(noise_mask, noise.shape, device)
|
||||
|
||||
real_model = None
|
||||
models = get_additional_models(positive, negative)
|
||||
comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3]))
|
||||
models, inference_memory = get_additional_models(positive, negative, model.model_dtype())
|
||||
comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3]) + inference_memory)
|
||||
real_model = model.model
|
||||
|
||||
noise = noise.to(device)
|
||||
|
19
comfy/sd.py
19
comfy/sd.py
@ -779,6 +779,11 @@ class ControlBase:
|
||||
c.strength = self.strength
|
||||
c.timestep_percent_range = self.timestep_percent_range
|
||||
|
||||
def inference_memory_requirements(self, dtype):
|
||||
if self.previous_controlnet is not None:
|
||||
return self.previous_controlnet.inference_memory_requirements(dtype)
|
||||
return 0
|
||||
|
||||
def control_merge(self, control_input, control_output, control_prev, output_dtype):
|
||||
out = {'input':[], 'middle':[], 'output': []}
|
||||
|
||||
@ -985,6 +990,9 @@ class ControlLora(ControlNet):
|
||||
out = ControlBase.get_models(self)
|
||||
return out
|
||||
|
||||
def inference_memory_requirements(self, dtype):
|
||||
return utils.calculate_parameters(self.control_weights) * model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
||||
|
||||
def load_controlnet(ckpt_path, model=None):
|
||||
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
||||
if "lora_controlnet" in controlnet_data:
|
||||
@ -1323,13 +1331,6 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
|
||||
return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
|
||||
|
||||
def calculate_parameters(sd, prefix):
|
||||
params = 0
|
||||
for k in sd.keys():
|
||||
if k.startswith(prefix):
|
||||
params += sd[k].nelement()
|
||||
return params
|
||||
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None):
|
||||
sd = utils.load_torch_file(ckpt_path)
|
||||
sd_keys = sd.keys()
|
||||
@ -1339,7 +1340,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
model = None
|
||||
clip_target = None
|
||||
|
||||
parameters = calculate_parameters(sd, "model.diffusion_model.")
|
||||
parameters = utils.calculate_parameters(sd, "model.diffusion_model.")
|
||||
fp16 = model_management.should_use_fp16(model_params=parameters)
|
||||
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
@ -1390,7 +1391,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
|
||||
def load_unet(unet_path): #load unet in diffusers format
|
||||
sd = utils.load_torch_file(unet_path)
|
||||
parameters = calculate_parameters(sd, "")
|
||||
parameters = utils.calculate_parameters(sd)
|
||||
fp16 = model_management.should_use_fp16(model_params=parameters)
|
||||
|
||||
model_config = model_detection.model_config_from_diffusers_unet(sd, fp16)
|
||||
|
@ -32,6 +32,13 @@ def save_torch_file(sd, ckpt, metadata=None):
|
||||
else:
|
||||
safetensors.torch.save_file(sd, ckpt)
|
||||
|
||||
def calculate_parameters(sd, prefix=""):
|
||||
params = 0
|
||||
for k in sd.keys():
|
||||
if k.startswith(prefix):
|
||||
params += sd[k].nelement()
|
||||
return params
|
||||
|
||||
def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||
keys_to_replace = {
|
||||
"{}positional_embedding": "{}embeddings.position_embedding.weight",
|
||||
|
Loading…
Reference in New Issue
Block a user