Implement basic MemoryCounter system for determing with cached weights due to hooks should be offloaded in hooks_backup

This commit is contained in:
kosinkadink1@gmail.com 2024-09-24 17:28:18 +09:00
parent c422553b0b
commit d3229cbba7

View File

@ -221,6 +221,25 @@ class PatcherInjection:
self.inject = inject
self.eject = eject
class MemoryCounter:
def __init__(self, initial: int, minimum=0):
self.value = initial
self.minimum = minimum
# TODO: add a safe limit besides 0
def use(self, weight: torch.Tensor):
weight_size = weight.nelement() * weight.element_size()
if self.is_useable(weight_size):
self.decrement(weight_size)
return True
return False
def is_useable(self, used: int):
return self.value - used > self.minimum
def decrement(self, used: int):
self.value -= used
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
self.size = size
@ -1007,6 +1026,9 @@ class ModelPatcher:
with self.use_ejected():
self.unpatch_hooks()
model_sd = self.model_state_dict()
memory_counter = None
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
memory_counter = MemoryCounter(comfy.model_management.get_free_memory(self.load_device))
# if have cached weights for hooks, use it
cached_weights = self.cached_hook_patches.get(hooks, None)
if cached_weights is not None:
@ -1014,7 +1036,7 @@ class ModelPatcher:
if key not in model_sd:
print(f"WARNING cached hook could not patch. key does not exist in model: {key}")
continue
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key)
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
else:
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
original_weights = None
@ -1024,15 +1046,18 @@ class ModelPatcher:
if key not in model_sd:
print(f"WARNING cached hook would not patch. key does not exist in model: {key}")
continue
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights)
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
memory_counter=memory_counter)
self.current_hooks = hooks
def patch_cached_hook_weights(self, cached_weights: Dict, key: str):
def patch_cached_hook_weights(self, cached_weights: Dict, key: str, memory_counter: MemoryCounter):
if key not in self.hook_backup:
weight: torch.Tensor = comfy.utils.get_attr(self.model, key)
target_device = self.offload_device
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
target_device = weight.device
used = memory_counter.use(weight)
if used:
target_device = weight.device
self.hook_backup[key] = (weight.to(device=target_device, copy=self.weight_inplace_update), weight.device)
if self.weight_inplace_update:
comfy.utils.copy_to_param(self.model, key, cached_weights[key])
@ -1043,14 +1068,16 @@ class ModelPatcher:
self.cached_hook_patches.clear()
self.current_hooks = None
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict):
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
if key not in combined_patches:
return
weight: torch.Tensor = comfy.utils.get_attr(self.model, key)
if key not in self.hook_backup:
target_device = self.offload_device
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
target_device = weight.device
used = memory_counter.use(weight)
if used:
target_device = weight.device
self.hook_backup[key] = (weight.to(device=target_device, copy=self.weight_inplace_update), weight.device)
# TODO: properly handle lowvram situations for cached hook patches
@ -1058,6 +1085,7 @@ class ModelPatcher:
out_weight = comfy.lora.calculate_weight(combined_patches[key], temp_weight, key, original_weights=original_weights).to(weight.dtype)
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
# TODO: disable caching if not enough system RAM to do so
self.cached_hook_patches.setdefault(hooks, {})
self.cached_hook_patches[hooks][key] = out_weight
if self.weight_inplace_update:
@ -1073,16 +1101,10 @@ class ModelPatcher:
keys = list(self.hook_backup.keys())
if self.weight_inplace_update:
for k in keys:
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: # does not need to be cast; device already matches
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0])
else:
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
else:
for k in keys:
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
comfy.utils.set_attr_param(self.model, k, self.hook_backup[k][0])
else:
comfy.utils.set_attr_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
comfy.utils.set_attr_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
self.hook_backup.clear()
self.current_hooks = None