From a47bb4cc0c06cabf58d19d5fefa208d7803a2d11 Mon Sep 17 00:00:00 2001 From: FE-xiaoJiang <3401384168@qq.com> Date: Wed, 26 Feb 2025 22:40:52 +0800 Subject: [PATCH] Add disable_mmap arg in method load_torch_file (cherry picked from commit 2ffcc72a5fdf86e5e9340c9d7b86d18f638e64d1) --- comfy/cli_args.py | 1 + comfy/utils.py | 25 +++++++++++++++++++------ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index a864205be..bf0328ab1 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -135,6 +135,7 @@ class PerformanceFeature(enum.Enum): Fp8MatrixMultiplication = "fp8_matrix_mult" parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult") +parser.add_argument("--disable-mmap", action="store_true", help="When load .safetensors or .sft model sometimes.") parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") diff --git a/comfy/utils.py b/comfy/utils.py index a826e41bf..790ec6b52 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -28,6 +28,7 @@ import logging import itertools from torch.nn.functional import interpolate from einops import rearrange +from comfy.cli_args import args ALWAYS_SAFE_LOAD = False if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated @@ -46,18 +47,30 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in else: logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.") -def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): +def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False, disable_mmap=None): if device is None: device = torch.device("cpu") metadata = None if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): try: - with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: - sd = {} - for k in f.keys(): - sd[k] = f.get_tensor(k) + if disable_mmap is None: + disable_mmap_decision = args.disable_mmap + else: + disable_mmap_decision = True + + if disable_mmap_decision: + pl_sd = safetensors.torch.load(open(ckpt, 'rb').read()) + sd = {k: v.to(device) for k, v in pl_sd.items()} if return_metadata: - metadata = f.metadata() + with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: + metadata = f.metadata() + else: + with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: + sd = {} + for k in f.keys(): + sd[k] = f.get_tensor(k) + if return_metadata: + metadata = f.metadata() except Exception as e: if len(e.args) > 0: message = e.args[0]