diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index ba505221..50b725b8 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -107,6 +107,10 @@ class ModelPatcher: for k in patch_list: if hasattr(patch_list[k], "to"): patch_list[k] = patch_list[k].to(device) + if "unet_wrapper_function" in self.model_options: + wrap_func = self.model_options["unet_wrapper_function"] + if hasattr(wrap_func, "to"): + self.model_options["unet_wrapper_function"] = wrap_func.to(device) def model_dtype(self): if hasattr(self.model, "get_dtype"):