mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-17 01:13:34 +00:00
Add VAELoaderDevice node to device what device to load VAE on
This commit is contained in:
parent
4879b47648
commit
e5396e98d8
@ -56,6 +56,23 @@ class GPUOptionsGroup:
|
||||
value['relative_speed'] /= min_speed
|
||||
model.model_options['multigpu_options'] = opts_dict
|
||||
|
||||
def get_torch_device_list():
|
||||
devices = ["default"]
|
||||
for device in comfy.model_management.get_all_torch_devices():
|
||||
device: torch.device
|
||||
devices.append(str(device.index))
|
||||
return devices
|
||||
|
||||
def get_device_from_str(device_str: str, throw_error_if_not_found=False):
|
||||
if device_str == "default":
|
||||
return comfy.model_management.get_torch_device()
|
||||
for device in comfy.model_management.get_all_torch_devices():
|
||||
device: torch.device
|
||||
if str(device.index) == device_str:
|
||||
return device
|
||||
if throw_error_if_not_found:
|
||||
raise Exception(f"Device with index '{device_str}' not found.")
|
||||
logging.warning(f"Device with index '{device_str}' not found, using default device ({comfy.model_management.get_torch_device()}) instead.")
|
||||
|
||||
def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False):
|
||||
'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.'
|
||||
|
@ -6,6 +6,27 @@ if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
import comfy.multigpu
|
||||
|
||||
from nodes import VAELoader
|
||||
|
||||
|
||||
class VAELoaderDevice(VAELoader):
|
||||
NodeId = "VAELoaderDevice"
|
||||
NodeName = "Load VAE MultiGPU"
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"vae_name": (cls.vae_list(), ),
|
||||
"load_device": (comfy.multigpu.get_torch_device_list(), ),
|
||||
}
|
||||
}
|
||||
|
||||
FUNCTION = "load_vae_device"
|
||||
CATEGORY = "advanced/multigpu/loaders"
|
||||
|
||||
def load_vae_device(self, vae_name, load_device: str):
|
||||
device = comfy.multigpu.get_device_from_str(load_device)
|
||||
return self.load_vae(vae_name, device)
|
||||
|
||||
class MultiGPUWorkUnitsNode:
|
||||
"""
|
||||
@ -76,7 +97,8 @@ class MultiGPUOptionsNode:
|
||||
|
||||
node_list = [
|
||||
MultiGPUWorkUnitsNode,
|
||||
MultiGPUOptionsNode
|
||||
MultiGPUOptionsNode,
|
||||
VAELoaderDevice,
|
||||
]
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||
|
4
nodes.py
4
nodes.py
@ -763,13 +763,13 @@ class VAELoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
#TODO: scale factor?
|
||||
def load_vae(self, vae_name):
|
||||
def load_vae(self, vae_name, device=None):
|
||||
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
|
||||
sd = self.load_taesd(vae_name)
|
||||
else:
|
||||
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||
sd = comfy.utils.load_torch_file(vae_path)
|
||||
vae = comfy.sd.VAE(sd=sd)
|
||||
vae = comfy.sd.VAE(sd=sd, device=device)
|
||||
vae.throw_exception_if_invalid()
|
||||
return (vae,)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user