From e5396e98d8635218f4a979ca7c09f256e9b3126f Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 21 Mar 2025 14:57:05 -0500 Subject: [PATCH] Add VAELoaderDevice node to device what device to load VAE on --- comfy/multigpu.py | 17 +++++++++++++++++ comfy_extras/nodes_multigpu.py | 24 +++++++++++++++++++++++- nodes.py | 4 ++-- 3 files changed, 42 insertions(+), 3 deletions(-) diff --git a/comfy/multigpu.py b/comfy/multigpu.py index aef0b68e..350a564f 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -56,6 +56,23 @@ class GPUOptionsGroup: value['relative_speed'] /= min_speed model.model_options['multigpu_options'] = opts_dict +def get_torch_device_list(): + devices = ["default"] + for device in comfy.model_management.get_all_torch_devices(): + device: torch.device + devices.append(str(device.index)) + return devices + +def get_device_from_str(device_str: str, throw_error_if_not_found=False): + if device_str == "default": + return comfy.model_management.get_torch_device() + for device in comfy.model_management.get_all_torch_devices(): + device: torch.device + if str(device.index) == device_str: + return device + if throw_error_if_not_found: + raise Exception(f"Device with index '{device_str}' not found.") + logging.warning(f"Device with index '{device_str}' not found, using default device ({comfy.model_management.get_torch_device()}) instead.") def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False): 'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.' diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 3b68c10f..85a43e3e 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -6,6 +6,27 @@ if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher import comfy.multigpu +from nodes import VAELoader + + +class VAELoaderDevice(VAELoader): + NodeId = "VAELoaderDevice" + NodeName = "Load VAE MultiGPU" + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "vae_name": (cls.vae_list(), ), + "load_device": (comfy.multigpu.get_torch_device_list(), ), + } + } + + FUNCTION = "load_vae_device" + CATEGORY = "advanced/multigpu/loaders" + + def load_vae_device(self, vae_name, load_device: str): + device = comfy.multigpu.get_device_from_str(load_device) + return self.load_vae(vae_name, device) class MultiGPUWorkUnitsNode: """ @@ -76,7 +97,8 @@ class MultiGPUOptionsNode: node_list = [ MultiGPUWorkUnitsNode, - MultiGPUOptionsNode + MultiGPUOptionsNode, + VAELoaderDevice, ] NODE_CLASS_MAPPINGS = {} NODE_DISPLAY_NAME_MAPPINGS = {} diff --git a/nodes.py b/nodes.py index e2107c1d..cc74d733 100644 --- a/nodes.py +++ b/nodes.py @@ -763,13 +763,13 @@ class VAELoader: CATEGORY = "loaders" #TODO: scale factor? - def load_vae(self, vae_name): + def load_vae(self, vae_name, device=None): if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]: sd = self.load_taesd(vae_name) else: vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) sd = comfy.utils.load_torch_file(vae_path) - vae = comfy.sd.VAE(sd=sd) + vae = comfy.sd.VAE(sd=sd, device=device) vae.throw_exception_if_invalid() return (vae,)