diff --git a/comfy_api/torch_helpers/__init__.py b/comfy_api/torch_helpers/__init__.py new file mode 100644 index 000000000..be7ae7a61 --- /dev/null +++ b/comfy_api/torch_helpers/__init__.py @@ -0,0 +1,5 @@ +from .torch_compile import set_torch_compile_wrapper + +__all__ = [ + "set_torch_compile_wrapper", +] diff --git a/comfy_api/torch_helpers/torch_compile.py b/comfy_api/torch_helpers/torch_compile.py new file mode 100644 index 000000000..9223f58db --- /dev/null +++ b/comfy_api/torch_helpers/torch_compile.py @@ -0,0 +1,69 @@ +from __future__ import annotations +import torch + +import comfy.utils +from comfy.patcher_extension import WrappersMP +from typing import TYPE_CHECKING, Callable, Optional +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher + from comfy.patcher_extension import WrapperExecutor + + +COMPILE_KEY = "torch.compile" +TORCH_COMPILE_KWARGS = "torch_compile_kwargs" + + +def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable]) -> Callable: + ''' + Create a wrapper that will refer to the compiled_diffusion_model. + ''' + def apply_torch_compile_wrapper(executor: WrapperExecutor, *args, **kwargs): + try: + orig_modules = {} + for key, value in compiled_module_dict.items(): + orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key) + comfy.utils.set_attr(executor.class_obj, key, value) + return executor(*args, **kwargs) + finally: + for key, value in orig_modules.items(): + comfy.utils.set_attr(executor.class_obj, key, value) + return apply_torch_compile_wrapper + + +def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Optional[dict[str,str]]=None, + mode: Optional[str]=None, fullgraph=False, dynamic: Optional[bool]=None, + keys: list[str]=["diffusion_model"], *args, **kwargs): + ''' + Perform torch.compile that will be applied at sample time for either the whole model or specific params of the BaseModel instance. + + When keys is None, it will default to using ["diffusion_model"], compiling the whole diffusion_model. + When a list of keys is provided, it will perform torch.compile on only the selected modules. + ''' + # clear out any other torch.compile wrappers + model.remove_wrappers_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY) + # if no keys, default to 'diffusion_model' + if not keys: + keys = ["diffusion_model"] + # create kwargs dict that can be referenced later + compile_kwargs = { + "backend": backend, + "options": options, + "mode": mode, + "fullgraph": fullgraph, + "dynamic": dynamic, + } + # get a dict of compiled keys + compiled_modules = {} + for key in keys: + compiled_modules[key] = torch.compile( + model=model.get_model_object(key), + **compile_kwargs, + ) + # add torch.compile wrapper + wrapper_func = apply_torch_compile_factory( + compiled_module_dict=compiled_modules, + ) + # store wrapper to run on BaseModel's apply_model function + model.add_wrapper_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY, wrapper_func) + # keep compile kwargs for reference + model.model_options[TORCH_COMPILE_KWARGS] = compile_kwargs diff --git a/comfy_extras/nodes_torch_compile.py b/comfy_extras/nodes_torch_compile.py index 1fe6f42c7..605536678 100644 --- a/comfy_extras/nodes_torch_compile.py +++ b/comfy_extras/nodes_torch_compile.py @@ -1,4 +1,5 @@ -import torch +from comfy_api.torch_helpers import set_torch_compile_wrapper + class TorchCompileModel: @classmethod @@ -14,7 +15,7 @@ class TorchCompileModel: def patch(self, model, backend): m = model.clone() - m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend)) + set_torch_compile_wrapper(model=m, backend=backend) return (m, ) NODE_CLASS_MAPPINGS = {