Refactor code so model can be a dtype other than fp32 or fp16.

This commit is contained in:
comfyanonymous 2023-10-13 14:35:21 -04:00
parent fee3b0c070
commit 9a55dadb4c
6 changed files with 39 additions and 41 deletions

View File

@ -34,8 +34,7 @@ class ControlNet(nn.Module):
dims=2,
num_classes=None,
use_checkpoint=False,
use_fp16=False,
use_bf16=False,
dtype=torch.float32,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
@ -108,8 +107,7 @@ class ControlNet(nn.Module):
self.conv_resample = conv_resample
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32
self.dtype = th.bfloat16 if use_bf16 else self.dtype
self.dtype = dtype
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample

View File

@ -292,8 +292,8 @@ def load_controlnet(ckpt_path, model=None):
controlnet_config = None
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
use_fp16 = comfy.model_management.should_use_fp16()
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, use_fp16)
unet_dtype = comfy.model_management.unet_dtype()
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
@ -353,8 +353,8 @@ def load_controlnet(ckpt_path, model=None):
return net
if controlnet_config is None:
use_fp16 = comfy.model_management.should_use_fp16()
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16, True).unet_config
unet_dtype = comfy.model_management.unet_dtype()
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
@ -383,8 +383,7 @@ def load_controlnet(ckpt_path, model=None):
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
print(missing, unexpected)
if use_fp16:
control_model = control_model.half()
control_model = control_model.to(unet_dtype)
global_average_pooling = False
filename = os.path.splitext(ckpt_path)[0]

View File

@ -296,8 +296,7 @@ class UNetModel(nn.Module):
dims=2,
num_classes=None,
use_checkpoint=False,
use_fp16=False,
use_bf16=False,
dtype=th.float32,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
@ -370,8 +369,7 @@ class UNetModel(nn.Module):
self.conv_resample = conv_resample
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32
self.dtype = th.bfloat16 if use_bf16 else self.dtype
self.dtype = dtype
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample

View File

@ -14,7 +14,7 @@ def count_blocks(state_dict_keys, prefix_string):
count += 1
return count
def detect_unet_config(state_dict, key_prefix, use_fp16):
def detect_unet_config(state_dict, key_prefix, dtype):
state_dict_keys = list(state_dict.keys())
unet_config = {
@ -32,7 +32,7 @@ def detect_unet_config(state_dict, key_prefix, use_fp16):
else:
unet_config["adm_in_channels"] = None
unet_config["use_fp16"] = use_fp16
unet_config["dtype"] = dtype
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
@ -116,15 +116,15 @@ def model_config_from_unet_config(unet_config):
print("no match", unet_config)
return None
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16, use_base_if_no_match=False):
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16)
def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_match=False):
unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype)
model_config = model_config_from_unet_config(unet_config)
if model_config is None and use_base_if_no_match:
return comfy.supported_models_base.BASE(unet_config)
else:
return model_config
def unet_config_from_diffusers_unet(state_dict, use_fp16):
def unet_config_from_diffusers_unet(state_dict, dtype):
match = {}
attention_resolutions = []
@ -147,47 +147,47 @@ def unet_config_from_diffusers_unet(state_dict, use_fp16):
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 384,
'num_classes': 'sequential', 'adm_in_channels': 2560, 'dtype': dtype, 'in_channels': 4, 'model_channels': 384,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64}
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
'num_classes': 'sequential', 'adm_in_channels': 2048, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
'num_classes': 'sequential', 'adm_in_channels': 1536, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8}
SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1}
SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 9, 'model_channels': 320,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
@ -203,8 +203,8 @@ def unet_config_from_diffusers_unet(state_dict, use_fp16):
return unet_config
return None
def model_config_from_diffusers_unet(state_dict, use_fp16):
unet_config = unet_config_from_diffusers_unet(state_dict, use_fp16)
def model_config_from_diffusers_unet(state_dict, dtype):
unet_config = unet_config_from_diffusers_unet(state_dict, dtype)
if unet_config is not None:
return model_config_from_unet_config(unet_config)
return None

View File

@ -448,6 +448,11 @@ def unet_inital_load_device(parameters, dtype):
else:
return cpu_dev
def unet_dtype(device=None, model_params=0):
if should_use_fp16(device=device, model_params=model_params):
return torch.float16
return torch.float32
def text_encoder_offload_device():
if args.gpu_only:
return get_torch_device()

View File

@ -327,7 +327,9 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
if "params" in model_config_params["unet_config"]:
unet_config = model_config_params["unet_config"]["params"]
if "use_fp16" in unet_config:
fp16 = unet_config["use_fp16"]
fp16 = unet_config.pop("use_fp16")
if fp16:
unet_config["dtype"] = torch.float16
noise_aug_config = None
if "noise_aug_config" in model_config_params:
@ -405,12 +407,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
clip_target = None
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
fp16 = model_management.should_use_fp16(model_params=parameters)
unet_dtype = model_management.unet_dtype(model_params=parameters)
class WeightsLoader(torch.nn.Module):
pass
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", fp16)
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
@ -418,12 +420,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if output_clipvision:
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
dtype = torch.float32
if fp16:
dtype = torch.float16
if output_model:
inital_load_device = model_management.unet_inital_load_device(parameters, dtype)
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
offload_device = model_management.unet_offload_device()
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
model.load_model_weights(sd, "model.diffusion_model.")
@ -458,15 +456,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
def load_unet(unet_path): #load unet in diffusers format
sd = comfy.utils.load_torch_file(unet_path)
parameters = comfy.utils.calculate_parameters(sd)
fp16 = model_management.should_use_fp16(model_params=parameters)
unet_dtype = model_management.unet_dtype(model_params=parameters)
if "input_blocks.0.0.weight" in sd: #ldm
model_config = model_detection.model_config_from_unet(sd, "", fp16)
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
new_sd = sd
else: #diffusers
model_config = model_detection.model_config_from_diffusers_unet(sd, fp16)
model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype)
if model_config is None:
print("ERROR UNSUPPORTED UNET", unet_path)
return None