From 4a0c4ce4ef3c1e0f2b777dcd20a8864be1420f19 Mon Sep 17 00:00:00 2001 From: Simon Lui <502929+simonlui@users.noreply.github.com> Date: Sat, 2 Sep 2023 18:22:10 -0700 Subject: [PATCH 1/2] Some fixes to generalize CUDA specific functionality to Intel or other GPUs. --- comfy/ldm/modules/attention.py | 3 +- comfy/ldm/modules/diffusionmodules/util.py | 24 ++++++++++---- comfy/model_management.py | 37 ++++++++++++---------- 3 files changed, 38 insertions(+), 26 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 9fdfbd21..8f953d33 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -323,8 +323,7 @@ class CrossAttentionDoggettx(nn.Module): break except model_management.OOM_EXCEPTION as e: if first_op_done == False: - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + model_management.soft_empty_cache() if cleared_cache == False: cleared_cache = True print("out of memory error, emptying cache and trying again") diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index d890c804..9d07d935 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -15,6 +15,7 @@ import torch.nn as nn import numpy as np from einops import repeat +from comfy import model_management from comfy.ldm.util import instantiate_from_config import comfy.ops @@ -139,13 +140,22 @@ class CheckpointFunction(torch.autograd.Function): @staticmethod def backward(ctx, *output_grads): ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] - with torch.enable_grad(), \ - torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): - # Fixes a bug where the first op in run_function modifies the - # Tensor storage in place, which is not allowed for detach()'d - # Tensors. - shallow_copies = [x.view_as(x) for x in ctx.input_tensors] - output_tensors = ctx.run_function(*shallow_copies) + if model_management.is_nvidia(): + with torch.enable_grad(), \ + torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + elif model_management.is_intel_xpu(): + with torch.enable_grad(), \ + torch.xpu.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) input_grads = torch.autograd.grad( output_tensors, ctx.input_tensors + ctx.input_params, diff --git a/comfy/model_management.py b/comfy/model_management.py index aca8af99..bdbbbd84 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -58,8 +58,15 @@ except: if args.cpu: cpu_state = CPUState.CPU -def get_torch_device(): +def is_intel_xpu(): + global cpu_state global xpu_available + if cpu_state == CPUState.GPU: + if xpu_available: + return True + return False + +def get_torch_device(): global directml_enabled global cpu_state if directml_enabled: @@ -70,13 +77,12 @@ def get_torch_device(): if cpu_state == CPUState.CPU: return torch.device("cpu") else: - if xpu_available: + if is_intel_xpu(): return torch.device("xpu") else: return torch.device(torch.cuda.current_device()) def get_total_memory(dev=None, torch_total_too=False): - global xpu_available global directml_enabled if dev is None: dev = get_torch_device() @@ -88,7 +94,7 @@ def get_total_memory(dev=None, torch_total_too=False): if directml_enabled: mem_total = 1024 * 1024 * 1024 #TODO mem_total_torch = mem_total - elif xpu_available: + elif is_intel_xpu(): stats = torch.xpu.memory_stats(dev) mem_reserved = stats['reserved_bytes.all.current'] mem_total = torch.xpu.get_device_properties(dev).total_memory @@ -146,11 +152,11 @@ def is_nvidia(): if cpu_state == CPUState.GPU: if torch.version.cuda: return True + return False ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention VAE_DTYPE = torch.float32 - try: if is_nvidia(): torch_version = torch.version.__version__ @@ -162,6 +168,9 @@ try: except: pass +if is_intel_xpu(): + VAE_DTYPE = torch.bfloat16 + if args.fp16_vae: VAE_DTYPE = torch.float16 elif args.bf16_vae: @@ -220,7 +229,6 @@ if DISABLE_SMART_MEMORY: print("Disabling smart memory management") def get_torch_device_name(device): - global xpu_available if hasattr(device, 'type'): if device.type == "cuda": try: @@ -230,7 +238,7 @@ def get_torch_device_name(device): return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend) else: return "{}".format(device.type) - elif xpu_available: + elif is_intel_xpu(): return "{} {}".format(device, torch.xpu.get_device_name(device)) else: return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) @@ -260,7 +268,6 @@ class LoadedModel: return self.model_memory() def model_load(self, lowvram_model_memory=0): - global xpu_available patch_model_to = None if lowvram_model_memory == 0: patch_model_to = self.device @@ -281,7 +288,7 @@ class LoadedModel: accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device) self.model_accelerated = True - if xpu_available and not args.disable_ipex_optimize: + if is_intel_xpu() and not args.disable_ipex_optimize: self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True) return self.real_model @@ -471,12 +478,11 @@ def get_autocast_device(dev): def xformers_enabled(): - global xpu_available global directml_enabled global cpu_state if cpu_state != CPUState.GPU: return False - if xpu_available: + if is_intel_xpu(): return False if directml_enabled: return False @@ -503,7 +509,6 @@ def pytorch_attention_flash_attention(): return False def get_free_memory(dev=None, torch_free_too=False): - global xpu_available global directml_enabled if dev is None: dev = get_torch_device() @@ -515,7 +520,7 @@ def get_free_memory(dev=None, torch_free_too=False): if directml_enabled: mem_free_total = 1024 * 1024 * 1024 #TODO mem_free_torch = mem_free_total - elif xpu_available: + elif is_intel_xpu(): stats = torch.xpu.memory_stats(dev) mem_active = stats['active_bytes.all.current'] mem_allocated = stats['allocated_bytes.all.current'] @@ -577,7 +582,6 @@ def is_device_mps(device): return False def should_use_fp16(device=None, model_params=0, prioritize_performance=True): - global xpu_available global directml_enabled if device is not None: @@ -600,7 +604,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): if cpu_mode() or mps_mode(): return False #TODO ? - if xpu_available: + if is_intel_xpu(): return True if torch.cuda.is_bf16_supported(): @@ -636,11 +640,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): return True def soft_empty_cache(): - global xpu_available global cpu_state if cpu_state == CPUState.MPS: torch.mps.empty_cache() - elif xpu_available: + elif is_intel_xpu(): torch.xpu.empty_cache() elif torch.cuda.is_available(): if is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda From 2da73b7073dc520ee480dee8ff911b9aa83ff70a Mon Sep 17 00:00:00 2001 From: Simon Lui <502929+simonlui@users.noreply.github.com> Date: Sat, 2 Sep 2023 20:07:52 -0700 Subject: [PATCH 2/2] Revert changes in comfy/ldm/modules/diffusionmodules/util.py, which is unused. --- comfy/ldm/modules/diffusionmodules/util.py | 24 +++++++--------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index 9d07d935..d890c804 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -15,7 +15,6 @@ import torch.nn as nn import numpy as np from einops import repeat -from comfy import model_management from comfy.ldm.util import instantiate_from_config import comfy.ops @@ -140,22 +139,13 @@ class CheckpointFunction(torch.autograd.Function): @staticmethod def backward(ctx, *output_grads): ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] - if model_management.is_nvidia(): - with torch.enable_grad(), \ - torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): - # Fixes a bug where the first op in run_function modifies the - # Tensor storage in place, which is not allowed for detach()'d - # Tensors. - shallow_copies = [x.view_as(x) for x in ctx.input_tensors] - output_tensors = ctx.run_function(*shallow_copies) - elif model_management.is_intel_xpu(): - with torch.enable_grad(), \ - torch.xpu.amp.autocast(**ctx.gpu_autocast_kwargs): - # Fixes a bug where the first op in run_function modifies the - # Tensor storage in place, which is not allowed for detach()'d - # Tensors. - shallow_copies = [x.view_as(x) for x in ctx.input_tensors] - output_tensors = ctx.run_function(*shallow_copies) + with torch.enable_grad(), \ + torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) input_grads = torch.autograd.grad( output_tensors, ctx.input_tensors + ctx.input_params,