mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 10:25:16 +00:00
Merge branch 'generalize_fixes' of https://github.com/simonlui/ComfyUI
This commit is contained in:
commit
7746bdf7b0
@ -323,8 +323,7 @@ class CrossAttentionDoggettx(nn.Module):
|
|||||||
break
|
break
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
if first_op_done == False:
|
if first_op_done == False:
|
||||||
torch.cuda.empty_cache()
|
model_management.soft_empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
|
||||||
if cleared_cache == False:
|
if cleared_cache == False:
|
||||||
cleared_cache = True
|
cleared_cache = True
|
||||||
print("out of memory error, emptying cache and trying again")
|
print("out of memory error, emptying cache and trying again")
|
||||||
|
@ -58,8 +58,15 @@ except:
|
|||||||
if args.cpu:
|
if args.cpu:
|
||||||
cpu_state = CPUState.CPU
|
cpu_state = CPUState.CPU
|
||||||
|
|
||||||
def get_torch_device():
|
def is_intel_xpu():
|
||||||
|
global cpu_state
|
||||||
global xpu_available
|
global xpu_available
|
||||||
|
if cpu_state == CPUState.GPU:
|
||||||
|
if xpu_available:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_torch_device():
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
global cpu_state
|
global cpu_state
|
||||||
if directml_enabled:
|
if directml_enabled:
|
||||||
@ -70,13 +77,12 @@ def get_torch_device():
|
|||||||
if cpu_state == CPUState.CPU:
|
if cpu_state == CPUState.CPU:
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
else:
|
else:
|
||||||
if xpu_available:
|
if is_intel_xpu():
|
||||||
return torch.device("xpu")
|
return torch.device("xpu")
|
||||||
else:
|
else:
|
||||||
return torch.device(torch.cuda.current_device())
|
return torch.device(torch.cuda.current_device())
|
||||||
|
|
||||||
def get_total_memory(dev=None, torch_total_too=False):
|
def get_total_memory(dev=None, torch_total_too=False):
|
||||||
global xpu_available
|
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
if dev is None:
|
if dev is None:
|
||||||
dev = get_torch_device()
|
dev = get_torch_device()
|
||||||
@ -88,7 +94,7 @@ def get_total_memory(dev=None, torch_total_too=False):
|
|||||||
if directml_enabled:
|
if directml_enabled:
|
||||||
mem_total = 1024 * 1024 * 1024 #TODO
|
mem_total = 1024 * 1024 * 1024 #TODO
|
||||||
mem_total_torch = mem_total
|
mem_total_torch = mem_total
|
||||||
elif xpu_available:
|
elif is_intel_xpu():
|
||||||
stats = torch.xpu.memory_stats(dev)
|
stats = torch.xpu.memory_stats(dev)
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
||||||
@ -146,11 +152,11 @@ def is_nvidia():
|
|||||||
if cpu_state == CPUState.GPU:
|
if cpu_state == CPUState.GPU:
|
||||||
if torch.version.cuda:
|
if torch.version.cuda:
|
||||||
return True
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
|
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
|
||||||
VAE_DTYPE = torch.float32
|
VAE_DTYPE = torch.float32
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if is_nvidia():
|
if is_nvidia():
|
||||||
torch_version = torch.version.__version__
|
torch_version = torch.version.__version__
|
||||||
@ -162,6 +168,9 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if is_intel_xpu():
|
||||||
|
VAE_DTYPE = torch.bfloat16
|
||||||
|
|
||||||
if args.fp16_vae:
|
if args.fp16_vae:
|
||||||
VAE_DTYPE = torch.float16
|
VAE_DTYPE = torch.float16
|
||||||
elif args.bf16_vae:
|
elif args.bf16_vae:
|
||||||
@ -220,7 +229,6 @@ if DISABLE_SMART_MEMORY:
|
|||||||
print("Disabling smart memory management")
|
print("Disabling smart memory management")
|
||||||
|
|
||||||
def get_torch_device_name(device):
|
def get_torch_device_name(device):
|
||||||
global xpu_available
|
|
||||||
if hasattr(device, 'type'):
|
if hasattr(device, 'type'):
|
||||||
if device.type == "cuda":
|
if device.type == "cuda":
|
||||||
try:
|
try:
|
||||||
@ -230,7 +238,7 @@ def get_torch_device_name(device):
|
|||||||
return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
|
return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
|
||||||
else:
|
else:
|
||||||
return "{}".format(device.type)
|
return "{}".format(device.type)
|
||||||
elif xpu_available:
|
elif is_intel_xpu():
|
||||||
return "{} {}".format(device, torch.xpu.get_device_name(device))
|
return "{} {}".format(device, torch.xpu.get_device_name(device))
|
||||||
else:
|
else:
|
||||||
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
||||||
@ -260,7 +268,6 @@ class LoadedModel:
|
|||||||
return self.model_memory()
|
return self.model_memory()
|
||||||
|
|
||||||
def model_load(self, lowvram_model_memory=0):
|
def model_load(self, lowvram_model_memory=0):
|
||||||
global xpu_available
|
|
||||||
patch_model_to = None
|
patch_model_to = None
|
||||||
if lowvram_model_memory == 0:
|
if lowvram_model_memory == 0:
|
||||||
patch_model_to = self.device
|
patch_model_to = self.device
|
||||||
@ -281,7 +288,7 @@ class LoadedModel:
|
|||||||
accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device)
|
accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device)
|
||||||
self.model_accelerated = True
|
self.model_accelerated = True
|
||||||
|
|
||||||
if xpu_available and not args.disable_ipex_optimize:
|
if is_intel_xpu() and not args.disable_ipex_optimize:
|
||||||
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
|
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
|
||||||
|
|
||||||
return self.real_model
|
return self.real_model
|
||||||
@ -471,12 +478,11 @@ def get_autocast_device(dev):
|
|||||||
|
|
||||||
|
|
||||||
def xformers_enabled():
|
def xformers_enabled():
|
||||||
global xpu_available
|
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
global cpu_state
|
global cpu_state
|
||||||
if cpu_state != CPUState.GPU:
|
if cpu_state != CPUState.GPU:
|
||||||
return False
|
return False
|
||||||
if xpu_available:
|
if is_intel_xpu():
|
||||||
return False
|
return False
|
||||||
if directml_enabled:
|
if directml_enabled:
|
||||||
return False
|
return False
|
||||||
@ -503,7 +509,6 @@ def pytorch_attention_flash_attention():
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def get_free_memory(dev=None, torch_free_too=False):
|
def get_free_memory(dev=None, torch_free_too=False):
|
||||||
global xpu_available
|
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
if dev is None:
|
if dev is None:
|
||||||
dev = get_torch_device()
|
dev = get_torch_device()
|
||||||
@ -515,7 +520,7 @@ def get_free_memory(dev=None, torch_free_too=False):
|
|||||||
if directml_enabled:
|
if directml_enabled:
|
||||||
mem_free_total = 1024 * 1024 * 1024 #TODO
|
mem_free_total = 1024 * 1024 * 1024 #TODO
|
||||||
mem_free_torch = mem_free_total
|
mem_free_torch = mem_free_total
|
||||||
elif xpu_available:
|
elif is_intel_xpu():
|
||||||
stats = torch.xpu.memory_stats(dev)
|
stats = torch.xpu.memory_stats(dev)
|
||||||
mem_active = stats['active_bytes.all.current']
|
mem_active = stats['active_bytes.all.current']
|
||||||
mem_allocated = stats['allocated_bytes.all.current']
|
mem_allocated = stats['allocated_bytes.all.current']
|
||||||
@ -577,7 +582,6 @@ def is_device_mps(device):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
|
def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
|
||||||
global xpu_available
|
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
|
|
||||||
if device is not None:
|
if device is not None:
|
||||||
@ -600,7 +604,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
|
|||||||
if cpu_mode() or mps_mode():
|
if cpu_mode() or mps_mode():
|
||||||
return False #TODO ?
|
return False #TODO ?
|
||||||
|
|
||||||
if xpu_available:
|
if is_intel_xpu():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if torch.cuda.is_bf16_supported():
|
if torch.cuda.is_bf16_supported():
|
||||||
@ -636,11 +640,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def soft_empty_cache():
|
def soft_empty_cache():
|
||||||
global xpu_available
|
|
||||||
global cpu_state
|
global cpu_state
|
||||||
if cpu_state == CPUState.MPS:
|
if cpu_state == CPUState.MPS:
|
||||||
torch.mps.empty_cache()
|
torch.mps.empty_cache()
|
||||||
elif xpu_available:
|
elif is_intel_xpu():
|
||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
elif torch.cuda.is_available():
|
elif torch.cuda.is_available():
|
||||||
if is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda
|
if is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda
|
||||||
|
Loading…
Reference in New Issue
Block a user