diff --git a/comfy/utils.py b/comfy/utils.py index a826e41bf..be55ae3bf 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -19,6 +19,7 @@ import torch import math +import functools import struct import comfy.checkpoint_pickle import safetensors.torch @@ -46,6 +47,32 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in else: logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.") +def cache_model(func): + keys_count = {} + def wrapper(*args, **kw): + # the returned torch_file is previous cached + torch_file = func(*args, **kw) + sd = torch_file[0] if kw.get("return_metadata") else torch_file + hashed_key = functools._make_key(args, kw, typed=True) + count = keys_count.get(hashed_key) + if count is None: + keys_count[hashed_key] = len(sd.keys()) + count = len(sd.keys()) + # tensors on device are not released + if len(sd.keys()) == count: + logging.info(f"{func.__name__} with args: {args} returned cached torch file") + return torch_file + # tensors have been released and need to be loaded again + logging.info(f"{func.__name__} with args: {args} cache clear") + func.cache_clear() + torch_file = func(*args, **kw) + sd = torch_file[0] if kw.get("return_metadata") else torch_file + keys_count[hashed_key] = len(sd.keys()) + return torch_file + return wrapper + +@cache_model +@functools.lru_cache def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): if device is None: device = torch.device("cpu")