diff --git a/comfy/samplers.py b/comfy/samplers.py index 27d87570..b8b30f2c 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -471,7 +471,9 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t transformer_options["uuids"] = uuids[:] transformer_options["sigmas"] = timestep transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device) + transformer_options["multigpu_thread_device"] = device + cast_transformer_options(transformer_options, device=device) c['transformer_options'] = transformer_options if control is not None: @@ -1045,7 +1047,9 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): to_load_options = model_options.get("to_load_options", None) if to_load_options is None: return + cast_transformer_options(to_load_options, device, dtype) +def cast_transformer_options(transformer_options: dict[str], device=None, dtype=None): casts = [] if device is not None: casts.append(device) @@ -1054,18 +1058,17 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): # if nothing to apply, do nothing if len(casts) == 0: return - # try to call .to on patches - if "patches" in to_load_options: - patches = to_load_options["patches"] + if "patches" in transformer_options: + patches = transformer_options["patches"] for name in patches: patch_list = patches[name] for i in range(len(patch_list)): if hasattr(patch_list[i], "to"): for cast in casts: patch_list[i] = patch_list[i].to(cast) - if "patches_replace" in to_load_options: - patches = to_load_options["patches_replace"] + if "patches_replace" in transformer_options: + patches = transformer_options["patches_replace"] for name in patches: patch_list = patches[name] for k in patch_list: @@ -1075,8 +1078,8 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): # try to call .to on any wrappers/callbacks wrappers_and_callbacks = ["wrappers", "callbacks"] for wc_name in wrappers_and_callbacks: - if wc_name in to_load_options: - wc: dict[str, list] = to_load_options[wc_name] + if wc_name in transformer_options: + wc: dict[str, list] = transformer_options[wc_name] for wc_dict in wc.values(): for wc_list in wc_dict.values(): for i in range(len(wc_list)):