[feat]: add easy cache implementation for load_torch_file

This commit is contained in:
zhaoyingzhuo 2025-03-14 07:24:30 +00:00
parent 35504e2f93
commit 715d345170

View File

@ -19,6 +19,7 @@
import torch
import math
import functools
import struct
import comfy.checkpoint_pickle
import safetensors.torch
@ -46,6 +47,32 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
else:
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
def cache_model(func):
keys_count = {}
def wrapper(*args, **kw):
# the returned torch_file is previous cached
torch_file = func(*args, **kw)
sd = torch_file[0] if kw.get("return_metadata") else torch_file
hashed_key = functools._make_key(args, kw, typed=True)
count = keys_count.get(hashed_key)
if count is None:
keys_count[hashed_key] = len(sd.keys())
count = len(sd.keys())
# tensors on device are not released
if len(sd.keys()) == count:
logging.info(f"{func.__name__} with args: {args} returned cached torch file")
return torch_file
# tensors have been released and need to be loaded again
logging.info(f"{func.__name__} with args: {args} cache clear")
func.cache_clear()
torch_file = func(*args, **kw)
sd = torch_file[0] if kw.get("return_metadata") else torch_file
keys_count[hashed_key] = len(sd.keys())
return torch_file
return wrapper
@cache_model
@functools.lru_cache
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if device is None:
device = torch.device("cpu")