mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
0075c6d096
This change allows supports for diffusion models where all the linears are scaled fp8 while the other weights are the original precision.
120 lines
4.2 KiB
Python
120 lines
4.2 KiB
Python
"""
|
|
This file is part of ComfyUI.
|
|
Copyright (C) 2024 Comfy
|
|
|
|
This program is free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation, either version 3 of the License, or
|
|
(at your option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
"""
|
|
|
|
import torch
|
|
from . import model_base
|
|
from . import utils
|
|
from . import latent_formats
|
|
|
|
class ClipTarget:
|
|
def __init__(self, tokenizer, clip):
|
|
self.clip = clip
|
|
self.tokenizer = tokenizer
|
|
self.params = {}
|
|
|
|
class BASE:
|
|
unet_config = {}
|
|
unet_extra_config = {
|
|
"num_heads": -1,
|
|
"num_head_channels": 64,
|
|
}
|
|
|
|
required_keys = {}
|
|
|
|
clip_prefix = []
|
|
clip_vision_prefix = None
|
|
noise_aug_config = None
|
|
sampling_settings = {}
|
|
latent_format = latent_formats.LatentFormat
|
|
vae_key_prefix = ["first_stage_model."]
|
|
text_encoder_key_prefix = ["cond_stage_model."]
|
|
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
|
|
|
memory_usage_factor = 2.0
|
|
|
|
manual_cast_dtype = None
|
|
custom_operations = None
|
|
scaled_fp8 = None
|
|
optimizations = {"fp8": False}
|
|
|
|
@classmethod
|
|
def matches(s, unet_config, state_dict=None):
|
|
for k in s.unet_config:
|
|
if k not in unet_config or s.unet_config[k] != unet_config[k]:
|
|
return False
|
|
if state_dict is not None:
|
|
for k in s.required_keys:
|
|
if k not in state_dict:
|
|
return False
|
|
return True
|
|
|
|
def model_type(self, state_dict, prefix=""):
|
|
return model_base.ModelType.EPS
|
|
|
|
def inpaint_model(self):
|
|
return self.unet_config["in_channels"] > 4
|
|
|
|
def __init__(self, unet_config):
|
|
self.unet_config = unet_config.copy()
|
|
self.sampling_settings = self.sampling_settings.copy()
|
|
self.latent_format = self.latent_format()
|
|
self.optimizations = self.optimizations.copy()
|
|
for x in self.unet_extra_config:
|
|
self.unet_config[x] = self.unet_extra_config[x]
|
|
|
|
def get_model(self, state_dict, prefix="", device=None):
|
|
if self.noise_aug_config is not None:
|
|
out = model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix), device=device)
|
|
else:
|
|
out = model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix), device=device)
|
|
if self.inpaint_model():
|
|
out.set_inpaint()
|
|
return out
|
|
|
|
def process_clip_state_dict(self, state_dict):
|
|
state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
|
|
return state_dict
|
|
|
|
def process_unet_state_dict(self, state_dict):
|
|
return state_dict
|
|
|
|
def process_vae_state_dict(self, state_dict):
|
|
return state_dict
|
|
|
|
def process_clip_state_dict_for_saving(self, state_dict):
|
|
replace_prefix = {"": self.text_encoder_key_prefix[0]}
|
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
|
|
|
def process_clip_vision_state_dict_for_saving(self, state_dict):
|
|
replace_prefix = {}
|
|
if self.clip_vision_prefix is not None:
|
|
replace_prefix[""] = self.clip_vision_prefix
|
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
|
|
|
def process_unet_state_dict_for_saving(self, state_dict):
|
|
replace_prefix = {"": "model.diffusion_model."}
|
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
|
|
|
def process_vae_state_dict_for_saving(self, state_dict):
|
|
replace_prefix = {"": self.vae_key_prefix[0]}
|
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
|
|
|
def set_inference_dtype(self, dtype, manual_cast_dtype):
|
|
self.unet_config['dtype'] = dtype
|
|
self.manual_cast_dtype = manual_cast_dtype
|