Merge branch 'generalize_fixes' of https://github.com/simonlui/ComfyUI

This commit is contained in:
comfyanonymous 2023-09-04 00:43:11 -04:00
commit 7746bdf7b0
2 changed files with 21 additions and 19 deletions

View File

@ -323,8 +323,7 @@ class CrossAttentionDoggettx(nn.Module):
break break
except model_management.OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION as e:
if first_op_done == False: if first_op_done == False:
torch.cuda.empty_cache() model_management.soft_empty_cache()
torch.cuda.ipc_collect()
if cleared_cache == False: if cleared_cache == False:
cleared_cache = True cleared_cache = True
print("out of memory error, emptying cache and trying again") print("out of memory error, emptying cache and trying again")

View File

@ -58,8 +58,15 @@ except:
if args.cpu: if args.cpu:
cpu_state = CPUState.CPU cpu_state = CPUState.CPU
def get_torch_device(): def is_intel_xpu():
global cpu_state
global xpu_available global xpu_available
if cpu_state == CPUState.GPU:
if xpu_available:
return True
return False
def get_torch_device():
global directml_enabled global directml_enabled
global cpu_state global cpu_state
if directml_enabled: if directml_enabled:
@ -70,13 +77,12 @@ def get_torch_device():
if cpu_state == CPUState.CPU: if cpu_state == CPUState.CPU:
return torch.device("cpu") return torch.device("cpu")
else: else:
if xpu_available: if is_intel_xpu():
return torch.device("xpu") return torch.device("xpu")
else: else:
return torch.device(torch.cuda.current_device()) return torch.device(torch.cuda.current_device())
def get_total_memory(dev=None, torch_total_too=False): def get_total_memory(dev=None, torch_total_too=False):
global xpu_available
global directml_enabled global directml_enabled
if dev is None: if dev is None:
dev = get_torch_device() dev = get_torch_device()
@ -88,7 +94,7 @@ def get_total_memory(dev=None, torch_total_too=False):
if directml_enabled: if directml_enabled:
mem_total = 1024 * 1024 * 1024 #TODO mem_total = 1024 * 1024 * 1024 #TODO
mem_total_torch = mem_total mem_total_torch = mem_total
elif xpu_available: elif is_intel_xpu():
stats = torch.xpu.memory_stats(dev) stats = torch.xpu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current']
mem_total = torch.xpu.get_device_properties(dev).total_memory mem_total = torch.xpu.get_device_properties(dev).total_memory
@ -146,11 +152,11 @@ def is_nvidia():
if cpu_state == CPUState.GPU: if cpu_state == CPUState.GPU:
if torch.version.cuda: if torch.version.cuda:
return True return True
return False
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
VAE_DTYPE = torch.float32 VAE_DTYPE = torch.float32
try: try:
if is_nvidia(): if is_nvidia():
torch_version = torch.version.__version__ torch_version = torch.version.__version__
@ -162,6 +168,9 @@ try:
except: except:
pass pass
if is_intel_xpu():
VAE_DTYPE = torch.bfloat16
if args.fp16_vae: if args.fp16_vae:
VAE_DTYPE = torch.float16 VAE_DTYPE = torch.float16
elif args.bf16_vae: elif args.bf16_vae:
@ -220,7 +229,6 @@ if DISABLE_SMART_MEMORY:
print("Disabling smart memory management") print("Disabling smart memory management")
def get_torch_device_name(device): def get_torch_device_name(device):
global xpu_available
if hasattr(device, 'type'): if hasattr(device, 'type'):
if device.type == "cuda": if device.type == "cuda":
try: try:
@ -230,7 +238,7 @@ def get_torch_device_name(device):
return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend) return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
else: else:
return "{}".format(device.type) return "{}".format(device.type)
elif xpu_available: elif is_intel_xpu():
return "{} {}".format(device, torch.xpu.get_device_name(device)) return "{} {}".format(device, torch.xpu.get_device_name(device))
else: else:
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
@ -260,7 +268,6 @@ class LoadedModel:
return self.model_memory() return self.model_memory()
def model_load(self, lowvram_model_memory=0): def model_load(self, lowvram_model_memory=0):
global xpu_available
patch_model_to = None patch_model_to = None
if lowvram_model_memory == 0: if lowvram_model_memory == 0:
patch_model_to = self.device patch_model_to = self.device
@ -281,7 +288,7 @@ class LoadedModel:
accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device) accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device)
self.model_accelerated = True self.model_accelerated = True
if xpu_available and not args.disable_ipex_optimize: if is_intel_xpu() and not args.disable_ipex_optimize:
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True) self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
return self.real_model return self.real_model
@ -471,12 +478,11 @@ def get_autocast_device(dev):
def xformers_enabled(): def xformers_enabled():
global xpu_available
global directml_enabled global directml_enabled
global cpu_state global cpu_state
if cpu_state != CPUState.GPU: if cpu_state != CPUState.GPU:
return False return False
if xpu_available: if is_intel_xpu():
return False return False
if directml_enabled: if directml_enabled:
return False return False
@ -503,7 +509,6 @@ def pytorch_attention_flash_attention():
return False return False
def get_free_memory(dev=None, torch_free_too=False): def get_free_memory(dev=None, torch_free_too=False):
global xpu_available
global directml_enabled global directml_enabled
if dev is None: if dev is None:
dev = get_torch_device() dev = get_torch_device()
@ -515,7 +520,7 @@ def get_free_memory(dev=None, torch_free_too=False):
if directml_enabled: if directml_enabled:
mem_free_total = 1024 * 1024 * 1024 #TODO mem_free_total = 1024 * 1024 * 1024 #TODO
mem_free_torch = mem_free_total mem_free_torch = mem_free_total
elif xpu_available: elif is_intel_xpu():
stats = torch.xpu.memory_stats(dev) stats = torch.xpu.memory_stats(dev)
mem_active = stats['active_bytes.all.current'] mem_active = stats['active_bytes.all.current']
mem_allocated = stats['allocated_bytes.all.current'] mem_allocated = stats['allocated_bytes.all.current']
@ -577,7 +582,6 @@ def is_device_mps(device):
return False return False
def should_use_fp16(device=None, model_params=0, prioritize_performance=True): def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
global xpu_available
global directml_enabled global directml_enabled
if device is not None: if device is not None:
@ -600,7 +604,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
if cpu_mode() or mps_mode(): if cpu_mode() or mps_mode():
return False #TODO ? return False #TODO ?
if xpu_available: if is_intel_xpu():
return True return True
if torch.cuda.is_bf16_supported(): if torch.cuda.is_bf16_supported():
@ -636,11 +640,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
return True return True
def soft_empty_cache(): def soft_empty_cache():
global xpu_available
global cpu_state global cpu_state
if cpu_state == CPUState.MPS: if cpu_state == CPUState.MPS:
torch.mps.empty_cache() torch.mps.empty_cache()
elif xpu_available: elif is_intel_xpu():
torch.xpu.empty_cache() torch.xpu.empty_cache()
elif torch.cuda.is_available(): elif torch.cuda.is_available():
if is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda if is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda