mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Support official SD3.5 Controlnets.
This commit is contained in:
parent
15c39ea757
commit
4c82741b54
122
comfy/cldm/dit_embedder.py
Normal file
122
comfy/cldm/dit_embedder.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
import math
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetEmbedder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
img_size: int,
|
||||||
|
patch_size: int,
|
||||||
|
in_chans: int,
|
||||||
|
attention_head_dim: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
adm_in_channels: int,
|
||||||
|
num_layers: int,
|
||||||
|
main_model_double: int,
|
||||||
|
double_y_emb: bool,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
pos_embed_max_size: Optional[int] = None,
|
||||||
|
operations = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.main_model_double = main_model_double
|
||||||
|
self.dtype = dtype
|
||||||
|
self.hidden_size = num_attention_heads * attention_head_dim
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.x_embedder = PatchEmbed(
|
||||||
|
img_size=img_size,
|
||||||
|
patch_size=patch_size,
|
||||||
|
in_chans=in_chans,
|
||||||
|
embed_dim=self.hidden_size,
|
||||||
|
strict_img_size=pos_embed_max_size is None,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
self.double_y_emb = double_y_emb
|
||||||
|
if self.double_y_emb:
|
||||||
|
self.orig_y_embedder = VectorEmbedder(
|
||||||
|
adm_in_channels, self.hidden_size, dtype, device, operations=operations
|
||||||
|
)
|
||||||
|
self.y_embedder = VectorEmbedder(
|
||||||
|
self.hidden_size, self.hidden_size, dtype, device, operations=operations
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.y_embedder = VectorEmbedder(
|
||||||
|
adm_in_channels, self.hidden_size, dtype, device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
self.transformer_blocks = nn.ModuleList(
|
||||||
|
DismantledBlock(
|
||||||
|
hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
)
|
||||||
|
|
||||||
|
# self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features
|
||||||
|
# TODO double check this logic when 8b
|
||||||
|
self.use_y_embedder = True
|
||||||
|
|
||||||
|
self.controlnet_blocks = nn.ModuleList([])
|
||||||
|
for _ in range(len(self.transformer_blocks)):
|
||||||
|
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
self.controlnet_blocks.append(controlnet_block)
|
||||||
|
|
||||||
|
self.pos_embed_input = PatchEmbed(
|
||||||
|
img_size=img_size,
|
||||||
|
patch_size=patch_size,
|
||||||
|
in_chans=in_chans,
|
||||||
|
embed_dim=self.hidden_size,
|
||||||
|
strict_img_size=False,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
y: Optional[torch.Tensor] = None,
|
||||||
|
context: Optional[torch.Tensor] = None,
|
||||||
|
hint = None,
|
||||||
|
) -> Tuple[Tensor, List[Tensor]]:
|
||||||
|
x_shape = list(x.shape)
|
||||||
|
x = self.x_embedder(x)
|
||||||
|
if not self.double_y_emb:
|
||||||
|
h = (x_shape[-2] + 1) // self.patch_size
|
||||||
|
w = (x_shape[-1] + 1) // self.patch_size
|
||||||
|
x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device)
|
||||||
|
c = self.t_embedder(timesteps, dtype=x.dtype)
|
||||||
|
if y is not None and self.y_embedder is not None:
|
||||||
|
if self.double_y_emb:
|
||||||
|
y = self.orig_y_embedder(y)
|
||||||
|
y = self.y_embedder(y)
|
||||||
|
c = c + y
|
||||||
|
|
||||||
|
x = x + self.pos_embed_input(hint)
|
||||||
|
|
||||||
|
block_out = ()
|
||||||
|
|
||||||
|
repeat = math.ceil(self.main_model_double / len(self.transformer_blocks))
|
||||||
|
for i in range(len(self.transformer_blocks)):
|
||||||
|
out = self.transformer_blocks[i](x, c)
|
||||||
|
if not self.double_y_emb:
|
||||||
|
x = out
|
||||||
|
block_out += (self.controlnet_blocks[i](out),) * repeat
|
||||||
|
|
||||||
|
return {"output": block_out}
|
@ -35,7 +35,7 @@ import comfy.ldm.cascade.controlnet
|
|||||||
import comfy.cldm.mmdit
|
import comfy.cldm.mmdit
|
||||||
import comfy.ldm.hydit.controlnet
|
import comfy.ldm.hydit.controlnet
|
||||||
import comfy.ldm.flux.controlnet
|
import comfy.ldm.flux.controlnet
|
||||||
|
import comfy.cldm.dit_embedder
|
||||||
|
|
||||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||||
current_batch_size = tensor.shape[0]
|
current_batch_size = tensor.shape[0]
|
||||||
@ -78,6 +78,7 @@ class ControlBase:
|
|||||||
self.concat_mask = False
|
self.concat_mask = False
|
||||||
self.extra_concat_orig = []
|
self.extra_concat_orig = []
|
||||||
self.extra_concat = None
|
self.extra_concat = None
|
||||||
|
self.preprocess_image = lambda a: a
|
||||||
|
|
||||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
||||||
self.cond_hint_original = cond_hint
|
self.cond_hint_original = cond_hint
|
||||||
@ -129,6 +130,7 @@ class ControlBase:
|
|||||||
c.strength_type = self.strength_type
|
c.strength_type = self.strength_type
|
||||||
c.concat_mask = self.concat_mask
|
c.concat_mask = self.concat_mask
|
||||||
c.extra_concat_orig = self.extra_concat_orig.copy()
|
c.extra_concat_orig = self.extra_concat_orig.copy()
|
||||||
|
c.preprocess_image = self.preprocess_image
|
||||||
|
|
||||||
def inference_memory_requirements(self, dtype):
|
def inference_memory_requirements(self, dtype):
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
@ -181,7 +183,7 @@ class ControlBase:
|
|||||||
|
|
||||||
|
|
||||||
class ControlNet(ControlBase):
|
class ControlNet(ControlBase):
|
||||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
|
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False, preprocess_image=lambda a: a):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
self.load_device = load_device
|
self.load_device = load_device
|
||||||
@ -196,6 +198,7 @@ class ControlNet(ControlBase):
|
|||||||
self.extra_conds += extra_conds
|
self.extra_conds += extra_conds
|
||||||
self.strength_type = strength_type
|
self.strength_type = strength_type
|
||||||
self.concat_mask = concat_mask
|
self.concat_mask = concat_mask
|
||||||
|
self.preprocess_image = preprocess_image
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond, batched_number):
|
def get_control(self, x_noisy, t, cond, batched_number):
|
||||||
control_prev = None
|
control_prev = None
|
||||||
@ -224,6 +227,7 @@ class ControlNet(ControlBase):
|
|||||||
if self.latent_format is not None:
|
if self.latent_format is not None:
|
||||||
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
|
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
|
||||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
||||||
|
self.cond_hint = self.preprocess_image(self.cond_hint)
|
||||||
if self.vae is not None:
|
if self.vae is not None:
|
||||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||||
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
|
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
|
||||||
@ -427,6 +431,7 @@ def controlnet_load_state_dict(control_model, sd):
|
|||||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||||
return control_model
|
return control_model
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet_mmdit(sd, model_options={}):
|
def load_controlnet_mmdit(sd, model_options={}):
|
||||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
|
||||||
@ -448,6 +453,83 @@ def load_controlnet_mmdit(sd, model_options={}):
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetSD35(ControlNet):
|
||||||
|
def pre_run(self, model, percent_to_timestep_function):
|
||||||
|
if self.control_model.double_y_emb:
|
||||||
|
missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
|
||||||
|
else:
|
||||||
|
missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
|
||||||
|
super().pre_run(model, percent_to_timestep_function)
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
c = ControlNetSD35(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||||
|
c.control_model = self.control_model
|
||||||
|
c.control_model_wrapped = self.control_model_wrapped
|
||||||
|
self.copy_to(c)
|
||||||
|
return c
|
||||||
|
|
||||||
|
def load_controlnet_sd35(sd, model_options={}):
|
||||||
|
control_type = -1
|
||||||
|
if "control_type" in sd:
|
||||||
|
control_type = round(sd.pop("control_type").item())
|
||||||
|
|
||||||
|
# blur_cnet = control_type == 0
|
||||||
|
canny_cnet = control_type == 1
|
||||||
|
depth_cnet = control_type == 2
|
||||||
|
|
||||||
|
print(control_type, canny_cnet, depth_cnet)
|
||||||
|
new_sd = {}
|
||||||
|
for k in comfy.utils.MMDIT_MAP_BASIC:
|
||||||
|
if k[1] in sd:
|
||||||
|
new_sd[k[0]] = sd.pop(k[1])
|
||||||
|
for k in sd:
|
||||||
|
new_sd[k] = sd[k]
|
||||||
|
sd = new_sd
|
||||||
|
|
||||||
|
y_emb_shape = sd["y_embedder.mlp.0.weight"].shape
|
||||||
|
depth = y_emb_shape[0] // 64
|
||||||
|
hidden_size = 64 * depth
|
||||||
|
num_heads = depth
|
||||||
|
head_dim = hidden_size // num_heads
|
||||||
|
num_blocks = comfy.model_detection.count_blocks(new_sd, 'transformer_blocks.{}.')
|
||||||
|
|
||||||
|
load_device = comfy.model_management.get_torch_device()
|
||||||
|
offload_device = comfy.model_management.unet_offload_device()
|
||||||
|
unet_dtype = comfy.model_management.unet_dtype(model_params=-1)
|
||||||
|
|
||||||
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||||
|
|
||||||
|
operations = model_options.get("custom_operations", None)
|
||||||
|
if operations is None:
|
||||||
|
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||||
|
|
||||||
|
control_model = comfy.cldm.dit_embedder.ControlNetEmbedder(img_size=None,
|
||||||
|
patch_size=2,
|
||||||
|
in_chans=16,
|
||||||
|
num_layers=num_blocks,
|
||||||
|
main_model_double=depth,
|
||||||
|
double_y_emb=y_emb_shape[0] == y_emb_shape[1],
|
||||||
|
attention_head_dim=head_dim,
|
||||||
|
num_attention_heads=num_heads,
|
||||||
|
adm_in_channels=2048,
|
||||||
|
device=offload_device,
|
||||||
|
dtype=unet_dtype,
|
||||||
|
operations=operations)
|
||||||
|
|
||||||
|
control_model = controlnet_load_state_dict(control_model, sd)
|
||||||
|
|
||||||
|
latent_format = comfy.latent_formats.SD3()
|
||||||
|
preprocess_image = lambda a: a
|
||||||
|
if canny_cnet:
|
||||||
|
preprocess_image = lambda a: (a * 255 * 0.5 + 0.5)
|
||||||
|
elif depth_cnet:
|
||||||
|
preprocess_image = lambda a: 1.0 - a
|
||||||
|
|
||||||
|
control = ControlNetSD35(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, preprocess_image=preprocess_image)
|
||||||
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
|
||||||
|
|
||||||
@ -560,7 +642,10 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
|||||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
||||||
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
|
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
|
||||||
elif "pos_embed_input.proj.weight" in controlnet_data:
|
elif "pos_embed_input.proj.weight" in controlnet_data:
|
||||||
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
|
if "transformer_blocks.0.adaLN_modulation.1.bias" in controlnet_data:
|
||||||
|
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
|
||||||
|
else:
|
||||||
|
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
|
||||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||||
|
Loading…
Reference in New Issue
Block a user