diff --git a/comfy/model_base.py b/comfy/model_base.py index 5bfcc391d..bab7b9b34 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -4,6 +4,7 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep import comfy.model_management import comfy.conds +import comfy.ops from enum import Enum import contextlib from . import utils @@ -41,9 +42,14 @@ class BaseModel(torch.nn.Module): unet_config = model_config.unet_config self.latent_format = model_config.latent_format self.model_config = model_config + self.manual_cast_dtype = model_config.manual_cast_dtype if not unet_config.get("disable_unet_model_creation", False): - self.diffusion_model = UNetModel(**unet_config, device=device) + if self.manual_cast_dtype is not None: + operations = comfy.ops.manual_cast + else: + operations = comfy.ops + self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations) self.model_type = model_type self.model_sampling = model_sampling(model_config, model_type) @@ -63,11 +69,8 @@ class BaseModel(torch.nn.Module): context = c_crossattn dtype = self.get_dtype() - if comfy.model_management.supports_dtype(xc.device, dtype): - precision_scope = lambda a: contextlib.nullcontext(a) - else: - precision_scope = torch.autocast - dtype = torch.float32 + if self.manual_cast_dtype is not None: + dtype = self.manual_cast_dtype xc = xc.to(dtype) t = self.model_sampling.timestep(t).float() @@ -79,9 +82,7 @@ class BaseModel(torch.nn.Module): extra = extra.to(dtype) extra_conds[o] = extra - with precision_scope(comfy.model_management.get_autocast_device(xc.device)): - model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() - + model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() return self.model_sampling.calculate_denoised(sigma, model_output, x) def get_dtype(self): diff --git a/comfy/model_management.py b/comfy/model_management.py index a6c8fb352..fe0374a8b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -474,6 +474,20 @@ def unet_dtype(device=None, model_params=0): return torch.float16 return torch.float32 +# None means no manual cast +def unet_manual_cast(weight_dtype, inference_device): + if weight_dtype == torch.float32: + return None + + fp16_supported = comfy.model_management.should_use_fp16(inference_device, prioritize_performance=False) + if fp16_supported and weight_dtype == torch.float16: + return None + + if fp16_supported: + return torch.float16 + else: + return torch.float32 + def text_encoder_offload_device(): if args.gpu_only: return get_torch_device() @@ -538,7 +552,7 @@ def get_autocast_device(dev): def supports_dtype(device, dtype): #TODO if dtype == torch.float32: return True - if torch.device("cpu") == device: + if is_device_cpu(device): return False if dtype == torch.float16: return True diff --git a/comfy/ops.py b/comfy/ops.py index e48568409..a67bc809f 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -62,6 +62,15 @@ class manual_cast: weight, bias = cast_bias_weight(self, input) return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + @classmethod + def conv_nd(s, dims, *args, **kwargs): + if dims == 2: + return s.Conv2d(*args, **kwargs) + elif dims == 3: + return s.Conv3d(*args, **kwargs) + else: + raise ValueError(f"unsupported dimensions: {dims}") + @contextmanager def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way old_torch_nn_linear = torch.nn.Linear diff --git a/comfy/sd.py b/comfy/sd.py index 43e201d36..8c056e4ea 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -433,11 +433,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.") unet_dtype = model_management.unet_dtype(model_params=parameters) + load_device = model_management.get_torch_device() + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) class WeightsLoader(torch.nn.Module): pass model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype) + model_config.set_manual_cast(manual_cast_dtype) + if model_config is None: raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) @@ -470,7 +474,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o print("left over keys:", left_over) if output_model: - model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device) + model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device) if inital_load_device != torch.device("cpu"): print("loaded straight to GPU") model_management.load_model_gpu(model_patcher) @@ -481,6 +485,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o def load_unet_state_dict(sd): #load unet in diffusers format parameters = comfy.utils.calculate_parameters(sd) unet_dtype = model_management.unet_dtype(model_params=parameters) + load_device = model_management.get_torch_device() + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) + if "input_blocks.0.0.weight" in sd: #ldm model_config = model_detection.model_config_from_unet(sd, "", unet_dtype) if model_config is None: @@ -501,13 +508,14 @@ def load_unet_state_dict(sd): #load unet in diffusers format else: print(diffusers_keys[k], k) offload_device = model_management.unet_offload_device() + model_config.set_manual_cast(manual_cast_dtype) model = model_config.get_model(new_sd, "") model = model.to(offload_device) model.load_model_weights(new_sd, "") left_over = sd.keys() if len(left_over) > 0: print("left over keys in unet:", left_over) - return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) + return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device) def load_unet(unet_path): sd = comfy.utils.load_torch_file(unet_path) diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 3412cfea0..49087d23e 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -22,6 +22,8 @@ class BASE: sampling_settings = {} latent_format = latent_formats.LatentFormat + manual_cast_dtype = None + @classmethod def matches(s, unet_config): for k in s.unet_config: @@ -71,3 +73,5 @@ class BASE: replace_prefix = {"": "first_stage_model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) + def set_manual_cast(self, manual_cast_dtype): + self.manual_cast_dtype = manual_cast_dtype