diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 2c683320..208fee06 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -97,7 +97,7 @@ def wipe_lowvram_weight(m): m.comfy_cast_weights = m.prev_comfy_cast_weights del m.prev_comfy_cast_weights - if not hasattr(m, "weight_function"): + if hasattr(m, "weight_function"): m.weight_function = [] if hasattr(m, "bias_function"): @@ -781,6 +781,8 @@ class ModelPatcher: if move_weight: cast_weight = self.force_cast_weights m.to(device_to) + if not hasattr(m, "weight_function"): + m.weight_function=[] module_mem += move_weight_functions(m, device_to) if lowvram_possible: if weight_key in self.patches: