diff --git a/comfy/model_base.py b/comfy/model_base.py
index cb694964..830bcc68 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -1,3 +1,21 @@
+"""
+ This file is part of ComfyUI.
+ Copyright (C) 2024 Comfy
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+
import torch
import logging
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
@@ -77,10 +95,13 @@ class BaseModel(torch.nn.Module):
self.device = device
if not unet_config.get("disable_unet_model_creation", False):
- if self.manual_cast_dtype is not None:
- operations = comfy.ops.manual_cast
+ if model_config.custom_operations is None:
+ if self.manual_cast_dtype is not None:
+ operations = comfy.ops.manual_cast
+ else:
+ operations = comfy.ops.disable_weight_init
else:
- operations = comfy.ops.disable_weight_init
+ operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
if comfy.model_management.force_channels_last():
self.diffusion_model.to(memory_format=torch.channels_last)
diff --git a/comfy/sd.py b/comfy/sd.py
index 68917324..ee91ad53 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -498,14 +498,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae)
-def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
+def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}):
sd = comfy.utils.load_torch_file(ckpt_path)
- out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model)
+ out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options)
if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
return out
-def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
+def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}):
clip = None
clipvision = None
vae = None
@@ -525,7 +525,12 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
if weight_dtype is not None:
unet_weight_dtype.append(weight_dtype)
- unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
+ model_config.custom_operations = model_options.get("custom_operations", None)
+ unet_dtype = model_options.get("weight_dtype", None)
+
+ if unet_dtype is None:
+ unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
+
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py
index bc0a7e31..7a2152f9 100644
--- a/comfy/supported_models_base.py
+++ b/comfy/supported_models_base.py
@@ -1,3 +1,21 @@
+"""
+ This file is part of ComfyUI.
+ Copyright (C) 2024 Comfy
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+
import torch
from . import model_base
from . import utils
@@ -30,6 +48,7 @@ class BASE:
memory_usage_factor = 2.0
manual_cast_dtype = None
+ custom_operations = None
@classmethod
def matches(s, unet_config, state_dict=None):