mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 05:57:20 +00:00
Controlnet code refactor.
This commit is contained in:
parent
1c08bf35b4
commit
c19dcd362f
@ -191,13 +191,16 @@ class ControlNet(ControlBase):
|
||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||
|
||||
context = cond.get('crossattn_controlnet', cond['c_crossattn'])
|
||||
y = cond.get('y', None)
|
||||
if y is not None:
|
||||
y = y.to(dtype)
|
||||
extra = self.extra_args.copy()
|
||||
for c in ["y", "guidance"]: #TODO
|
||||
temp = cond.get(c, None)
|
||||
if temp is not None:
|
||||
extra[c] = temp.to(dtype)
|
||||
|
||||
timestep = self.model_sampling_current.timestep(t)
|
||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, **self.extra_args)
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
||||
return self.control_merge(control, control_prev, output_dtype)
|
||||
|
||||
def copy(self):
|
||||
@ -338,12 +341,8 @@ class ControlLora(ControlNet):
|
||||
def inference_memory_requirements(self, dtype):
|
||||
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
||||
|
||||
def load_controlnet_mmdit(sd):
|
||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||
model_config = comfy.model_detection.model_config_from_unet(new_sd, "", True)
|
||||
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
||||
for k in sd:
|
||||
new_sd[k] = sd[k]
|
||||
def controlnet_config(sd):
|
||||
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
|
||||
|
||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
||||
|
||||
@ -356,14 +355,27 @@ def load_controlnet_mmdit(sd):
|
||||
else:
|
||||
operations = comfy.ops.disable_weight_init
|
||||
|
||||
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **controlnet_config)
|
||||
missing, unexpected = control_model.load_state_dict(new_sd, strict=False)
|
||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype
|
||||
|
||||
def controlnet_load_state_dict(control_model, sd):
|
||||
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
||||
|
||||
if len(missing) > 0:
|
||||
logging.warning("missing controlnet keys: {}".format(missing))
|
||||
|
||||
if len(unexpected) > 0:
|
||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||
return control_model
|
||||
|
||||
def load_controlnet_mmdit(sd):
|
||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
|
||||
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
||||
for k in sd:
|
||||
new_sd[k] = sd[k]
|
||||
|
||||
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||
|
||||
latent_format = comfy.latent_formats.SD3()
|
||||
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
||||
|
@ -137,8 +137,8 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
dit_config["hidden_size"] = 3072
|
||||
dit_config["mlp_ratio"] = 4.0
|
||||
dit_config["num_heads"] = 24
|
||||
dit_config["depth"] = 19
|
||||
dit_config["depth_single_blocks"] = 38
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["axes_dim"] = [16, 56, 56]
|
||||
dit_config["theta"] = 10000
|
||||
dit_config["qkv_bias"] = True
|
||||
|
Loading…
Reference in New Issue
Block a user