From 019c7029ea324517ab88d7e61e79b739bc8f4e91 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 13 Feb 2025 20:34:03 -0500 Subject: [PATCH] Add a way to set a different compute dtype for the model at runtime. Currently only works for diffusion models. --- comfy/model_patcher.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index aee0164c5..4dbe1b7aa 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -218,6 +218,7 @@ class ModelPatcher: self.load_device = load_device self.offload_device = offload_device self.weight_inplace_update = weight_inplace_update + self.force_cast_weights = False self.patches_uuid = uuid.uuid4() self.parent = None @@ -277,6 +278,8 @@ class ModelPatcher: n.object_patches_backup = self.object_patches_backup n.parent = self + n.force_cast_weights = self.force_cast_weights + # attachments n.attachments = {} for k in self.attachments: @@ -424,6 +427,12 @@ class ModelPatcher: def add_object_patch(self, name, obj): self.object_patches[name] = obj + def set_model_compute_dtype(self, dtype): + self.add_object_patch("manual_cast_dtype", dtype) + if dtype is not None: + self.force_cast_weights = True + self.patches_uuid = uuid.uuid4() #TODO: optimize by preventing a full model reload for this + def add_weight_wrapper(self, name, function): self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function] self.patches_uuid = uuid.uuid4() @@ -602,6 +611,7 @@ class ModelPatcher: if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed continue + cast_weight = self.force_cast_weights if lowvram_weight: if hasattr(m, "comfy_cast_weights"): m.weight_function = [] @@ -620,8 +630,7 @@ class ModelPatcher: m.bias_function = [LowVramPatch(bias_key, self.patches)] patch_counter += 1 - m.prev_comfy_cast_weights = m.comfy_cast_weights - m.comfy_cast_weights = True + cast_weight = True else: if hasattr(m, "comfy_cast_weights"): wipe_lowvram_weight(m) @@ -630,6 +639,10 @@ class ModelPatcher: mem_counter += module_mem load_completely.append((module_mem, n, m, params)) + if cast_weight: + m.prev_comfy_cast_weights = m.comfy_cast_weights + m.comfy_cast_weights = True + if weight_key in self.weight_wrapper_patches: m.weight_function.extend(self.weight_wrapper_patches[weight_key]) @@ -766,6 +779,7 @@ class ModelPatcher: weight_key = "{}.weight".format(n) bias_key = "{}.bias".format(n) if move_weight: + cast_weight = self.force_cast_weights m.to(device_to) module_mem += move_weight_functions(m, device_to) if lowvram_possible: @@ -775,7 +789,9 @@ class ModelPatcher: if bias_key in self.patches: m.bias_function.append(LowVramPatch(bias_key, self.patches)) patch_counter += 1 + cast_weight = True + if cast_weight: m.prev_comfy_cast_weights = m.comfy_cast_weights m.comfy_cast_weights = True m.comfy_patched_weights = False