Compare commits

...

3 Commits

Author SHA1 Message Date
jiangxuefeng
6df3471d1d
Merge a47bb4cc0c into 98bdca4cb2 2025-04-10 16:28:35 -04:00
Chenlei Hu
98bdca4cb2
Deprecate InputTypeOptions.defaultInput (#7551)
* Deprecate InputTypeOptions.defaultInput

* nit

* nit
2025-04-10 06:57:06 -04:00
FE-xiaoJiang
a47bb4cc0c Add disable_mmap arg in method load_torch_file
(cherry picked from commit 2ffcc72a5fdf86e5e9340c9d7b86d18f638e64d1)
2025-03-07 10:24:40 +08:00
3 changed files with 26 additions and 8 deletions

View File

@ -137,6 +137,7 @@ class PerformanceFeature(enum.Enum):
Fp8MatrixMultiplication = "fp8_matrix_mult" Fp8MatrixMultiplication = "fp8_matrix_mult"
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult") parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult")
parser.add_argument("--disable-mmap", action="store_true", help="When load .safetensors or .sft model sometimes.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")

View File

@ -102,9 +102,13 @@ class InputTypeOptions(TypedDict):
default: bool | str | float | int | list | tuple default: bool | str | float | int | list | tuple
"""The default value of the widget""" """The default value of the widget"""
defaultInput: bool defaultInput: bool
"""Defaults to an input slot rather than a widget""" """@deprecated in v1.16 frontend. v1.16 frontend allows input socket and widget to co-exist.
- defaultInput on required inputs should be dropped.
- defaultInput on optional inputs should be replaced with forceInput.
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3364
"""
forceInput: bool forceInput: bool
"""`defaultInput` and also don't allow converting to a widget""" """Forces the input to be an input slot rather than a widget even a widget is available for the input type."""
lazy: bool lazy: bool
"""Declares that this input uses lazy evaluation""" """Declares that this input uses lazy evaluation"""
rawLink: bool rawLink: bool

View File

@ -28,6 +28,7 @@ import logging
import itertools import itertools
from torch.nn.functional import interpolate from torch.nn.functional import interpolate
from einops import rearrange from einops import rearrange
from comfy.cli_args import args
ALWAYS_SAFE_LOAD = False ALWAYS_SAFE_LOAD = False
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
@ -46,12 +47,24 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
else: else:
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.") logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False, disable_mmap=None):
if device is None: if device is None:
device = torch.device("cpu") device = torch.device("cpu")
metadata = None metadata = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
try: try:
if disable_mmap is None:
disable_mmap_decision = args.disable_mmap
else:
disable_mmap_decision = True
if disable_mmap_decision:
pl_sd = safetensors.torch.load(open(ckpt, 'rb').read())
sd = {k: v.to(device) for k, v in pl_sd.items()}
if return_metadata:
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
metadata = f.metadata()
else:
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
sd = {} sd = {}
for k in f.keys(): for k in f.keys():