This commit is contained in:
Ethan Yang 2025-04-11 09:46:27 -04:00 committed by GitHub
commit 5afbcc9309
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 58 additions and 7 deletions

View File

@ -498,12 +498,20 @@ class ModelPatcher:
key = k[0] key = k[0]
if len(k) > 2: if len(k) > 2:
function = k[2] function = k[2]
org_key=key.replace("diffusion_model", "diffusion_model._orig_mod")
if key in model_sd: if key in model_sd:
p.add(k) p.add(k)
current_patches = self.patches.get(key, []) current_patches = self.patches.get(key, [])
current_patches.append((strength_patch, patches[k], strength_model, offset, function)) current_patches.append((strength_patch, patches[k], strength_model, offset, function))
self.patches[key] = current_patches self.patches[key] = current_patches
self.patches[org_key] = current_patches
elif org_key in model_sd:
if key in self.patches:
self.patches.pop(key)
p.add(k)
current_patches = self.patches.get(org_key, [])
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
self.patches[org_key] = current_patches
self.patches_uuid = uuid.uuid4() self.patches_uuid = uuid.uuid4()
return list(p) return list(p)

View File

@ -1,21 +1,64 @@
import torch import torch
import importlib
class TorchCompileModel: class TorchCompileModel:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",), if importlib.util.find_spec("openvino") is not None:
"backend": (["inductor", "cudagraphs"],), import openvino as ov
}}
core = ov.Core()
available_devices = core.available_devices
else:
available_devices = []
return {
"required": {
"model": ("MODEL",),
"backend": (["inductor", "cudagraphs", "openvino"],),
},
"optional": {
"openvino_device": (available_devices,),
},
}
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "patch" FUNCTION = "patch"
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
EXPERIMENTAL = True EXPERIMENTAL = True
def patch(self, model, backend): def patch(self, model, backend, openvino_device):
print(model.__class__.__name__)
if backend == "openvino":
options = {"device": openvino_device}
try:
import openvino.torch
except ImportError:
raise ImportError(
"Could not import openvino python package. "
"Please install it with `pip install openvino`."
)
import openvino.frontend.pytorch.torchdynamo.execute as ov_ex
torch._dynamo.reset()
ov_ex.compiled_cache.clear()
ov_ex.req_cache.clear()
ov_ex.partitioned_modules.clear()
else:
options = None
m = model.clone() m = model.clone()
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend)) m.add_object_patch(
return (m, ) "diffusion_model",
torch.compile(
model=m.get_model_object("diffusion_model"),
backend=backend,
options=options,
),
)
return (m,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"TorchCompileModel": TorchCompileModel, "TorchCompileModel": TorchCompileModel,