diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index 251483131..f982d648c 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -34,8 +34,7 @@ class ControlNet(nn.Module): dims=2, num_classes=None, use_checkpoint=False, - use_fp16=False, - use_bf16=False, + dtype=torch.float32, num_heads=-1, num_head_channels=-1, num_heads_upsample=-1, @@ -108,8 +107,7 @@ class ControlNet(nn.Module): self.conv_resample = conv_resample self.num_classes = num_classes self.use_checkpoint = use_checkpoint - self.dtype = th.float16 if use_fp16 else th.float32 - self.dtype = th.bfloat16 if use_bf16 else self.dtype + self.dtype = dtype self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample diff --git a/comfy/controlnet.py b/comfy/controlnet.py index ea219c7e5..73a40acfa 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -292,8 +292,8 @@ def load_controlnet(ckpt_path, model=None): controlnet_config = None if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format - use_fp16 = comfy.model_management.should_use_fp16() - controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, use_fp16) + unet_dtype = comfy.model_management.unet_dtype() + controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype) diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config) diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" @@ -353,8 +353,8 @@ def load_controlnet(ckpt_path, model=None): return net if controlnet_config is None: - use_fp16 = comfy.model_management.should_use_fp16() - controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16, True).unet_config + unet_dtype = comfy.model_management.unet_dtype() + controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) @@ -383,8 +383,7 @@ def load_controlnet(ckpt_path, model=None): missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) print(missing, unexpected) - if use_fp16: - control_model = control_model.half() + control_model = control_model.to(unet_dtype) global_average_pooling = False filename = os.path.splitext(ckpt_path)[0] diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index b42637c82..bf58a4045 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -296,8 +296,7 @@ class UNetModel(nn.Module): dims=2, num_classes=None, use_checkpoint=False, - use_fp16=False, - use_bf16=False, + dtype=th.float32, num_heads=-1, num_head_channels=-1, num_heads_upsample=-1, @@ -370,8 +369,7 @@ class UNetModel(nn.Module): self.conv_resample = conv_resample self.num_classes = num_classes self.use_checkpoint = use_checkpoint - self.dtype = th.float16 if use_fp16 else th.float32 - self.dtype = th.bfloat16 if use_bf16 else self.dtype + self.dtype = dtype self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 787c78575..0ff2e7fb5 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -14,7 +14,7 @@ def count_blocks(state_dict_keys, prefix_string): count += 1 return count -def detect_unet_config(state_dict, key_prefix, use_fp16): +def detect_unet_config(state_dict, key_prefix, dtype): state_dict_keys = list(state_dict.keys()) unet_config = { @@ -32,7 +32,7 @@ def detect_unet_config(state_dict, key_prefix, use_fp16): else: unet_config["adm_in_channels"] = None - unet_config["use_fp16"] = use_fp16 + unet_config["dtype"] = dtype model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0] in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1] @@ -116,15 +116,15 @@ def model_config_from_unet_config(unet_config): print("no match", unet_config) return None -def model_config_from_unet(state_dict, unet_key_prefix, use_fp16, use_base_if_no_match=False): - unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16) +def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_match=False): + unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype) model_config = model_config_from_unet_config(unet_config) if model_config is None and use_base_if_no_match: return comfy.supported_models_base.BASE(unet_config) else: return model_config -def unet_config_from_diffusers_unet(state_dict, use_fp16): +def unet_config_from_diffusers_unet(state_dict, dtype): match = {} attention_resolutions = [] @@ -147,47 +147,47 @@ def unet_config_from_diffusers_unet(state_dict, use_fp16): match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1] SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 384, + 'num_classes': 'sequential', 'adm_in_channels': 2560, 'dtype': dtype, 'in_channels': 4, 'model_channels': 384, 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64} SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, + 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64} SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, + 'num_classes': 'sequential', 'adm_in_channels': 2048, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64} SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, + 'num_classes': 'sequential', 'adm_in_channels': 1536, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, + 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8} SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1} SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 9, 'model_channels': 320, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} @@ -203,8 +203,8 @@ def unet_config_from_diffusers_unet(state_dict, use_fp16): return unet_config return None -def model_config_from_diffusers_unet(state_dict, use_fp16): - unet_config = unet_config_from_diffusers_unet(state_dict, use_fp16) +def model_config_from_diffusers_unet(state_dict, dtype): + unet_config = unet_config_from_diffusers_unet(state_dict, dtype) if unet_config is not None: return model_config_from_unet_config(unet_config) return None diff --git a/comfy/model_management.py b/comfy/model_management.py index 3c390d9ca..1161c2447 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -448,6 +448,11 @@ def unet_inital_load_device(parameters, dtype): else: return cpu_dev +def unet_dtype(device=None, model_params=0): + if should_use_fp16(device=device, model_params=model_params): + return torch.float16 + return torch.float32 + def text_encoder_offload_device(): if args.gpu_only: return get_torch_device() diff --git a/comfy/sd.py b/comfy/sd.py index cfd6fb3cb..fd8b94df8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -327,7 +327,9 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl if "params" in model_config_params["unet_config"]: unet_config = model_config_params["unet_config"]["params"] if "use_fp16" in unet_config: - fp16 = unet_config["use_fp16"] + fp16 = unet_config.pop("use_fp16") + if fp16: + unet_config["dtype"] = torch.float16 noise_aug_config = None if "noise_aug_config" in model_config_params: @@ -405,12 +407,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o clip_target = None parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.") - fp16 = model_management.should_use_fp16(model_params=parameters) + unet_dtype = model_management.unet_dtype(model_params=parameters) class WeightsLoader(torch.nn.Module): pass - model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", fp16) + model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype) if model_config is None: raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) @@ -418,12 +420,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if output_clipvision: clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True) - dtype = torch.float32 - if fp16: - dtype = torch.float16 - if output_model: - inital_load_device = model_management.unet_inital_load_device(parameters, dtype) + inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) offload_device = model_management.unet_offload_device() model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device) model.load_model_weights(sd, "model.diffusion_model.") @@ -458,15 +456,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o def load_unet(unet_path): #load unet in diffusers format sd = comfy.utils.load_torch_file(unet_path) parameters = comfy.utils.calculate_parameters(sd) - fp16 = model_management.should_use_fp16(model_params=parameters) + unet_dtype = model_management.unet_dtype(model_params=parameters) if "input_blocks.0.0.weight" in sd: #ldm - model_config = model_detection.model_config_from_unet(sd, "", fp16) + model_config = model_detection.model_config_from_unet(sd, "", unet_dtype) if model_config is None: raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) new_sd = sd else: #diffusers - model_config = model_detection.model_config_from_diffusers_unet(sd, fp16) + model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype) if model_config is None: print("ERROR UNSUPPORTED UNET", unet_path) return None