Let unet wrapper functions have .to attributes.

This commit is contained in:
comfyanonymous 2023-10-11 01:34:38 -04:00
parent 5e885bd9c8
commit 8cc75c64ff

View File

@ -107,6 +107,10 @@ class ModelPatcher:
for k in patch_list:
if hasattr(patch_list[k], "to"):
patch_list[k] = patch_list[k].to(device)
if "unet_wrapper_function" in self.model_options:
wrap_func = self.model_options["unet_wrapper_function"]
if hasattr(wrap_func, "to"):
self.model_options["unet_wrapper_function"] = wrap_func.to(device)
def model_dtype(self):
if hasattr(self.model, "get_dtype"):