mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-12 18:33:35 +00:00
[feat]: add easy cache implementation for load_torch_file
This commit is contained in:
parent
35504e2f93
commit
715d345170
@ -19,6 +19,7 @@
|
||||
|
||||
import torch
|
||||
import math
|
||||
import functools
|
||||
import struct
|
||||
import comfy.checkpoint_pickle
|
||||
import safetensors.torch
|
||||
@ -46,6 +47,32 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
|
||||
else:
|
||||
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
||||
|
||||
def cache_model(func):
|
||||
keys_count = {}
|
||||
def wrapper(*args, **kw):
|
||||
# the returned torch_file is previous cached
|
||||
torch_file = func(*args, **kw)
|
||||
sd = torch_file[0] if kw.get("return_metadata") else torch_file
|
||||
hashed_key = functools._make_key(args, kw, typed=True)
|
||||
count = keys_count.get(hashed_key)
|
||||
if count is None:
|
||||
keys_count[hashed_key] = len(sd.keys())
|
||||
count = len(sd.keys())
|
||||
# tensors on device are not released
|
||||
if len(sd.keys()) == count:
|
||||
logging.info(f"{func.__name__} with args: {args} returned cached torch file")
|
||||
return torch_file
|
||||
# tensors have been released and need to be loaded again
|
||||
logging.info(f"{func.__name__} with args: {args} cache clear")
|
||||
func.cache_clear()
|
||||
torch_file = func(*args, **kw)
|
||||
sd = torch_file[0] if kw.get("return_metadata") else torch_file
|
||||
keys_count[hashed_key] = len(sd.keys())
|
||||
return torch_file
|
||||
return wrapper
|
||||
|
||||
@cache_model
|
||||
@functools.lru_cache
|
||||
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
|
Loading…
Reference in New Issue
Block a user