mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-28 18:37:23 +08:00
Implement basic MemoryCounter system for determing with cached weights due to hooks should be offloaded in hooks_backup
This commit is contained in:
parent
c422553b0b
commit
d3229cbba7
@ -221,6 +221,25 @@ class PatcherInjection:
|
||||
self.inject = inject
|
||||
self.eject = eject
|
||||
|
||||
class MemoryCounter:
|
||||
def __init__(self, initial: int, minimum=0):
|
||||
self.value = initial
|
||||
self.minimum = minimum
|
||||
# TODO: add a safe limit besides 0
|
||||
|
||||
def use(self, weight: torch.Tensor):
|
||||
weight_size = weight.nelement() * weight.element_size()
|
||||
if self.is_useable(weight_size):
|
||||
self.decrement(weight_size)
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_useable(self, used: int):
|
||||
return self.value - used > self.minimum
|
||||
|
||||
def decrement(self, used: int):
|
||||
self.value -= used
|
||||
|
||||
class ModelPatcher:
|
||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||
self.size = size
|
||||
@ -1007,6 +1026,9 @@ class ModelPatcher:
|
||||
with self.use_ejected():
|
||||
self.unpatch_hooks()
|
||||
model_sd = self.model_state_dict()
|
||||
memory_counter = None
|
||||
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
||||
memory_counter = MemoryCounter(comfy.model_management.get_free_memory(self.load_device))
|
||||
# if have cached weights for hooks, use it
|
||||
cached_weights = self.cached_hook_patches.get(hooks, None)
|
||||
if cached_weights is not None:
|
||||
@ -1014,7 +1036,7 @@ class ModelPatcher:
|
||||
if key not in model_sd:
|
||||
print(f"WARNING cached hook could not patch. key does not exist in model: {key}")
|
||||
continue
|
||||
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key)
|
||||
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
||||
else:
|
||||
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
|
||||
original_weights = None
|
||||
@ -1024,15 +1046,18 @@ class ModelPatcher:
|
||||
if key not in model_sd:
|
||||
print(f"WARNING cached hook would not patch. key does not exist in model: {key}")
|
||||
continue
|
||||
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights)
|
||||
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
||||
memory_counter=memory_counter)
|
||||
self.current_hooks = hooks
|
||||
|
||||
def patch_cached_hook_weights(self, cached_weights: Dict, key: str):
|
||||
def patch_cached_hook_weights(self, cached_weights: Dict, key: str, memory_counter: MemoryCounter):
|
||||
if key not in self.hook_backup:
|
||||
weight: torch.Tensor = comfy.utils.get_attr(self.model, key)
|
||||
target_device = self.offload_device
|
||||
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
||||
target_device = weight.device
|
||||
used = memory_counter.use(weight)
|
||||
if used:
|
||||
target_device = weight.device
|
||||
self.hook_backup[key] = (weight.to(device=target_device, copy=self.weight_inplace_update), weight.device)
|
||||
if self.weight_inplace_update:
|
||||
comfy.utils.copy_to_param(self.model, key, cached_weights[key])
|
||||
@ -1043,14 +1068,16 @@ class ModelPatcher:
|
||||
self.cached_hook_patches.clear()
|
||||
self.current_hooks = None
|
||||
|
||||
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict):
|
||||
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
|
||||
if key not in combined_patches:
|
||||
return
|
||||
weight: torch.Tensor = comfy.utils.get_attr(self.model, key)
|
||||
if key not in self.hook_backup:
|
||||
target_device = self.offload_device
|
||||
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
||||
target_device = weight.device
|
||||
used = memory_counter.use(weight)
|
||||
if used:
|
||||
target_device = weight.device
|
||||
self.hook_backup[key] = (weight.to(device=target_device, copy=self.weight_inplace_update), weight.device)
|
||||
|
||||
# TODO: properly handle lowvram situations for cached hook patches
|
||||
@ -1058,6 +1085,7 @@ class ModelPatcher:
|
||||
out_weight = comfy.lora.calculate_weight(combined_patches[key], temp_weight, key, original_weights=original_weights).to(weight.dtype)
|
||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
||||
# TODO: disable caching if not enough system RAM to do so
|
||||
self.cached_hook_patches.setdefault(hooks, {})
|
||||
self.cached_hook_patches[hooks][key] = out_weight
|
||||
if self.weight_inplace_update:
|
||||
@ -1073,16 +1101,10 @@ class ModelPatcher:
|
||||
keys = list(self.hook_backup.keys())
|
||||
if self.weight_inplace_update:
|
||||
for k in keys:
|
||||
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: # does not need to be cast; device already matches
|
||||
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0])
|
||||
else:
|
||||
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||
else:
|
||||
for k in keys:
|
||||
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
||||
comfy.utils.set_attr_param(self.model, k, self.hook_backup[k][0])
|
||||
else:
|
||||
comfy.utils.set_attr_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||
comfy.utils.set_attr_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||
|
||||
self.hook_backup.clear()
|
||||
self.current_hooks = None
|
||||
|
Loading…
x
Reference in New Issue
Block a user