diff --git a/comfy/model_management.py b/comfy/model_management.py index 3eeee67dc..90cf83d84 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -93,10 +93,9 @@ def get_total_memory(dev=None, torch_total_too=False): dev = get_torch_device() if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): - if os.path.isfile('/sys/fs/cgroup/memory/memory.limit_in_bytes'): - with open('/sys/fs/cgroup/memory/memory.limit_in_bytes', 'r') as f: - mem_total = int(f.read()) - mem_total_torch = mem_total + mem_total = get_containerd_memory_limit() + if mem_total > 0: + mem_total_torch = mem_total else: mem_total = psutil.virtual_memory().total mem_total_torch = mem_total @@ -656,12 +655,11 @@ def get_free_memory(dev=None, torch_free_too=False): dev = get_torch_device() if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): - if os.path.isfile('/sys/fs/cgroup/memory/memory.limit_in_bytes'): - with open('/sys/fs/cgroup/memory/memory.limit_in_bytes', 'r') as f: - mem_used_total = psutil.virtual_memory().used - mem_total = int(f.read()) - mem_free_total = mem_total - mem_used_total - mem_free_torch = mem_free_total + mem_total = get_containerd_memory_limit() + if mem_total > 0: + mem_used_total = psutil.virtual_memory().used + mem_free_total = mem_total - mem_used_total + mem_free_torch = mem_free_total else: mem_free_total = psutil.virtual_memory().available mem_free_torch = mem_free_total @@ -709,6 +707,13 @@ def is_device_mps(device): return True return False +def get_containerd_memory_limit(): + cgroup_memory_limit = '/sys/fs/cgroup/memory/memory.limit_in_bytes' + if os.path.isfile(cgroup_memory_limit): + with open(cgroup_memory_limit, 'r') as f: + return int(f.read()) + return 0 + def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): global directml_enabled