mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 14:09:36 +00:00
Controlnet code refactor.
This commit is contained in:
parent
1c08bf35b4
commit
c19dcd362f
@ -191,13 +191,16 @@ class ControlNet(ControlBase):
|
|||||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||||
|
|
||||||
context = cond.get('crossattn_controlnet', cond['c_crossattn'])
|
context = cond.get('crossattn_controlnet', cond['c_crossattn'])
|
||||||
y = cond.get('y', None)
|
extra = self.extra_args.copy()
|
||||||
if y is not None:
|
for c in ["y", "guidance"]: #TODO
|
||||||
y = y.to(dtype)
|
temp = cond.get(c, None)
|
||||||
|
if temp is not None:
|
||||||
|
extra[c] = temp.to(dtype)
|
||||||
|
|
||||||
timestep = self.model_sampling_current.timestep(t)
|
timestep = self.model_sampling_current.timestep(t)
|
||||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||||
|
|
||||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, **self.extra_args)
|
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
||||||
return self.control_merge(control, control_prev, output_dtype)
|
return self.control_merge(control, control_prev, output_dtype)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
@ -338,12 +341,8 @@ class ControlLora(ControlNet):
|
|||||||
def inference_memory_requirements(self, dtype):
|
def inference_memory_requirements(self, dtype):
|
||||||
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
||||||
|
|
||||||
def load_controlnet_mmdit(sd):
|
def controlnet_config(sd):
|
||||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
|
||||||
model_config = comfy.model_detection.model_config_from_unet(new_sd, "", True)
|
|
||||||
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
|
||||||
for k in sd:
|
|
||||||
new_sd[k] = sd[k]
|
|
||||||
|
|
||||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
supported_inference_dtypes = model_config.supported_inference_dtypes
|
||||||
|
|
||||||
@ -356,14 +355,27 @@ def load_controlnet_mmdit(sd):
|
|||||||
else:
|
else:
|
||||||
operations = comfy.ops.disable_weight_init
|
operations = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **controlnet_config)
|
return model_config, operations, load_device, unet_dtype, manual_cast_dtype
|
||||||
missing, unexpected = control_model.load_state_dict(new_sd, strict=False)
|
|
||||||
|
def controlnet_load_state_dict(control_model, sd):
|
||||||
|
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
if len(missing) > 0:
|
if len(missing) > 0:
|
||||||
logging.warning("missing controlnet keys: {}".format(missing))
|
logging.warning("missing controlnet keys: {}".format(missing))
|
||||||
|
|
||||||
if len(unexpected) > 0:
|
if len(unexpected) > 0:
|
||||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||||
|
return control_model
|
||||||
|
|
||||||
|
def load_controlnet_mmdit(sd):
|
||||||
|
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
|
||||||
|
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
||||||
|
for k in sd:
|
||||||
|
new_sd[k] = sd[k]
|
||||||
|
|
||||||
|
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
|
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||||
|
|
||||||
latent_format = comfy.latent_formats.SD3()
|
latent_format = comfy.latent_formats.SD3()
|
||||||
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
||||||
|
@ -137,8 +137,8 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
dit_config["hidden_size"] = 3072
|
dit_config["hidden_size"] = 3072
|
||||||
dit_config["mlp_ratio"] = 4.0
|
dit_config["mlp_ratio"] = 4.0
|
||||||
dit_config["num_heads"] = 24
|
dit_config["num_heads"] = 24
|
||||||
dit_config["depth"] = 19
|
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||||
dit_config["depth_single_blocks"] = 38
|
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||||
dit_config["axes_dim"] = [16, 56, 56]
|
dit_config["axes_dim"] = [16, 56, 56]
|
||||||
dit_config["theta"] = 10000
|
dit_config["theta"] = 10000
|
||||||
dit_config["qkv_bias"] = True
|
dit_config["qkv_bias"] = True
|
||||||
|
Loading…
Reference in New Issue
Block a user