Add VAELoaderDevice node to device what device to load VAE on

This commit is contained in:
Jedrzej Kosinski 2025-03-21 14:57:05 -05:00
parent 4879b47648
commit e5396e98d8
3 changed files with 42 additions and 3 deletions

View File

@ -56,6 +56,23 @@ class GPUOptionsGroup:
value['relative_speed'] /= min_speed
model.model_options['multigpu_options'] = opts_dict
def get_torch_device_list():
devices = ["default"]
for device in comfy.model_management.get_all_torch_devices():
device: torch.device
devices.append(str(device.index))
return devices
def get_device_from_str(device_str: str, throw_error_if_not_found=False):
if device_str == "default":
return comfy.model_management.get_torch_device()
for device in comfy.model_management.get_all_torch_devices():
device: torch.device
if str(device.index) == device_str:
return device
if throw_error_if_not_found:
raise Exception(f"Device with index '{device_str}' not found.")
logging.warning(f"Device with index '{device_str}' not found, using default device ({comfy.model_management.get_torch_device()}) instead.")
def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False):
'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.'

View File

@ -6,6 +6,27 @@ if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
import comfy.multigpu
from nodes import VAELoader
class VAELoaderDevice(VAELoader):
NodeId = "VAELoaderDevice"
NodeName = "Load VAE MultiGPU"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"vae_name": (cls.vae_list(), ),
"load_device": (comfy.multigpu.get_torch_device_list(), ),
}
}
FUNCTION = "load_vae_device"
CATEGORY = "advanced/multigpu/loaders"
def load_vae_device(self, vae_name, load_device: str):
device = comfy.multigpu.get_device_from_str(load_device)
return self.load_vae(vae_name, device)
class MultiGPUWorkUnitsNode:
"""
@ -76,7 +97,8 @@ class MultiGPUOptionsNode:
node_list = [
MultiGPUWorkUnitsNode,
MultiGPUOptionsNode
MultiGPUOptionsNode,
VAELoaderDevice,
]
NODE_CLASS_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS = {}

View File

@ -763,13 +763,13 @@ class VAELoader:
CATEGORY = "loaders"
#TODO: scale factor?
def load_vae(self, vae_name):
def load_vae(self, vae_name, device=None):
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
sd = self.load_taesd(vae_name)
else:
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
vae = comfy.sd.VAE(sd=sd, device=device)
vae.throw_exception_if_invalid()
return (vae,)