mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 05:57:20 +00:00
Use fp16 if checkpoint weights are fp16 and the model supports it.
This commit is contained in:
parent
f4dac8ab6f
commit
1804397952
@ -418,10 +418,7 @@ def controlnet_config(sd, model_options={}):
|
|||||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||||
|
|
||||||
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
|
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
|
||||||
if weight_dtype is not None:
|
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes, weight_dtype=weight_dtype)
|
||||||
supported_inference_dtypes.append(weight_dtype)
|
|
||||||
|
|
||||||
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
|
|
||||||
|
|
||||||
load_device = comfy.model_management.get_torch_device()
|
load_device = comfy.model_management.get_torch_device()
|
||||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||||
@ -689,10 +686,7 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
|||||||
if supported_inference_dtypes is None:
|
if supported_inference_dtypes is None:
|
||||||
supported_inference_dtypes = [comfy.model_management.unet_dtype()]
|
supported_inference_dtypes = [comfy.model_management.unet_dtype()]
|
||||||
|
|
||||||
if weight_dtype is not None:
|
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes, weight_dtype=weight_dtype)
|
||||||
supported_inference_dtypes.append(weight_dtype)
|
|
||||||
|
|
||||||
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
|
|
||||||
|
|
||||||
load_device = comfy.model_management.get_torch_device()
|
load_device = comfy.model_management.get_torch_device()
|
||||||
|
|
||||||
|
@ -674,7 +674,7 @@ def unet_inital_load_device(parameters, dtype):
|
|||||||
def maximum_vram_for_weights(device=None):
|
def maximum_vram_for_weights(device=None):
|
||||||
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
|
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
|
||||||
|
|
||||||
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32], weight_dtype=None):
|
||||||
if model_params < 0:
|
if model_params < 0:
|
||||||
model_params = 1000000000000000000000
|
model_params = 1000000000000000000000
|
||||||
if args.fp32_unet:
|
if args.fp32_unet:
|
||||||
@ -692,10 +692,8 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
|||||||
|
|
||||||
fp8_dtype = None
|
fp8_dtype = None
|
||||||
try:
|
try:
|
||||||
for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||||
if dtype in supported_dtypes:
|
fp8_dtype = weight_dtype
|
||||||
fp8_dtype = dtype
|
|
||||||
break
|
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -707,7 +705,7 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
|||||||
if model_params * 2 > free_model_memory:
|
if model_params * 2 > free_model_memory:
|
||||||
return fp8_dtype
|
return fp8_dtype
|
||||||
|
|
||||||
if PRIORITIZE_FP16:
|
if PRIORITIZE_FP16 or weight_dtype == torch.float16:
|
||||||
if torch.float16 in supported_dtypes and should_use_fp16(device=device, model_params=model_params):
|
if torch.float16 in supported_dtypes and should_use_fp16(device=device, model_params=model_params):
|
||||||
return torch.float16
|
return torch.float16
|
||||||
|
|
||||||
|
12
comfy/sd.py
12
comfy/sd.py
@ -896,14 +896,14 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||||
if weight_dtype is not None and model_config.scaled_fp8 is None:
|
if model_config.scaled_fp8 is not None:
|
||||||
unet_weight_dtype.append(weight_dtype)
|
weight_dtype = None
|
||||||
|
|
||||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||||
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
|
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
|
||||||
|
|
||||||
if unet_dtype is None:
|
if unet_dtype is None:
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
|
||||||
|
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
@ -994,11 +994,11 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
|
|||||||
|
|
||||||
offload_device = model_management.unet_offload_device()
|
offload_device = model_management.unet_offload_device()
|
||||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||||
if weight_dtype is not None and model_config.scaled_fp8 is None:
|
if model_config.scaled_fp8 is not None:
|
||||||
unet_weight_dtype.append(weight_dtype)
|
weight_dtype = None
|
||||||
|
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
|
||||||
else:
|
else:
|
||||||
unet_dtype = dtype
|
unet_dtype = dtype
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user