mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Refactor code so model can be a dtype other than fp32 or fp16.
This commit is contained in:
parent
fee3b0c070
commit
9a55dadb4c
@ -34,8 +34,7 @@ class ControlNet(nn.Module):
|
|||||||
dims=2,
|
dims=2,
|
||||||
num_classes=None,
|
num_classes=None,
|
||||||
use_checkpoint=False,
|
use_checkpoint=False,
|
||||||
use_fp16=False,
|
dtype=torch.float32,
|
||||||
use_bf16=False,
|
|
||||||
num_heads=-1,
|
num_heads=-1,
|
||||||
num_head_channels=-1,
|
num_head_channels=-1,
|
||||||
num_heads_upsample=-1,
|
num_heads_upsample=-1,
|
||||||
@ -108,8 +107,7 @@ class ControlNet(nn.Module):
|
|||||||
self.conv_resample = conv_resample
|
self.conv_resample = conv_resample
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.use_checkpoint = use_checkpoint
|
self.use_checkpoint = use_checkpoint
|
||||||
self.dtype = th.float16 if use_fp16 else th.float32
|
self.dtype = dtype
|
||||||
self.dtype = th.bfloat16 if use_bf16 else self.dtype
|
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.num_head_channels = num_head_channels
|
self.num_head_channels = num_head_channels
|
||||||
self.num_heads_upsample = num_heads_upsample
|
self.num_heads_upsample = num_heads_upsample
|
||||||
|
@ -292,8 +292,8 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
|
|
||||||
controlnet_config = None
|
controlnet_config = None
|
||||||
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
|
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
|
||||||
use_fp16 = comfy.model_management.should_use_fp16()
|
unet_dtype = comfy.model_management.unet_dtype()
|
||||||
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, use_fp16)
|
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
|
||||||
diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
|
diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
|
||||||
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
||||||
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
|
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
|
||||||
@ -353,8 +353,8 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
return net
|
return net
|
||||||
|
|
||||||
if controlnet_config is None:
|
if controlnet_config is None:
|
||||||
use_fp16 = comfy.model_management.should_use_fp16()
|
unet_dtype = comfy.model_management.unet_dtype()
|
||||||
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16, True).unet_config
|
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
|
||||||
controlnet_config.pop("out_channels")
|
controlnet_config.pop("out_channels")
|
||||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||||
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
||||||
@ -383,8 +383,7 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
||||||
print(missing, unexpected)
|
print(missing, unexpected)
|
||||||
|
|
||||||
if use_fp16:
|
control_model = control_model.to(unet_dtype)
|
||||||
control_model = control_model.half()
|
|
||||||
|
|
||||||
global_average_pooling = False
|
global_average_pooling = False
|
||||||
filename = os.path.splitext(ckpt_path)[0]
|
filename = os.path.splitext(ckpt_path)[0]
|
||||||
|
@ -296,8 +296,7 @@ class UNetModel(nn.Module):
|
|||||||
dims=2,
|
dims=2,
|
||||||
num_classes=None,
|
num_classes=None,
|
||||||
use_checkpoint=False,
|
use_checkpoint=False,
|
||||||
use_fp16=False,
|
dtype=th.float32,
|
||||||
use_bf16=False,
|
|
||||||
num_heads=-1,
|
num_heads=-1,
|
||||||
num_head_channels=-1,
|
num_head_channels=-1,
|
||||||
num_heads_upsample=-1,
|
num_heads_upsample=-1,
|
||||||
@ -370,8 +369,7 @@ class UNetModel(nn.Module):
|
|||||||
self.conv_resample = conv_resample
|
self.conv_resample = conv_resample
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.use_checkpoint = use_checkpoint
|
self.use_checkpoint = use_checkpoint
|
||||||
self.dtype = th.float16 if use_fp16 else th.float32
|
self.dtype = dtype
|
||||||
self.dtype = th.bfloat16 if use_bf16 else self.dtype
|
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.num_head_channels = num_head_channels
|
self.num_head_channels = num_head_channels
|
||||||
self.num_heads_upsample = num_heads_upsample
|
self.num_heads_upsample = num_heads_upsample
|
||||||
|
@ -14,7 +14,7 @@ def count_blocks(state_dict_keys, prefix_string):
|
|||||||
count += 1
|
count += 1
|
||||||
return count
|
return count
|
||||||
|
|
||||||
def detect_unet_config(state_dict, key_prefix, use_fp16):
|
def detect_unet_config(state_dict, key_prefix, dtype):
|
||||||
state_dict_keys = list(state_dict.keys())
|
state_dict_keys = list(state_dict.keys())
|
||||||
|
|
||||||
unet_config = {
|
unet_config = {
|
||||||
@ -32,7 +32,7 @@ def detect_unet_config(state_dict, key_prefix, use_fp16):
|
|||||||
else:
|
else:
|
||||||
unet_config["adm_in_channels"] = None
|
unet_config["adm_in_channels"] = None
|
||||||
|
|
||||||
unet_config["use_fp16"] = use_fp16
|
unet_config["dtype"] = dtype
|
||||||
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
|
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
|
||||||
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
|
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
|
||||||
|
|
||||||
@ -116,15 +116,15 @@ def model_config_from_unet_config(unet_config):
|
|||||||
print("no match", unet_config)
|
print("no match", unet_config)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16, use_base_if_no_match=False):
|
def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_match=False):
|
||||||
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16)
|
unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype)
|
||||||
model_config = model_config_from_unet_config(unet_config)
|
model_config = model_config_from_unet_config(unet_config)
|
||||||
if model_config is None and use_base_if_no_match:
|
if model_config is None and use_base_if_no_match:
|
||||||
return comfy.supported_models_base.BASE(unet_config)
|
return comfy.supported_models_base.BASE(unet_config)
|
||||||
else:
|
else:
|
||||||
return model_config
|
return model_config
|
||||||
|
|
||||||
def unet_config_from_diffusers_unet(state_dict, use_fp16):
|
def unet_config_from_diffusers_unet(state_dict, dtype):
|
||||||
match = {}
|
match = {}
|
||||||
attention_resolutions = []
|
attention_resolutions = []
|
||||||
|
|
||||||
@ -147,47 +147,47 @@ def unet_config_from_diffusers_unet(state_dict, use_fp16):
|
|||||||
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
|
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
|
||||||
|
|
||||||
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
||||||
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
||||||
|
|
||||||
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 384,
|
'num_classes': 'sequential', 'adm_in_channels': 2560, 'dtype': dtype, 'in_channels': 4, 'model_channels': 384,
|
||||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
|
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64}
|
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64}
|
||||||
|
|
||||||
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||||
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
|
||||||
|
|
||||||
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
'num_classes': 'sequential', 'adm_in_channels': 2048, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||||
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
|
||||||
|
|
||||||
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
'num_classes': 'sequential', 'adm_in_channels': 1536, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||||
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
||||||
|
|
||||||
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||||
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8}
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8}
|
||||||
|
|
||||||
SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||||
'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4],
|
'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4],
|
||||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
||||||
|
|
||||||
SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||||
'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4],
|
'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4],
|
||||||
'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1}
|
'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1}
|
||||||
|
|
||||||
SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 9, 'model_channels': 320,
|
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320,
|
||||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
||||||
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
||||||
|
|
||||||
@ -203,8 +203,8 @@ def unet_config_from_diffusers_unet(state_dict, use_fp16):
|
|||||||
return unet_config
|
return unet_config
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def model_config_from_diffusers_unet(state_dict, use_fp16):
|
def model_config_from_diffusers_unet(state_dict, dtype):
|
||||||
unet_config = unet_config_from_diffusers_unet(state_dict, use_fp16)
|
unet_config = unet_config_from_diffusers_unet(state_dict, dtype)
|
||||||
if unet_config is not None:
|
if unet_config is not None:
|
||||||
return model_config_from_unet_config(unet_config)
|
return model_config_from_unet_config(unet_config)
|
||||||
return None
|
return None
|
||||||
|
@ -448,6 +448,11 @@ def unet_inital_load_device(parameters, dtype):
|
|||||||
else:
|
else:
|
||||||
return cpu_dev
|
return cpu_dev
|
||||||
|
|
||||||
|
def unet_dtype(device=None, model_params=0):
|
||||||
|
if should_use_fp16(device=device, model_params=model_params):
|
||||||
|
return torch.float16
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
def text_encoder_offload_device():
|
def text_encoder_offload_device():
|
||||||
if args.gpu_only:
|
if args.gpu_only:
|
||||||
return get_torch_device()
|
return get_torch_device()
|
||||||
|
20
comfy/sd.py
20
comfy/sd.py
@ -327,7 +327,9 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
if "params" in model_config_params["unet_config"]:
|
if "params" in model_config_params["unet_config"]:
|
||||||
unet_config = model_config_params["unet_config"]["params"]
|
unet_config = model_config_params["unet_config"]["params"]
|
||||||
if "use_fp16" in unet_config:
|
if "use_fp16" in unet_config:
|
||||||
fp16 = unet_config["use_fp16"]
|
fp16 = unet_config.pop("use_fp16")
|
||||||
|
if fp16:
|
||||||
|
unet_config["dtype"] = torch.float16
|
||||||
|
|
||||||
noise_aug_config = None
|
noise_aug_config = None
|
||||||
if "noise_aug_config" in model_config_params:
|
if "noise_aug_config" in model_config_params:
|
||||||
@ -405,12 +407,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
clip_target = None
|
clip_target = None
|
||||||
|
|
||||||
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
|
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
|
||||||
fp16 = model_management.should_use_fp16(model_params=parameters)
|
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||||
|
|
||||||
class WeightsLoader(torch.nn.Module):
|
class WeightsLoader(torch.nn.Module):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", fp16)
|
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||||
|
|
||||||
@ -418,12 +420,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
if output_clipvision:
|
if output_clipvision:
|
||||||
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
|
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
|
||||||
|
|
||||||
dtype = torch.float32
|
|
||||||
if fp16:
|
|
||||||
dtype = torch.float16
|
|
||||||
|
|
||||||
if output_model:
|
if output_model:
|
||||||
inital_load_device = model_management.unet_inital_load_device(parameters, dtype)
|
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||||
offload_device = model_management.unet_offload_device()
|
offload_device = model_management.unet_offload_device()
|
||||||
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
|
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
|
||||||
model.load_model_weights(sd, "model.diffusion_model.")
|
model.load_model_weights(sd, "model.diffusion_model.")
|
||||||
@ -458,15 +456,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
def load_unet(unet_path): #load unet in diffusers format
|
def load_unet(unet_path): #load unet in diffusers format
|
||||||
sd = comfy.utils.load_torch_file(unet_path)
|
sd = comfy.utils.load_torch_file(unet_path)
|
||||||
parameters = comfy.utils.calculate_parameters(sd)
|
parameters = comfy.utils.calculate_parameters(sd)
|
||||||
fp16 = model_management.should_use_fp16(model_params=parameters)
|
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||||
if "input_blocks.0.0.weight" in sd: #ldm
|
if "input_blocks.0.0.weight" in sd: #ldm
|
||||||
model_config = model_detection.model_config_from_unet(sd, "", fp16)
|
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||||
new_sd = sd
|
new_sd = sd
|
||||||
|
|
||||||
else: #diffusers
|
else: #diffusers
|
||||||
model_config = model_detection.model_config_from_diffusers_unet(sd, fp16)
|
model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype)
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
print("ERROR UNSUPPORTED UNET", unet_path)
|
print("ERROR UNSUPPORTED UNET", unet_path)
|
||||||
return None
|
return None
|
||||||
|
Loading…
Reference in New Issue
Block a user