diff --git a/comfy/model_base.py b/comfy/model_base.py index cb694964..830bcc68 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1,3 +1,21 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Comfy + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + import torch import logging from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep @@ -77,10 +95,13 @@ class BaseModel(torch.nn.Module): self.device = device if not unet_config.get("disable_unet_model_creation", False): - if self.manual_cast_dtype is not None: - operations = comfy.ops.manual_cast + if model_config.custom_operations is None: + if self.manual_cast_dtype is not None: + operations = comfy.ops.manual_cast + else: + operations = comfy.ops.disable_weight_init else: - operations = comfy.ops.disable_weight_init + operations = model_config.custom_operations self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) if comfy.model_management.force_channels_last(): self.diffusion_model.to(memory_format=torch.channels_last) diff --git a/comfy/sd.py b/comfy/sd.py index 68917324..ee91ad53 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -498,14 +498,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl return (model, clip, vae) -def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True): +def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}): sd = comfy.utils.load_torch_file(ckpt_path) - out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model) + out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options) if out is None: raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) return out -def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True): +def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}): clip = None clipvision = None vae = None @@ -525,7 +525,12 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c if weight_dtype is not None: unet_weight_dtype.append(weight_dtype) - unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype) + model_config.custom_operations = model_options.get("custom_operations", None) + unet_dtype = model_options.get("weight_dtype", None) + + if unet_dtype is None: + unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype) + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index bc0a7e31..7a2152f9 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -1,3 +1,21 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Comfy + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + import torch from . import model_base from . import utils @@ -30,6 +48,7 @@ class BASE: memory_usage_factor = 2.0 manual_cast_dtype = None + custom_operations = None @classmethod def matches(s, unet_config, state_dict=None):