mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-02 01:22:11 +08:00

* Make torch compile node use wrapper instead of object_patch for the entire diffusion_models object, allowing key assotiations on diffusion_models to not break (loras, getting attributes, etc.) * Moved torch compile code into comfy_api so it can be used by custom nodes with a degree of confidence * Refactor set_torch_compile_wrapper to support a list of keys instead of just diffusion_model, as well as additional torch.compile args * remove unused import * Moved torch compile kwargs to be stored in model_options instead of attachments; attachments are more intended for things to be 'persisted', AKA not deepcopied * Add some comments * Remove random line of code, not sure how it got there
70 lines
2.7 KiB
Python
70 lines
2.7 KiB
Python
from __future__ import annotations
|
|
import torch
|
|
|
|
import comfy.utils
|
|
from comfy.patcher_extension import WrappersMP
|
|
from typing import TYPE_CHECKING, Callable, Optional
|
|
if TYPE_CHECKING:
|
|
from comfy.model_patcher import ModelPatcher
|
|
from comfy.patcher_extension import WrapperExecutor
|
|
|
|
|
|
COMPILE_KEY = "torch.compile"
|
|
TORCH_COMPILE_KWARGS = "torch_compile_kwargs"
|
|
|
|
|
|
def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable]) -> Callable:
|
|
'''
|
|
Create a wrapper that will refer to the compiled_diffusion_model.
|
|
'''
|
|
def apply_torch_compile_wrapper(executor: WrapperExecutor, *args, **kwargs):
|
|
try:
|
|
orig_modules = {}
|
|
for key, value in compiled_module_dict.items():
|
|
orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key)
|
|
comfy.utils.set_attr(executor.class_obj, key, value)
|
|
return executor(*args, **kwargs)
|
|
finally:
|
|
for key, value in orig_modules.items():
|
|
comfy.utils.set_attr(executor.class_obj, key, value)
|
|
return apply_torch_compile_wrapper
|
|
|
|
|
|
def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Optional[dict[str,str]]=None,
|
|
mode: Optional[str]=None, fullgraph=False, dynamic: Optional[bool]=None,
|
|
keys: list[str]=["diffusion_model"], *args, **kwargs):
|
|
'''
|
|
Perform torch.compile that will be applied at sample time for either the whole model or specific params of the BaseModel instance.
|
|
|
|
When keys is None, it will default to using ["diffusion_model"], compiling the whole diffusion_model.
|
|
When a list of keys is provided, it will perform torch.compile on only the selected modules.
|
|
'''
|
|
# clear out any other torch.compile wrappers
|
|
model.remove_wrappers_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY)
|
|
# if no keys, default to 'diffusion_model'
|
|
if not keys:
|
|
keys = ["diffusion_model"]
|
|
# create kwargs dict that can be referenced later
|
|
compile_kwargs = {
|
|
"backend": backend,
|
|
"options": options,
|
|
"mode": mode,
|
|
"fullgraph": fullgraph,
|
|
"dynamic": dynamic,
|
|
}
|
|
# get a dict of compiled keys
|
|
compiled_modules = {}
|
|
for key in keys:
|
|
compiled_modules[key] = torch.compile(
|
|
model=model.get_model_object(key),
|
|
**compile_kwargs,
|
|
)
|
|
# add torch.compile wrapper
|
|
wrapper_func = apply_torch_compile_factory(
|
|
compiled_module_dict=compiled_modules,
|
|
)
|
|
# store wrapper to run on BaseModel's apply_model function
|
|
model.add_wrapper_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY, wrapper_func)
|
|
# keep compile kwargs for reference
|
|
model.model_options[TORCH_COMPILE_KWARGS] = compile_kwargs
|