2024-08-08 07:27:37 +00:00
"""
This file is part of ComfyUI .
Copyright ( C ) 2024 Comfy
This program is free software : you can redistribute it and / or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation , either version 3 of the License , or
( at your option ) any later version .
This program is distributed in the hope that it will be useful ,
but WITHOUT ANY WARRANTY ; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE . See the
GNU General Public License for more details .
You should have received a copy of the GNU General Public License
along with this program . If not , see < https : / / www . gnu . org / licenses / > .
"""
2023-04-06 03:41:23 +00:00
import psutil
2024-03-10 15:37:08 +00:00
import logging
2023-04-06 03:41:23 +00:00
from enum import Enum
2023-05-05 04:19:35 +00:00
from comfy . cli_args import args
2023-06-02 19:05:25 +00:00
import torch
2023-08-17 05:06:34 +00:00
import sys
2024-05-21 20:56:33 +00:00
import platform
2023-02-08 08:17:54 +00:00
2023-04-06 03:41:23 +00:00
class VRAMState ( Enum ) :
2023-06-04 21:51:04 +00:00
DISABLED = 0 #No vram present: no need to move models to vram
NO_VRAM = 1 #Very low vram: enable all the options to save vram
2023-04-06 03:41:23 +00:00
LOW_VRAM = 2
NORMAL_VRAM = 3
HIGH_VRAM = 4
2023-06-04 21:51:04 +00:00
SHARED = 5 #No dedicated vram: memory shared between CPU and GPU but models still need to be moved between both.
2023-06-03 15:05:37 +00:00
class CPUState ( Enum ) :
GPU = 0
CPU = 1
MPS = 2
2023-02-08 16:37:10 +00:00
2023-04-06 03:41:23 +00:00
# Determine VRAM State
vram_state = VRAMState . NORMAL_VRAM
set_vram_to = VRAMState . NORMAL_VRAM
2023-06-03 15:05:37 +00:00
cpu_state = CPUState . GPU
2023-02-08 16:37:10 +00:00
2023-02-08 19:05:31 +00:00
total_vram = 0
2023-02-08 16:42:37 +00:00
2024-08-23 08:04:55 +00:00
xpu_available = False
2024-08-30 16:48:42 +00:00
torch_version = " "
2024-08-23 08:04:55 +00:00
try :
torch_version = torch . version . __version__
2024-08-23 08:06:27 +00:00
xpu_available = ( int ( torch_version [ 0 ] ) < 2 or ( int ( torch_version [ 0 ] ) == 2 and int ( torch_version [ 2 ] ) < = 4 ) ) and torch . xpu . is_available ( )
2024-08-23 08:04:55 +00:00
except :
pass
2024-08-23 07:59:57 +00:00
2023-05-30 16:36:41 +00:00
lowvram_available = True
2023-12-17 21:59:21 +00:00
if args . deterministic :
2024-03-11 17:54:56 +00:00
logging . info ( " Using deterministic algorithms for pytorch " )
2023-12-17 21:59:21 +00:00
torch . use_deterministic_algorithms ( True , warn_only = True )
2023-04-28 18:28:57 +00:00
directml_enabled = False
2023-04-28 20:51:35 +00:00
if args . directml is not None :
2023-04-28 18:28:57 +00:00
import torch_directml
directml_enabled = True
2023-04-28 20:51:35 +00:00
device_index = args . directml
if device_index < 0 :
directml_device = torch_directml . device ( )
else :
directml_device = torch_directml . device ( device_index )
2024-03-11 17:54:56 +00:00
logging . info ( " Using directml with device: {} " . format ( torch_directml . device_name ( device_index ) ) )
2023-04-28 18:28:57 +00:00
# torch_directml.disable_tiled_resources(True)
2023-05-30 16:36:41 +00:00
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
2023-04-28 18:28:57 +00:00
2023-02-08 19:05:31 +00:00
try :
2023-06-02 19:05:25 +00:00
import intel_extension_for_pytorch as ipex
2024-08-23 07:59:57 +00:00
_ = torch . xpu . device_count ( )
xpu_available = torch . xpu . is_available ( )
2023-02-08 19:05:31 +00:00
except :
2024-08-23 07:59:57 +00:00
xpu_available = xpu_available or ( hasattr ( torch , " xpu " ) and torch . xpu . is_available ( ) )
2023-02-08 19:05:31 +00:00
2023-06-03 15:05:37 +00:00
try :
if torch . backends . mps . is_available ( ) :
cpu_state = CPUState . MPS
2023-07-12 02:06:34 +00:00
import torch . mps
2023-06-03 15:05:37 +00:00
except :
pass
if args . cpu :
cpu_state = CPUState . CPU
2023-09-03 01:22:10 +00:00
def is_intel_xpu ( ) :
global cpu_state
2023-06-02 19:05:25 +00:00
global xpu_available
2023-09-03 01:22:10 +00:00
if cpu_state == CPUState . GPU :
if xpu_available :
return True
return False
def get_torch_device ( ) :
2023-06-02 19:05:25 +00:00
global directml_enabled
2023-06-03 15:05:37 +00:00
global cpu_state
2023-06-02 19:05:25 +00:00
if directml_enabled :
global directml_device
return directml_device
2023-06-03 15:05:37 +00:00
if cpu_state == CPUState . MPS :
2023-06-02 19:05:25 +00:00
return torch . device ( " mps " )
2023-06-03 15:05:37 +00:00
if cpu_state == CPUState . CPU :
2023-06-02 19:05:25 +00:00
return torch . device ( " cpu " )
else :
2023-09-03 01:22:10 +00:00
if is_intel_xpu ( ) :
2024-05-02 07:26:50 +00:00
return torch . device ( " xpu " , torch . xpu . current_device ( ) )
2023-06-02 19:05:25 +00:00
else :
return torch . device ( torch . cuda . current_device ( ) )
def get_total_memory ( dev = None , torch_total_too = False ) :
global directml_enabled
if dev is None :
dev = get_torch_device ( )
if hasattr ( dev , ' type ' ) and ( dev . type == ' cpu ' or dev . type == ' mps ' ) :
mem_total = psutil . virtual_memory ( ) . total
mem_total_torch = mem_total
else :
if directml_enabled :
mem_total = 1024 * 1024 * 1024 #TODO
mem_total_torch = mem_total
2023-09-03 01:22:10 +00:00
elif is_intel_xpu ( ) :
2023-08-17 10:12:17 +00:00
stats = torch . xpu . memory_stats ( dev )
mem_reserved = stats [ ' reserved_bytes.all.current ' ]
mem_total_torch = mem_reserved
2024-05-12 10:36:30 +00:00
mem_total = torch . xpu . get_device_properties ( dev ) . total_memory
2023-06-02 19:05:25 +00:00
else :
stats = torch . cuda . memory_stats ( dev )
mem_reserved = stats [ ' reserved_bytes.all.current ' ]
_ , mem_total_cuda = torch . cuda . mem_get_info ( dev )
mem_total_torch = mem_reserved
mem_total = mem_total_cuda
if torch_total_too :
return ( mem_total , mem_total_torch )
else :
return mem_total
total_vram = get_total_memory ( get_torch_device ( ) ) / ( 1024 * 1024 )
total_ram = psutil . virtual_memory ( ) . total / ( 1024 * 1024 )
2024-03-11 17:54:56 +00:00
logging . info ( " Total VRAM {:0.0f} MB, total RAM {:0.0f} MB " . format ( total_vram , total_ram ) )
2023-06-02 19:05:25 +00:00
2024-05-20 10:22:29 +00:00
try :
2024-10-09 23:43:17 +00:00
logging . info ( " pytorch version: {} " . format ( torch_version ) )
2024-05-20 10:22:29 +00:00
except :
pass
2023-03-22 18:49:00 +00:00
try :
OOM_EXCEPTION = torch . cuda . OutOfMemoryError
except :
OOM_EXCEPTION = Exception
2023-04-09 05:31:47 +00:00
XFORMERS_VERSION = " "
XFORMERS_ENABLED_VAE = True
2023-04-06 03:41:23 +00:00
if args . disable_xformers :
XFORMERS_IS_AVAILABLE = False
2023-03-13 15:36:48 +00:00
else :
try :
import xformers
import xformers . ops
2023-04-06 03:41:23 +00:00
XFORMERS_IS_AVAILABLE = True
2023-11-13 17:27:44 +00:00
try :
XFORMERS_IS_AVAILABLE = xformers . _has_cpp_library
except :
pass
2023-04-09 05:31:47 +00:00
try :
XFORMERS_VERSION = xformers . version . __version__
2024-03-11 17:54:56 +00:00
logging . info ( " xformers version: {} " . format ( XFORMERS_VERSION ) )
2023-04-09 05:31:47 +00:00
if XFORMERS_VERSION . startswith ( " 0.0.18 " ) :
2024-03-10 15:37:08 +00:00
logging . warning ( " \n WARNING: This version of xformers has a major bug where you will get black images when generating high resolution images. " )
logging . warning ( " Please downgrade or upgrade xformers to a different version. \n " )
2023-04-09 05:31:47 +00:00
XFORMERS_ENABLED_VAE = False
except :
pass
2023-03-13 15:36:48 +00:00
except :
2023-04-06 03:41:23 +00:00
XFORMERS_IS_AVAILABLE = False
2023-03-13 15:36:48 +00:00
2023-06-26 16:55:07 +00:00
def is_nvidia ( ) :
global cpu_state
if cpu_state == CPUState . GPU :
if torch . version . cuda :
return True
2023-09-03 01:22:10 +00:00
return False
2023-06-26 16:55:07 +00:00
2023-10-12 01:29:03 +00:00
ENABLE_PYTORCH_ATTENTION = False
if args . use_pytorch_cross_attention :
ENABLE_PYTORCH_ATTENTION = True
XFORMERS_IS_AVAILABLE = False
2024-06-16 17:12:54 +00:00
VAE_DTYPES = [ torch . float32 ]
2023-06-26 16:55:07 +00:00
2023-08-28 03:06:19 +00:00
try :
if is_nvidia ( ) :
if int ( torch_version [ 0 ] ) > = 2 :
2023-10-12 01:29:03 +00:00
if ENABLE_PYTORCH_ATTENTION == False and args . use_split_cross_attention == False and args . use_quad_cross_attention == False :
2023-06-26 16:55:07 +00:00
ENABLE_PYTORCH_ATTENTION = True
2024-01-15 08:10:22 +00:00
if torch . cuda . is_bf16_supported ( ) and torch . cuda . get_device_properties ( torch . cuda . current_device ( ) ) . major > = 8 :
2024-06-16 17:12:54 +00:00
VAE_DTYPES = [ torch . bfloat16 ] + VAE_DTYPES
2023-09-17 08:09:19 +00:00
if is_intel_xpu ( ) :
if args . use_split_cross_attention == False and args . use_quad_cross_attention == False :
ENABLE_PYTORCH_ATTENTION = True
2023-08-28 03:06:19 +00:00
except :
pass
2023-09-03 01:22:10 +00:00
if is_intel_xpu ( ) :
2024-06-16 17:12:54 +00:00
VAE_DTYPES = [ torch . bfloat16 ] + VAE_DTYPES
2023-09-03 01:22:10 +00:00
2023-12-30 10:38:21 +00:00
if args . cpu_vae :
2024-06-16 17:12:54 +00:00
VAE_DTYPES = [ torch . float32 ]
2023-08-28 03:06:19 +00:00
2023-06-26 16:55:07 +00:00
2023-04-06 03:41:23 +00:00
if ENABLE_PYTORCH_ATTENTION :
2023-03-13 16:25:19 +00:00
torch . backends . cuda . enable_math_sdp ( True )
torch . backends . cuda . enable_flash_sdp ( True )
torch . backends . cuda . enable_mem_efficient_sdp ( True )
2023-03-12 19:44:16 +00:00
2023-04-06 03:41:23 +00:00
if args . lowvram :
set_vram_to = VRAMState . LOW_VRAM
2023-05-30 16:36:41 +00:00
lowvram_available = True
2023-04-06 03:41:23 +00:00
elif args . novram :
set_vram_to = VRAMState . NO_VRAM
2023-06-15 19:21:37 +00:00
elif args . highvram or args . gpu_only :
2023-04-06 03:41:23 +00:00
vram_state = VRAMState . HIGH_VRAM
2023-03-24 18:30:43 +00:00
2023-04-07 04:27:54 +00:00
FORCE_FP32 = False
2023-07-02 02:42:35 +00:00
FORCE_FP16 = False
2023-04-07 04:27:54 +00:00
if args . force_fp32 :
2024-03-11 17:54:56 +00:00
logging . info ( " Forcing FP32, if this improves things please report it. " )
2023-04-07 04:27:54 +00:00
FORCE_FP32 = True
2023-07-02 02:42:35 +00:00
if args . force_fp16 :
2024-03-11 17:54:56 +00:00
logging . info ( " Forcing FP16. " )
2023-07-02 02:42:35 +00:00
FORCE_FP16 = True
2023-05-30 16:36:41 +00:00
if lowvram_available :
2023-12-22 19:24:04 +00:00
if set_vram_to in ( VRAMState . LOW_VRAM , VRAMState . NO_VRAM ) :
vram_state = set_vram_to
2023-02-08 19:05:31 +00:00
2023-02-08 16:37:10 +00:00
2023-06-03 15:05:37 +00:00
if cpu_state != CPUState . GPU :
vram_state = VRAMState . DISABLED
2023-03-24 18:30:43 +00:00
2023-06-03 15:05:37 +00:00
if cpu_state == CPUState . MPS :
vram_state = VRAMState . SHARED
2023-02-08 16:37:10 +00:00
2024-03-11 17:54:56 +00:00
logging . info ( f " Set vram state to: { vram_state . name } " )
2023-02-08 16:37:10 +00:00
2023-08-17 07:12:37 +00:00
DISABLE_SMART_MEMORY = args . disable_smart_memory
if DISABLE_SMART_MEMORY :
2024-03-11 17:54:56 +00:00
logging . info ( " Disabling smart memory management " )
2023-06-03 15:05:37 +00:00
2023-05-13 21:11:27 +00:00
def get_torch_device_name ( device ) :
if hasattr ( device , ' type ' ) :
2023-06-02 19:05:25 +00:00
if device . type == " cuda " :
2023-07-17 19:18:58 +00:00
try :
allocator_backend = torch . cuda . get_allocator_backend ( )
except :
allocator_backend = " "
return " {} {} : {} " . format ( device , torch . cuda . get_device_name ( device ) , allocator_backend )
2023-06-02 19:05:25 +00:00
else :
return " {} " . format ( device . type )
2023-09-03 01:22:10 +00:00
elif is_intel_xpu ( ) :
2023-08-17 10:12:17 +00:00
return " {} {} " . format ( device , torch . xpu . get_device_name ( device ) )
2023-06-02 19:05:25 +00:00
else :
return " CUDA {} : {} " . format ( device , torch . cuda . get_device_name ( device ) )
2023-05-13 21:11:27 +00:00
try :
2024-03-11 17:54:56 +00:00
logging . info ( " Device: {} " . format ( get_torch_device_name ( get_torch_device ( ) ) ) )
2023-05-13 21:11:27 +00:00
except :
2024-03-10 15:37:08 +00:00
logging . warning ( " Could not pick default device. " )
2023-05-13 21:11:27 +00:00
2023-02-08 08:17:54 +00:00
2023-08-17 05:06:34 +00:00
current_loaded_models = [ ]
2023-02-08 08:17:54 +00:00
2023-12-29 02:41:10 +00:00
def module_size ( module ) :
module_mem = 0
sd = module . state_dict ( )
for k in sd :
t = sd [ k ]
module_mem + = t . nelement ( ) * t . element_size ( )
return module_mem
2023-08-17 05:06:34 +00:00
class LoadedModel :
def __init__ ( self , model ) :
self . model = model
self . device = model . load_device
2024-03-20 05:29:26 +00:00
self . weights_loaded = False
2024-03-28 22:01:04 +00:00
self . real_model = None
2024-06-05 23:14:56 +00:00
self . currently_used = True
2023-02-08 16:37:10 +00:00
2023-08-17 05:06:34 +00:00
def model_memory ( self ) :
return self . model . model_size ( )
2023-02-08 16:37:10 +00:00
2024-08-08 07:27:37 +00:00
def model_offloaded_memory ( self ) :
return self . model . model_size ( ) - self . model . loaded_size ( )
2023-08-17 05:06:34 +00:00
def model_memory_required ( self , device ) :
2024-08-06 17:27:48 +00:00
if device == self . model . current_loaded_device ( ) :
2024-08-10 19:29:36 +00:00
return self . model_offloaded_memory ( )
2023-08-17 05:06:34 +00:00
else :
return self . model_memory ( )
2023-02-18 02:14:07 +00:00
2024-05-12 01:46:05 +00:00
def model_load ( self , lowvram_model_memory = 0 , force_patch_weights = False ) :
2024-03-13 23:04:41 +00:00
patch_model_to = self . device
2023-02-08 16:37:10 +00:00
2023-08-17 05:06:34 +00:00
self . model . model_patches_to ( self . device )
self . model . model_patches_to ( self . model . model_dtype ( ) )
2023-02-18 02:14:07 +00:00
2024-03-20 05:29:26 +00:00
load_weights = not self . weights_loaded
2024-08-10 19:29:36 +00:00
if self . model . loaded_size ( ) > 0 :
use_more_vram = lowvram_model_memory
if use_more_vram == 0 :
use_more_vram = 1e32
self . model_use_more_vram ( use_more_vram )
else :
try :
2024-08-19 19:24:07 +00:00
self . real_model = self . model . patch_model ( device_to = patch_model_to , lowvram_model_memory = lowvram_model_memory , load_weights = load_weights , force_patch_weights = force_patch_weights )
2024-08-10 19:29:36 +00:00
except Exception as e :
self . model . unpatch_model ( self . model . offload_device )
self . model_unload ( )
raise e
2023-02-08 08:17:54 +00:00
2024-09-19 09:11:42 +00:00
if is_intel_xpu ( ) and not args . disable_ipex_optimize and ' ipex ' in globals ( ) and self . real_model is not None :
2024-08-23 07:59:57 +00:00
with torch . no_grad ( ) :
self . real_model = ipex . optimize ( self . real_model . eval ( ) , inplace = True , graph_mode = True , concat_linear = True )
2023-08-17 10:12:17 +00:00
2024-03-20 05:29:26 +00:00
self . weights_loaded = True
2023-08-17 05:06:34 +00:00
return self . real_model
2023-02-08 16:37:10 +00:00
2024-05-12 10:13:45 +00:00
def should_reload_model ( self , force_patch_weights = False ) :
2024-08-09 07:36:40 +00:00
if force_patch_weights and self . model . lowvram_patch_counter ( ) > 0 :
2024-05-12 10:13:45 +00:00
return True
return False
2024-08-08 07:27:37 +00:00
def model_unload ( self , memory_to_free = None , unpatch_weights = True ) :
if memory_to_free is not None :
if memory_to_free < self . model . loaded_size ( ) :
2024-08-13 07:57:55 +00:00
freed = self . model . partially_unload ( self . model . offload_device , memory_to_free )
if freed > = memory_to_free :
return False
2024-03-20 05:29:26 +00:00
self . model . unpatch_model ( self . model . offload_device , unpatch_weights = unpatch_weights )
2023-08-17 05:06:34 +00:00
self . model . model_patches_to ( self . model . offload_device )
2024-03-20 05:29:26 +00:00
self . weights_loaded = self . weights_loaded and not unpatch_weights
2024-03-28 22:01:04 +00:00
self . real_model = None
2024-08-08 07:27:37 +00:00
return True
def model_use_more_vram ( self , extra_memory ) :
return self . model . partially_load ( self . device , extra_memory )
2023-05-30 16:36:41 +00:00
2023-08-17 05:06:34 +00:00
def __eq__ ( self , other ) :
return self . model is other . model
2023-07-15 17:24:05 +00:00
2024-08-08 07:27:37 +00:00
def use_more_memory ( extra_memory , loaded_models , device ) :
for m in loaded_models :
if m . device == device :
extra_memory - = m . model_use_more_vram ( extra_memory )
if extra_memory < = 0 :
break
def offloaded_memory ( loaded_models , device ) :
offloaded_mem = 0
for m in loaded_models :
if m . device == device :
offloaded_mem + = m . model_offloaded_memory ( )
return offloaded_mem
2024-09-01 21:29:31 +00:00
WINDOWS = any ( platform . win32_ver ( ) )
2024-09-01 05:01:54 +00:00
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
2024-09-01 21:29:31 +00:00
if WINDOWS :
2024-09-01 05:01:54 +00:00
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
2024-08-19 21:16:18 +00:00
if args . reserve_vram is not None :
EXTRA_RESERVED_VRAM = args . reserve_vram * 1024 * 1024 * 1024
logging . debug ( " Reserving {} MB vram for other applications. " . format ( EXTRA_RESERVED_VRAM / ( 1024 * 1024 ) ) )
def extra_reserved_memory ( ) :
return EXTRA_RESERVED_VRAM
2024-09-01 05:01:54 +00:00
def minimum_inference_memory ( ) :
return ( 1024 * 1024 * 1024 ) * 0.8 + extra_reserved_memory ( )
2024-03-20 17:53:45 +00:00
def unload_model_clones ( model , unload_weights_only = True , force_unload = True ) :
2023-08-17 05:06:34 +00:00
to_unload = [ ]
for i in range ( len ( current_loaded_models ) ) :
if model . is_clone ( current_loaded_models [ i ] . model ) :
to_unload = [ i ] + to_unload
2024-03-20 05:29:26 +00:00
if len ( to_unload ) == 0 :
2024-03-28 22:01:04 +00:00
return True
2024-03-20 05:29:26 +00:00
same_weights = 0
2023-08-17 05:06:34 +00:00
for i in to_unload :
2024-03-20 05:29:26 +00:00
if model . clone_has_same_weights ( current_loaded_models [ i ] . model ) :
same_weights + = 1
if same_weights == len ( to_unload ) :
unload_weight = False
else :
unload_weight = True
2024-03-20 17:53:45 +00:00
if not force_unload :
if unload_weights_only and unload_weight == False :
return None
2024-08-27 22:46:55 +00:00
else :
unload_weight = True
2024-03-20 05:29:26 +00:00
for i in to_unload :
logging . debug ( " unload clone {} {} " . format ( i , unload_weight ) )
current_loaded_models . pop ( i ) . model_unload ( unpatch_weights = unload_weight )
2024-03-20 17:53:45 +00:00
return unload_weight
2023-08-17 05:06:34 +00:00
def free_memory ( memory_required , device , keep_loaded = [ ] ) :
2024-03-24 06:36:30 +00:00
unloaded_model = [ ]
can_unload = [ ]
2024-08-06 07:22:39 +00:00
unloaded_models = [ ]
2024-03-24 06:36:30 +00:00
2023-08-17 05:06:34 +00:00
for i in range ( len ( current_loaded_models ) - 1 , - 1 , - 1 ) :
shift_model = current_loaded_models [ i ]
if shift_model . device == device :
if shift_model not in keep_loaded :
2024-09-04 23:47:32 +00:00
can_unload . append ( ( - shift_model . model_offloaded_memory ( ) , sys . getrefcount ( shift_model . model ) , shift_model . model_memory ( ) , i ) )
2024-06-05 23:14:56 +00:00
shift_model . currently_used = False
2024-03-24 06:36:30 +00:00
for x in sorted ( can_unload ) :
i = x [ - 1 ]
2024-08-08 07:27:37 +00:00
memory_to_free = None
2024-03-24 06:36:30 +00:00
if not DISABLE_SMART_MEMORY :
2024-08-08 07:27:37 +00:00
free_mem = get_free_memory ( device )
if free_mem > memory_required :
2024-03-24 06:36:30 +00:00
break
2024-08-08 07:27:37 +00:00
memory_to_free = memory_required - free_mem
logging . debug ( f " Unloading { current_loaded_models [ i ] . model . model . __class__ . __name__ } " )
2024-08-08 19:16:51 +00:00
if current_loaded_models [ i ] . model_unload ( memory_to_free ) :
2024-08-08 07:27:37 +00:00
unloaded_model . append ( i )
2024-03-24 06:36:30 +00:00
for i in sorted ( unloaded_model , reverse = True ) :
2024-08-06 07:22:39 +00:00
unloaded_models . append ( current_loaded_models . pop ( i ) )
2023-08-17 05:06:34 +00:00
2024-03-24 06:36:30 +00:00
if len ( unloaded_model ) > 0 :
2023-08-17 05:06:34 +00:00
soft_empty_cache ( )
2023-10-22 17:53:59 +00:00
else :
if vram_state != VRAMState . HIGH_VRAM :
mem_free_total , mem_free_torch = get_free_memory ( device , torch_free_too = True )
if mem_free_torch > mem_free_total * 0.25 :
soft_empty_cache ( )
2024-08-06 07:22:39 +00:00
return unloaded_models
2023-08-17 05:06:34 +00:00
2024-08-13 03:42:21 +00:00
def load_models_gpu ( models , memory_required = 0 , force_patch_weights = False , minimum_memory_required = None , force_full_load = False ) :
2023-02-17 20:45:29 +00:00
global vram_state
2023-08-17 05:06:34 +00:00
inference_memory = minimum_inference_memory ( )
2024-08-19 21:16:18 +00:00
extra_mem = max ( inference_memory , memory_required + extra_reserved_memory ( ) )
2024-08-01 20:39:59 +00:00
if minimum_memory_required is None :
minimum_memory_required = extra_mem
else :
2024-08-19 21:16:18 +00:00
minimum_memory_required = max ( inference_memory , minimum_memory_required + extra_reserved_memory ( ) )
2023-08-17 05:06:34 +00:00
2024-04-06 22:38:39 +00:00
models = set ( models )
2023-08-17 05:06:34 +00:00
models_to_load = [ ]
models_already_loaded = [ ]
for x in models :
loaded_model = LoadedModel ( x )
2024-05-12 10:13:45 +00:00
loaded = None
2023-08-17 05:06:34 +00:00
2024-05-12 10:13:45 +00:00
try :
loaded_model_index = current_loaded_models . index ( loaded_model )
except :
loaded_model_index = None
if loaded_model_index is not None :
loaded = current_loaded_models [ loaded_model_index ]
if loaded . should_reload_model ( force_patch_weights = force_patch_weights ) : #TODO: cleanup this model reload logic
current_loaded_models . pop ( loaded_model_index ) . model_unload ( unpatch_weights = True )
loaded = None
else :
2024-06-05 23:14:56 +00:00
loaded . currently_used = True
2024-05-12 10:13:45 +00:00
models_already_loaded . append ( loaded )
if loaded is None :
2023-10-12 00:35:50 +00:00
if hasattr ( x , " model " ) :
2024-03-11 17:54:56 +00:00
logging . info ( f " Requested to load { x . model . __class__ . __name__ } " )
2023-08-17 05:06:34 +00:00
models_to_load . append ( loaded_model )
if len ( models_to_load ) == 0 :
devs = set ( map ( lambda a : a . device , models_already_loaded ) )
for d in devs :
if d != torch . device ( " cpu " ) :
2024-08-08 07:27:37 +00:00
free_memory ( extra_mem + offloaded_memory ( models_already_loaded , d ) , d , models_already_loaded )
2024-08-06 07:22:39 +00:00
free_mem = get_free_memory ( d )
if free_mem < minimum_memory_required :
logging . info ( " Unloading models for lowram load. " ) #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed.
models_to_load = free_memory ( minimum_memory_required , d )
logging . info ( " {} models unloaded. " . format ( len ( models_to_load ) ) )
2024-08-08 07:27:37 +00:00
else :
use_more_memory ( free_mem - minimum_memory_required , models_already_loaded , d )
2024-08-06 07:22:39 +00:00
if len ( models_to_load ) == 0 :
return
2023-02-17 20:45:29 +00:00
2024-03-11 17:54:56 +00:00
logging . info ( f " Loading { len ( models_to_load ) } new model { ' s ' if len ( models_to_load ) > 1 else ' ' } " )
2023-04-19 13:36:19 +00:00
2023-08-17 05:06:34 +00:00
total_memory_required = { }
for loaded_model in models_to_load :
2024-08-10 19:29:36 +00:00
unload_model_clones ( loaded_model . model , unload_weights_only = True , force_unload = False ) #unload clones where the weights are different
total_memory_required [ loaded_model . device ] = total_memory_required . get ( loaded_model . device , 0 ) + loaded_model . model_memory_required ( loaded_model . device )
2023-02-16 15:38:08 +00:00
2024-08-10 19:29:36 +00:00
for loaded_model in models_already_loaded :
total_memory_required [ loaded_model . device ] = total_memory_required . get ( loaded_model . device , 0 ) + loaded_model . model_memory_required ( loaded_model . device )
2023-02-16 15:38:08 +00:00
2024-03-20 05:29:26 +00:00
for loaded_model in models_to_load :
2024-03-20 17:53:45 +00:00
weights_unloaded = unload_model_clones ( loaded_model . model , unload_weights_only = False , force_unload = False ) #unload the rest of the clones where the weights can stay loaded
if weights_unloaded is not None :
loaded_model . weights_loaded = not weights_unloaded
2024-03-20 05:29:26 +00:00
2024-08-10 19:29:36 +00:00
for device in total_memory_required :
if device != torch . device ( " cpu " ) :
free_memory ( total_memory_required [ device ] * 1.1 + extra_mem , device , models_already_loaded )
2023-08-17 05:06:34 +00:00
for loaded_model in models_to_load :
model = loaded_model . model
torch_dev = model . load_device
if is_device_cpu ( torch_dev ) :
vram_set_state = VRAMState . DISABLED
else :
vram_set_state = vram_state
lowvram_model_memory = 0
2024-08-13 03:42:21 +00:00
if lowvram_available and ( vram_set_state == VRAMState . LOW_VRAM or vram_set_state == VRAMState . NORMAL_VRAM ) and not force_full_load :
2023-08-17 05:06:34 +00:00
model_size = loaded_model . model_memory_required ( torch_dev )
current_free_mem = get_free_memory ( torch_dev )
2024-08-03 20:34:27 +00:00
lowvram_model_memory = max ( 64 * ( 1024 * 1024 ) , ( current_free_mem - minimum_memory_required ) , min ( current_free_mem * 0.4 , current_free_mem - minimum_inference_memory ( ) ) )
2024-08-01 20:39:59 +00:00
if model_size < = lowvram_model_memory : #only switch to lowvram if really necessary
2023-08-17 05:06:34 +00:00
lowvram_model_memory = 0
2023-02-08 19:05:31 +00:00
2023-08-17 05:06:34 +00:00
if vram_set_state == VRAMState . NO_VRAM :
2023-12-22 19:24:04 +00:00
lowvram_model_memory = 64 * 1024 * 1024
2023-02-17 20:45:29 +00:00
2024-05-12 01:46:05 +00:00
cur_loaded_model = loaded_model . model_load ( lowvram_model_memory , force_patch_weights = force_patch_weights )
2023-08-17 05:06:34 +00:00
current_loaded_models . insert ( 0 , loaded_model )
2024-08-08 07:27:37 +00:00
devs = set ( map ( lambda a : a . device , models_already_loaded ) )
for d in devs :
if d != torch . device ( " cpu " ) :
free_mem = get_free_memory ( d )
if free_mem > minimum_memory_required :
use_more_memory ( free_mem - minimum_memory_required , models_already_loaded , d )
2023-08-17 05:06:34 +00:00
return
def load_model_gpu ( model ) :
return load_models_gpu ( [ model ] )
2024-06-05 23:14:56 +00:00
def loaded_models ( only_currently_used = False ) :
output = [ ]
for m in current_loaded_models :
if only_currently_used :
if not m . currently_used :
continue
output . append ( m . model )
return output
2024-03-28 22:01:04 +00:00
def cleanup_models ( keep_clone_weights_loaded = False ) :
2023-08-17 05:06:34 +00:00
to_delete = [ ]
for i in range ( len ( current_loaded_models ) ) :
2024-08-22 19:57:40 +00:00
#TODO: very fragile function needs improvement
2024-08-22 21:05:12 +00:00
num_refs = sys . getrefcount ( current_loaded_models [ i ] . model )
2024-08-22 19:57:40 +00:00
if num_refs < = 2 :
2024-03-28 22:01:04 +00:00
if not keep_clone_weights_loaded :
to_delete = [ i ] + to_delete
#TODO: find a less fragile way to do this.
elif sys . getrefcount ( current_loaded_models [ i ] . real_model ) < = 3 : #references from .real_model + the .model
to_delete = [ i ] + to_delete
2023-08-17 05:06:34 +00:00
for i in to_delete :
x = current_loaded_models . pop ( i )
x . model_unload ( )
del x
2023-02-17 20:45:29 +00:00
2023-08-24 21:20:54 +00:00
def dtype_size ( dtype ) :
dtype_size = 4
if dtype == torch . float16 or dtype == torch . bfloat16 :
dtype_size = 2
2023-12-04 16:52:06 +00:00
elif dtype == torch . float32 :
dtype_size = 4
else :
try :
dtype_size = dtype . itemsize
except : #Old pytorch doesn't have .itemsize
pass
2023-08-24 21:20:54 +00:00
return dtype_size
2023-07-01 17:22:51 +00:00
def unet_offload_device ( ) :
2023-07-03 04:08:30 +00:00
if vram_state == VRAMState . HIGH_VRAM :
2023-07-01 17:22:51 +00:00
return get_torch_device ( )
else :
return torch . device ( " cpu " )
2023-08-17 05:06:34 +00:00
def unet_inital_load_device ( parameters , dtype ) :
torch_dev = get_torch_device ( )
if vram_state == VRAMState . HIGH_VRAM :
return torch_dev
cpu_dev = torch . device ( " cpu " )
2023-08-20 08:00:53 +00:00
if DISABLE_SMART_MEMORY :
return cpu_dev
2023-08-24 21:20:54 +00:00
model_size = dtype_size ( dtype ) * parameters
2023-08-17 05:06:34 +00:00
mem_dev = get_free_memory ( torch_dev )
mem_cpu = get_free_memory ( cpu_dev )
if mem_dev > mem_cpu and model_size < mem_dev :
return torch_dev
else :
return cpu_dev
2024-08-03 17:45:19 +00:00
def maximum_vram_for_weights ( device = None ) :
2024-08-05 20:24:04 +00:00
return ( get_total_memory ( device ) * 0.88 - minimum_inference_memory ( ) )
2024-08-03 17:45:19 +00:00
2024-02-16 15:55:08 +00:00
def unet_dtype ( device = None , model_params = 0 , supported_dtypes = [ torch . float16 , torch . bfloat16 , torch . float32 ] ) :
2024-09-21 08:50:12 +00:00
if model_params < 0 :
model_params = 1000000000000000000000
2023-10-13 18:51:10 +00:00
if args . bf16_unet :
return torch . bfloat16
2023-12-11 23:36:29 +00:00
if args . fp16_unet :
return torch . float16
2023-12-04 16:10:00 +00:00
if args . fp8_e4m3fn_unet :
return torch . float8_e4m3fn
if args . fp8_e5m2_unet :
return torch . float8_e5m2
2024-08-03 17:45:19 +00:00
fp8_dtype = None
try :
for dtype in [ torch . float8_e4m3fn , torch . float8_e5m2 ] :
if dtype in supported_dtypes :
fp8_dtype = dtype
break
except :
pass
if fp8_dtype is not None :
free_model_memory = maximum_vram_for_weights ( device )
if model_params * 2 > free_model_memory :
return fp8_dtype
2024-08-07 19:00:06 +00:00
for dt in supported_dtypes :
if dt == torch . float16 and should_use_fp16 ( device = device , model_params = model_params ) :
if torch . float16 in supported_dtypes :
return torch . float16
if dt == torch . bfloat16 and should_use_bf16 ( device , model_params = model_params ) :
if torch . bfloat16 in supported_dtypes :
return torch . bfloat16
for dt in supported_dtypes :
if dt == torch . float16 and should_use_fp16 ( device = device , model_params = model_params , manual_cast = True ) :
if torch . float16 in supported_dtypes :
return torch . float16
if dt == torch . bfloat16 and should_use_bf16 ( device , model_params = model_params , manual_cast = True ) :
if torch . bfloat16 in supported_dtypes :
return torch . bfloat16
2023-10-13 18:35:21 +00:00
return torch . float32
2023-12-11 23:24:44 +00:00
# None means no manual cast
2024-02-16 15:55:08 +00:00
def unet_manual_cast ( weight_dtype , inference_device , supported_dtypes = [ torch . float16 , torch . bfloat16 , torch . float32 ] ) :
2023-12-11 23:24:44 +00:00
if weight_dtype == torch . float32 :
return None
2024-02-16 15:55:08 +00:00
fp16_supported = should_use_fp16 ( inference_device , prioritize_performance = False )
2023-12-11 23:24:44 +00:00
if fp16_supported and weight_dtype == torch . float16 :
return None
2024-02-16 15:55:08 +00:00
bf16_supported = should_use_bf16 ( inference_device )
if bf16_supported and weight_dtype == torch . bfloat16 :
return None
2024-08-21 20:38:26 +00:00
fp16_supported = should_use_fp16 ( inference_device , prioritize_performance = True )
2024-08-07 19:00:06 +00:00
for dt in supported_dtypes :
if dt == torch . float16 and fp16_supported :
return torch . float16
if dt == torch . bfloat16 and bf16_supported :
return torch . bfloat16
2024-02-16 15:55:08 +00:00
2024-08-07 19:00:06 +00:00
return torch . float32
2023-12-11 23:24:44 +00:00
2023-07-01 16:37:23 +00:00
def text_encoder_offload_device ( ) :
2023-07-03 04:08:30 +00:00
if args . gpu_only :
2023-06-15 19:21:37 +00:00
return get_torch_device ( )
else :
return torch . device ( " cpu " )
2023-07-01 16:37:23 +00:00
def text_encoder_device ( ) :
2023-07-03 04:08:30 +00:00
if args . gpu_only :
2023-07-01 16:37:23 +00:00
return get_torch_device ( )
2023-07-01 18:38:51 +00:00
elif vram_state == VRAMState . HIGH_VRAM or vram_state == VRAMState . NORMAL_VRAM :
2023-08-24 01:45:00 +00:00
if should_use_fp16 ( prioritize_performance = False ) :
2023-07-01 18:38:51 +00:00
return get_torch_device ( )
else :
return torch . device ( " cpu " )
2023-07-01 16:37:23 +00:00
else :
return torch . device ( " cpu " )
2024-08-12 03:50:01 +00:00
def text_encoder_initial_device ( load_device , offload_device , model_size = 0 ) :
if load_device == offload_device or model_size < = 1024 * 1024 * 1024 :
return offload_device
2024-08-12 04:23:29 +00:00
if is_device_mps ( load_device ) :
return offload_device
2024-08-12 03:50:01 +00:00
mem_l = get_free_memory ( load_device )
mem_o = get_free_memory ( offload_device )
if mem_l > ( mem_o * 0.5 ) and model_size * 1.2 < mem_l :
return load_device
else :
return offload_device
2023-11-17 07:56:59 +00:00
def text_encoder_dtype ( device = None ) :
if args . fp8_e4m3fn_text_enc :
return torch . float8_e4m3fn
elif args . fp8_e5m2_text_enc :
return torch . float8_e5m2
elif args . fp16_text_enc :
return torch . float16
elif args . fp32_text_enc :
return torch . float32
2023-12-11 04:00:54 +00:00
if is_device_cpu ( device ) :
return torch . float16
2024-02-02 15:02:49 +00:00
return torch . float16
2023-11-17 07:56:59 +00:00
2023-12-08 07:35:45 +00:00
def intermediate_device ( ) :
if args . gpu_only :
return get_torch_device ( )
else :
return torch . device ( " cpu " )
2023-07-01 19:22:40 +00:00
def vae_device ( ) :
2023-12-30 10:38:21 +00:00
if args . cpu_vae :
return torch . device ( " cpu " )
2023-07-01 19:22:40 +00:00
return get_torch_device ( )
def vae_offload_device ( ) :
2023-07-03 04:08:30 +00:00
if args . gpu_only :
2023-07-01 19:22:40 +00:00
return get_torch_device ( )
else :
return torch . device ( " cpu " )
2024-06-16 17:12:54 +00:00
def vae_dtype ( device = None , allowed_dtypes = [ ] ) :
global VAE_DTYPES
if args . fp16_vae :
return torch . float16
elif args . bf16_vae :
return torch . bfloat16
elif args . fp32_vae :
return torch . float32
for d in allowed_dtypes :
if d == torch . float16 and should_use_fp16 ( device , prioritize_performance = False ) :
return d
if d in VAE_DTYPES :
return d
return VAE_DTYPES [ 0 ]
2023-07-06 22:04:28 +00:00
2023-03-06 15:50:50 +00:00
def get_autocast_device ( dev ) :
if hasattr ( dev , ' type ' ) :
return dev . type
return " cuda "
2023-02-17 20:45:29 +00:00
2023-12-04 16:10:00 +00:00
def supports_dtype ( device , dtype ) : #TODO
if dtype == torch . float32 :
return True
2023-12-11 23:24:44 +00:00
if is_device_cpu ( device ) :
2023-12-04 16:10:00 +00:00
return False
if dtype == torch . float16 :
return True
if dtype == torch . bfloat16 :
return True
return False
2024-06-11 21:03:26 +00:00
def supports_cast ( device , dtype ) : #TODO
if dtype == torch . float32 :
return True
if dtype == torch . float16 :
return True
if directml_enabled : #TODO: test this
return False
if dtype == torch . bfloat16 :
return True
2024-08-01 13:42:17 +00:00
if is_device_mps ( device ) :
return False
2024-06-11 21:03:26 +00:00
if dtype == torch . float8_e4m3fn :
return True
if dtype == torch . float8_e5m2 :
return True
return False
2024-08-01 15:05:56 +00:00
def pick_weight_dtype ( dtype , fallback_dtype , device = None ) :
if dtype is None :
dtype = fallback_dtype
elif dtype_size ( dtype ) > dtype_size ( fallback_dtype ) :
dtype = fallback_dtype
if not supports_cast ( device , dtype ) :
dtype = fallback_dtype
return dtype
2023-12-22 19:24:04 +00:00
def device_supports_non_blocking ( device ) :
if is_device_mps ( device ) :
return False #pytorch bug? mps doesn't support non blocking
2024-06-13 22:51:14 +00:00
if is_intel_xpu ( ) :
return False
2024-05-30 15:07:38 +00:00
if args . deterministic : #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
return False
if directml_enabled :
return False
2024-05-22 17:56:28 +00:00
return True
def device_should_use_non_blocking ( device ) :
if not device_supports_non_blocking ( device ) :
return False
2024-04-14 16:08:58 +00:00
return False
2024-05-22 17:56:28 +00:00
# return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
2024-06-15 05:08:12 +00:00
def force_channels_last ( ) :
if args . force_channels_last :
return True
#TODO
return False
2023-12-22 19:24:04 +00:00
2024-10-17 21:25:56 +00:00
def cast_to ( weight , dtype = None , device = None , non_blocking = False , copy = False ) :
if device is None or weight . device == device :
if not copy :
if dtype is None or weight . dtype == dtype :
return weight
return weight . to ( dtype = dtype , copy = copy )
2023-09-20 21:52:41 +00:00
2024-10-17 21:25:56 +00:00
r = torch . empty_like ( weight , dtype = dtype , device = device )
r . copy_ ( weight , non_blocking = non_blocking )
return r
def cast_to_device ( tensor , device , dtype , copy = False ) :
non_blocking = device_supports_non_blocking ( device )
return cast_to ( tensor , dtype = dtype , device = device , non_blocking = non_blocking , copy = copy )
2023-12-10 06:30:35 +00:00
2023-04-05 02:22:02 +00:00
2023-03-12 19:44:16 +00:00
def xformers_enabled ( ) :
2023-04-28 18:28:57 +00:00
global directml_enabled
2023-06-03 15:05:37 +00:00
global cpu_state
if cpu_state != CPUState . GPU :
2023-03-12 19:44:16 +00:00
return False
2023-09-03 01:22:10 +00:00
if is_intel_xpu ( ) :
2023-04-28 18:28:57 +00:00
return False
if directml_enabled :
return False
2023-04-06 03:41:23 +00:00
return XFORMERS_IS_AVAILABLE
2023-03-12 19:44:16 +00:00
2023-04-05 02:22:02 +00:00
def xformers_enabled_vae ( ) :
enabled = xformers_enabled ( )
if not enabled :
return False
2023-04-09 05:31:47 +00:00
return XFORMERS_ENABLED_VAE
2023-04-05 02:22:02 +00:00
2023-03-13 16:25:19 +00:00
def pytorch_attention_enabled ( ) :
2023-05-06 23:58:54 +00:00
global ENABLE_PYTORCH_ATTENTION
2023-03-13 16:25:19 +00:00
return ENABLE_PYTORCH_ATTENTION
2023-05-06 23:58:54 +00:00
def pytorch_attention_flash_attention ( ) :
global ENABLE_PYTORCH_ATTENTION
if ENABLE_PYTORCH_ATTENTION :
#TODO: more reliable way of checking for flash attention?
2023-06-26 16:55:07 +00:00
if is_nvidia ( ) : #pytorch flash attention only works on Nvidia
2023-05-06 23:58:54 +00:00
return True
2024-06-04 21:44:14 +00:00
if is_intel_xpu ( ) :
return True
2023-05-06 23:58:54 +00:00
return False
2024-05-21 20:56:33 +00:00
def force_upcast_attention_dtype ( ) :
upcast = args . force_upcast_attention
try :
2024-08-22 17:24:21 +00:00
macos_version = tuple ( int ( n ) for n in platform . mac_ver ( ) [ 0 ] . split ( " . " ) )
2024-10-10 02:21:41 +00:00
if ( 14 , 5 ) < = macos_version < = ( 15 , 0 , 1 ) : # black image bug on recent versions of macOS
2024-05-21 20:56:33 +00:00
upcast = True
except :
pass
if upcast :
return torch . float32
else :
return None
2023-03-03 08:27:33 +00:00
def get_free_memory ( dev = None , torch_free_too = False ) :
2023-04-28 18:28:57 +00:00
global directml_enabled
2023-03-03 08:27:33 +00:00
if dev is None :
2023-03-06 15:50:50 +00:00
dev = get_torch_device ( )
2023-03-03 08:27:33 +00:00
2023-03-24 12:04:50 +00:00
if hasattr ( dev , ' type ' ) and ( dev . type == ' cpu ' or dev . type == ' mps ' ) :
2023-03-03 08:27:33 +00:00
mem_free_total = psutil . virtual_memory ( ) . available
mem_free_torch = mem_free_total
else :
2023-04-28 18:28:57 +00:00
if directml_enabled :
mem_free_total = 1024 * 1024 * 1024 #TODO
mem_free_torch = mem_free_total
2023-09-03 01:22:10 +00:00
elif is_intel_xpu ( ) :
2023-08-17 10:12:17 +00:00
stats = torch . xpu . memory_stats ( dev )
mem_active = stats [ ' active_bytes.all.current ' ]
mem_reserved = stats [ ' reserved_bytes.all.current ' ]
mem_free_torch = mem_reserved - mem_active
2024-05-12 10:36:30 +00:00
mem_free_xpu = torch . xpu . get_device_properties ( dev ) . total_memory - mem_reserved
mem_free_total = mem_free_xpu + mem_free_torch
2023-04-06 06:24:47 +00:00
else :
stats = torch . cuda . memory_stats ( dev )
mem_active = stats [ ' active_bytes.all.current ' ]
mem_reserved = stats [ ' reserved_bytes.all.current ' ]
mem_free_cuda , _ = torch . cuda . mem_get_info ( dev )
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
2023-03-03 08:27:33 +00:00
if torch_free_too :
return ( mem_free_total , mem_free_torch )
else :
return mem_free_total
2023-02-08 19:05:31 +00:00
2023-03-03 16:07:10 +00:00
def cpu_mode ( ) :
2023-06-03 15:05:37 +00:00
global cpu_state
return cpu_state == CPUState . CPU
2023-03-03 16:07:10 +00:00
2023-03-24 12:04:50 +00:00
def mps_mode ( ) :
2023-06-03 15:05:37 +00:00
global cpu_state
return cpu_state == CPUState . MPS
2023-03-24 12:04:50 +00:00
2024-02-16 02:10:10 +00:00
def is_device_type ( device , type ) :
2023-07-01 17:22:51 +00:00
if hasattr ( device , ' type ' ) :
2024-02-16 02:10:10 +00:00
if ( device . type == type ) :
2023-07-04 06:09:02 +00:00
return True
return False
2024-02-16 02:10:10 +00:00
def is_device_cpu ( device ) :
return is_device_type ( device , ' cpu ' )
2023-07-04 06:09:02 +00:00
def is_device_mps ( device ) :
2024-02-16 02:10:10 +00:00
return is_device_type ( device , ' mps ' )
def is_device_cuda ( device ) :
return is_device_type ( device , ' cuda ' )
2023-07-01 17:22:51 +00:00
2024-02-04 18:23:43 +00:00
def should_use_fp16 ( device = None , model_params = 0 , prioritize_performance = True , manual_cast = False ) :
2023-04-28 18:28:57 +00:00
global directml_enabled
2023-08-24 01:38:28 +00:00
if device is not None :
if is_device_cpu ( device ) :
return False
2023-07-02 02:42:35 +00:00
if FORCE_FP16 :
return True
2024-02-19 17:00:48 +00:00
if device is not None :
2023-08-24 01:38:28 +00:00
if is_device_mps ( device ) :
2024-02-19 17:00:48 +00:00
return True
2023-07-01 16:37:23 +00:00
2023-04-07 04:27:54 +00:00
if FORCE_FP32 :
return False
2023-04-28 18:28:57 +00:00
if directml_enabled :
return False
2024-02-19 17:00:48 +00:00
if mps_mode ( ) :
return True
if cpu_mode ( ) :
return False
2023-03-03 16:07:10 +00:00
2023-09-03 01:22:10 +00:00
if is_intel_xpu ( ) :
2023-08-20 18:56:47 +00:00
return True
2024-02-05 01:53:35 +00:00
if torch . version . hip :
2023-03-03 16:07:10 +00:00
return True
2024-08-20 04:31:04 +00:00
props = torch . cuda . get_device_properties ( device )
2024-02-05 01:53:35 +00:00
if props . major > = 8 :
return True
2023-07-02 13:37:31 +00:00
if props . major < 6 :
return False
2024-08-22 03:23:50 +00:00
#FP16 is confirmed working on a 1080 (GP104) and on latest pytorch actually seems faster than fp32
2024-03-02 22:16:31 +00:00
nvidia_10_series = [ " 1080 " , " 1070 " , " titan x " , " p3000 " , " p3200 " , " p4000 " , " p4200 " , " p5000 " , " p5200 " , " p6000 " , " 1060 " , " 1050 " , " p40 " , " p100 " , " p6 " , " p4 " ]
2023-07-02 13:37:31 +00:00
for x in nvidia_10_series :
if x in props . name . lower ( ) :
2024-09-01 21:29:31 +00:00
if WINDOWS or manual_cast :
return True
else :
return False #weird linux behavior where fp32 is faster
2023-07-02 13:37:31 +00:00
2024-08-22 03:23:50 +00:00
if manual_cast :
2024-08-03 17:45:19 +00:00
free_model_memory = maximum_vram_for_weights ( device )
2023-08-24 01:45:00 +00:00
if ( not prioritize_performance ) or model_params * 4 > free_model_memory :
2023-07-02 13:37:31 +00:00
return True
2023-03-03 16:07:10 +00:00
if props . major < 7 :
return False
2023-07-02 13:37:31 +00:00
#FP16 is just broken on these cards
2023-10-16 20:46:41 +00:00
nvidia_16_series = [ " 1660 " , " 1650 " , " 1630 " , " T500 " , " T550 " , " T600 " , " MX550 " , " MX450 " , " CMP 30HX " , " T2000 " , " T1000 " , " T1200 " ]
2023-03-03 16:07:10 +00:00
for x in nvidia_16_series :
if x in props . name :
return False
return True
2024-02-17 13:13:17 +00:00
def should_use_bf16 ( device = None , model_params = 0 , prioritize_performance = True , manual_cast = False ) :
if device is not None :
if is_device_cpu ( device ) : #TODO ? bf16 works on CPU but is extremely slow
return False
2024-08-01 13:42:17 +00:00
if device is not None :
2024-02-17 13:13:17 +00:00
if is_device_mps ( device ) :
2024-08-01 13:42:17 +00:00
return True
2024-02-17 13:13:17 +00:00
2024-02-17 04:01:54 +00:00
if FORCE_FP32 :
return False
2024-02-17 13:13:17 +00:00
if directml_enabled :
return False
2024-08-01 20:18:14 +00:00
if mps_mode ( ) :
return True
if cpu_mode ( ) :
2024-02-17 13:13:17 +00:00
return False
2024-02-16 15:55:08 +00:00
if is_intel_xpu ( ) :
return True
2024-08-20 04:31:04 +00:00
props = torch . cuda . get_device_properties ( device )
2024-02-16 15:55:08 +00:00
if props . major > = 8 :
return True
2024-02-17 13:13:17 +00:00
bf16_works = torch . cuda . is_bf16_supported ( )
if bf16_works or manual_cast :
2024-08-03 17:45:19 +00:00
free_model_memory = maximum_vram_for_weights ( device )
2024-02-17 13:13:17 +00:00
if ( not prioritize_performance ) or model_params * 4 > free_model_memory :
return True
2024-02-16 15:55:08 +00:00
return False
2024-08-20 15:49:33 +00:00
def supports_fp8_compute ( device = None ) :
2024-10-09 23:43:17 +00:00
if not is_nvidia ( ) :
return False
2024-08-20 15:49:33 +00:00
props = torch . cuda . get_device_properties ( device )
if props . major > = 9 :
return True
if props . major < 8 :
return False
if props . minor < 9 :
return False
2024-10-09 23:43:17 +00:00
if int ( torch_version [ 0 ] ) < 2 or ( int ( torch_version [ 0 ] ) == 2 and int ( torch_version [ 2 ] ) < 3 ) :
return False
if WINDOWS :
if ( int ( torch_version [ 0 ] ) == 2 and int ( torch_version [ 2 ] ) < 4 ) :
return False
2024-08-20 15:49:33 +00:00
return True
2023-09-04 04:58:18 +00:00
def soft_empty_cache ( force = False ) :
2023-06-03 15:05:37 +00:00
global cpu_state
if cpu_state == CPUState . MPS :
2023-06-01 07:52:51 +00:00
torch . mps . empty_cache ( )
2023-09-03 01:22:10 +00:00
elif is_intel_xpu ( ) :
2023-04-15 15:19:07 +00:00
torch . xpu . empty_cache ( )
elif torch . cuda . is_available ( ) :
2023-09-04 04:58:18 +00:00
if force or is_nvidia ( ) : #This seems to make things worse on ROCm so I only do it for cuda
2023-04-15 15:19:07 +00:00
torch . cuda . empty_cache ( )
torch . cuda . ipc_collect ( )
2023-12-23 09:25:06 +00:00
def unload_all_models ( ) :
free_memory ( 1e30 , get_torch_device ( ) )
2023-12-22 19:24:04 +00:00
def resolve_lowvram_weight ( weight , model , key ) : #TODO: remove
2024-06-01 16:47:31 +00:00
print ( " WARNING: The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it. " )
2023-08-26 15:52:07 +00:00
return weight
2023-03-02 19:42:03 +00:00
#TODO: might be cleaner to put this somewhere else
import threading
class InterruptProcessingException ( Exception ) :
pass
interrupt_processing_mutex = threading . RLock ( )
interrupt_processing = False
def interrupt_current_processing ( value = True ) :
global interrupt_processing
global interrupt_processing_mutex
with interrupt_processing_mutex :
interrupt_processing = value
def processing_interrupted ( ) :
global interrupt_processing
global interrupt_processing_mutex
with interrupt_processing_mutex :
return interrupt_processing
def throw_exception_if_processing_interrupted ( ) :
global interrupt_processing
global interrupt_processing_mutex
with interrupt_processing_mutex :
if interrupt_processing :
interrupt_processing = False
raise InterruptProcessingException ( )