diff --git a/comfy/model_management.py b/comfy/model_management.py index b1afeb71..24170692 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -281,19 +281,23 @@ def load_model_gpu(model): vram_set_state = VRAMState.LOW_VRAM real_model = model.model + patch_model_to = None if vram_set_state == VRAMState.DISABLED: pass elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED: model_accelerated = False - real_model.to(torch_dev) + patch_model_to = torch_dev try: - real_model = model.patch_model() + real_model = model.patch_model(device_to=patch_model_to) except Exception as e: model.unpatch_model() unload_model() raise e + if patch_model_to is not None: + real_model.to(torch_dev) + if vram_set_state == VRAMState.NO_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev) diff --git a/comfy/sd.py b/comfy/sd.py index ddafa0b5..1f364dd1 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -338,7 +338,7 @@ class ModelPatcher: sd.pop(k) return sd - def patch_model(self): + def patch_model(self, device_to=None): model_sd = self.model_state_dict() for key in self.patches: if key not in model_sd: @@ -350,10 +350,13 @@ class ModelPatcher: if key not in self.backup: self.backup[key] = weight.to(self.offload_device) - temp_weight = weight.to(torch.float32, copy=True) + if device_to is not None: + temp_weight = weight.float().to(device_to, copy=True) + else: + temp_weight = weight.to(torch.float32, copy=True) out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) set_attr(self.model, key, out_weight) - del weight + del temp_weight return self.model def calculate_weight(self, patches, weight, key):