Compare commits

...

9 Commits

Author SHA1 Message Date
Ethan Yang
d4caee69fc
Merge e14b8dfec5 into 98bdca4cb2 2025-04-10 15:51:47 +04:00
Ethan Yang
e14b8dfec5
Update nodes_torch_compile.py 2025-03-07 13:23:19 +08:00
ethan
bc9eb9dfdb fix the memory leakage issue 2025-03-05 18:28:47 -08:00
Ethan Yang
8558803f44
Merge pull request #2 from comfyanonymous/master
rebase
2025-03-03 14:36:03 +08:00
ethan
ebfe7a5679 fix the issue for model first inference with lora 2025-02-10 19:54:44 -08:00
ethan
77e9294c08 add Query Device 2025-01-30 00:20:58 -08:00
ethan
317af7201f remove history commit
remove history commit

remove history commit
2025-01-29 07:11:24 -08:00
ethan
d1f61cca5e add openvino to torch compile 2025-01-29 07:03:31 -08:00
ethan
33e71e0e79 update openvino backend 2025-01-24 01:37:44 -08:00
2 changed files with 58 additions and 7 deletions

View File

@ -498,12 +498,20 @@ class ModelPatcher:
key = k[0]
if len(k) > 2:
function = k[2]
org_key=key.replace("diffusion_model", "diffusion_model._orig_mod")
if key in model_sd:
p.add(k)
current_patches = self.patches.get(key, [])
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
self.patches[key] = current_patches
self.patches[org_key] = current_patches
elif org_key in model_sd:
if key in self.patches:
self.patches.pop(key)
p.add(k)
current_patches = self.patches.get(org_key, [])
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
self.patches[org_key] = current_patches
self.patches_uuid = uuid.uuid4()
return list(p)

View File

@ -1,21 +1,64 @@
import torch
import importlib
class TorchCompileModel:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"backend": (["inductor", "cudagraphs"],),
}}
if importlib.util.find_spec("openvino") is not None:
import openvino as ov
core = ov.Core()
available_devices = core.available_devices
else:
available_devices = []
return {
"required": {
"model": ("MODEL",),
"backend": (["inductor", "cudagraphs", "openvino"],),
},
"optional": {
"openvino_device": (available_devices,),
},
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
EXPERIMENTAL = True
def patch(self, model, backend):
def patch(self, model, backend, openvino_device):
print(model.__class__.__name__)
if backend == "openvino":
options = {"device": openvino_device}
try:
import openvino.torch
except ImportError:
raise ImportError(
"Could not import openvino python package. "
"Please install it with `pip install openvino`."
)
import openvino.frontend.pytorch.torchdynamo.execute as ov_ex
torch._dynamo.reset()
ov_ex.compiled_cache.clear()
ov_ex.req_cache.clear()
ov_ex.partitioned_modules.clear()
else:
options = None
m = model.clone()
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend))
return (m, )
m.add_object_patch(
"diffusion_model",
torch.compile(
model=m.get_model_object("diffusion_model"),
backend=backend,
options=options,
),
)
return (m,)
NODE_CLASS_MAPPINGS = {
"TorchCompileModel": TorchCompileModel,