[feat] upgrade

This commit is contained in:
admin 2024-02-12 02:17:30 +08:00
parent eaf3e70556
commit 777ee90206

View File

@ -93,10 +93,9 @@ def get_total_memory(dev=None, torch_total_too=False):
dev = get_torch_device()
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
if os.path.isfile('/sys/fs/cgroup/memory/memory.limit_in_bytes'):
with open('/sys/fs/cgroup/memory/memory.limit_in_bytes', 'r') as f:
mem_total = int(f.read())
mem_total_torch = mem_total
mem_total = get_containerd_memory_limit()
if mem_total > 0:
mem_total_torch = mem_total
else:
mem_total = psutil.virtual_memory().total
mem_total_torch = mem_total
@ -656,12 +655,11 @@ def get_free_memory(dev=None, torch_free_too=False):
dev = get_torch_device()
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
if os.path.isfile('/sys/fs/cgroup/memory/memory.limit_in_bytes'):
with open('/sys/fs/cgroup/memory/memory.limit_in_bytes', 'r') as f:
mem_used_total = psutil.virtual_memory().used
mem_total = int(f.read())
mem_free_total = mem_total - mem_used_total
mem_free_torch = mem_free_total
mem_total = get_containerd_memory_limit()
if mem_total > 0:
mem_used_total = psutil.virtual_memory().used
mem_free_total = mem_total - mem_used_total
mem_free_torch = mem_free_total
else:
mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total
@ -709,6 +707,13 @@ def is_device_mps(device):
return True
return False
def get_containerd_memory_limit():
cgroup_memory_limit = '/sys/fs/cgroup/memory/memory.limit_in_bytes'
if os.path.isfile(cgroup_memory_limit):
with open(cgroup_memory_limit, 'r') as f:
return int(f.read())
return 0
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
global directml_enabled