From 0229228f3f75fc4b0d0d4cf3658138eedc2cc2eb Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 25 Dec 2024 04:50:34 -0500 Subject: [PATCH] Clean up the VAE dtypes code. --- comfy/model_management.py | 27 ++++++++++++--------------- comfy/sd.py | 4 ++-- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 33891b92..8320c6ec 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -188,6 +188,12 @@ def is_nvidia(): return True return False +def is_amd(): + global cpu_state + if cpu_state == CPUState.GPU: + if torch.version.hip: + return True + return False MIN_WEIGHT_MEMORY_RATIO = 0.4 if is_nvidia(): @@ -198,27 +204,17 @@ if args.use_pytorch_cross_attention: ENABLE_PYTORCH_ATTENTION = True XFORMERS_IS_AVAILABLE = False -VAE_DTYPES = [torch.float32] - try: if is_nvidia(): if int(torch_version[0]) >= 2: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True - if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8: - VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES if is_intel_xpu(): if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True except: pass -if is_intel_xpu(): - VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES - -if args.cpu_vae: - VAE_DTYPES = [torch.float32] - if ENABLE_PYTORCH_ATTENTION: torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_flash_sdp(True) @@ -754,7 +750,6 @@ def vae_offload_device(): return torch.device("cpu") def vae_dtype(device=None, allowed_dtypes=[]): - global VAE_DTYPES if args.fp16_vae: return torch.float16 elif args.bf16_vae: @@ -763,12 +758,14 @@ def vae_dtype(device=None, allowed_dtypes=[]): return torch.float32 for d in allowed_dtypes: - if d == torch.float16 and should_use_fp16(device, prioritize_performance=False): - return d - if d in VAE_DTYPES: + if d == torch.float16 and should_use_fp16(device): return d - return VAE_DTYPES[0] + # NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32 + if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device): + return d + + return torch.float32 def get_autocast_device(dev): if hasattr(dev, 'type'): diff --git a/comfy/sd.py b/comfy/sd.py index de3ce677..55f91116 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -111,7 +111,7 @@ class CLIP: model_management.load_models_gpu([self.patcher], force_full_load=True) self.layer_idx = None self.use_clip_schedule = False - logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device'])) + logging.info("CLIP model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype)) def clone(self): n = CLIP(no_init=True) @@ -402,7 +402,7 @@ class VAE: self.output_device = model_management.intermediate_device() self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) - logging.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) + logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) def vae_encode_crop_pixels(self, pixels): downscale_ratio = self.spacial_compression_encode()