mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-13 15:03:33 +00:00
[feat]: add easy cache implementation for load_torch_file
This commit is contained in:
parent
35504e2f93
commit
715d345170
@ -19,6 +19,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
|
import functools
|
||||||
import struct
|
import struct
|
||||||
import comfy.checkpoint_pickle
|
import comfy.checkpoint_pickle
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
@ -46,6 +47,32 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
|
|||||||
else:
|
else:
|
||||||
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
||||||
|
|
||||||
|
def cache_model(func):
|
||||||
|
keys_count = {}
|
||||||
|
def wrapper(*args, **kw):
|
||||||
|
# the returned torch_file is previous cached
|
||||||
|
torch_file = func(*args, **kw)
|
||||||
|
sd = torch_file[0] if kw.get("return_metadata") else torch_file
|
||||||
|
hashed_key = functools._make_key(args, kw, typed=True)
|
||||||
|
count = keys_count.get(hashed_key)
|
||||||
|
if count is None:
|
||||||
|
keys_count[hashed_key] = len(sd.keys())
|
||||||
|
count = len(sd.keys())
|
||||||
|
# tensors on device are not released
|
||||||
|
if len(sd.keys()) == count:
|
||||||
|
logging.info(f"{func.__name__} with args: {args} returned cached torch file")
|
||||||
|
return torch_file
|
||||||
|
# tensors have been released and need to be loaded again
|
||||||
|
logging.info(f"{func.__name__} with args: {args} cache clear")
|
||||||
|
func.cache_clear()
|
||||||
|
torch_file = func(*args, **kw)
|
||||||
|
sd = torch_file[0] if kw.get("return_metadata") else torch_file
|
||||||
|
keys_count[hashed_key] = len(sd.keys())
|
||||||
|
return torch_file
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
@cache_model
|
||||||
|
@functools.lru_cache
|
||||||
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||||
if device is None:
|
if device is None:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
Loading…
Reference in New Issue
Block a user