mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-16 16:43:36 +00:00
Mistoline flux controlnet support.
This commit is contained in:
parent
c7427375ee
commit
5cbaa9e07c
@ -430,9 +430,9 @@ def load_controlnet_hunyuandit(controlnet_data):
|
|||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
def load_controlnet_flux_xlabs(sd):
|
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False):
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
|
||||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
control_model = controlnet_load_state_dict(control_model, sd)
|
control_model = controlnet_load_state_dict(control_model, sd)
|
||||||
extra_conds = ['y', 'guidance']
|
extra_conds = ['y', 'guidance']
|
||||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
@ -457,6 +457,10 @@ def load_controlnet_flux_instantx(sd):
|
|||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
def convert_mistoline(sd):
|
||||||
|
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None):
|
def load_controlnet(ckpt_path, model=None):
|
||||||
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
||||||
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
||||||
@ -518,13 +522,15 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
if len(leftover_keys) > 0:
|
if len(leftover_keys) > 0:
|
||||||
logging.warning("leftover keys: {}".format(leftover_keys))
|
logging.warning("leftover keys: {}".format(leftover_keys))
|
||||||
controlnet_data = new_sd
|
controlnet_data = new_sd
|
||||||
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
|
elif "controlnet_blocks.0.weight" in controlnet_data:
|
||||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
||||||
return load_controlnet_flux_xlabs(controlnet_data)
|
return load_controlnet_flux_xlabs_mistoline(controlnet_data)
|
||||||
elif "pos_embed_input.proj.weight" in controlnet_data:
|
elif "pos_embed_input.proj.weight" in controlnet_data:
|
||||||
return load_controlnet_mmdit(controlnet_data)
|
return load_controlnet_mmdit(controlnet_data) #SD3 diffusers controlnet
|
||||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||||
return load_controlnet_flux_instantx(controlnet_data)
|
return load_controlnet_flux_instantx(controlnet_data)
|
||||||
|
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||||
|
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True)
|
||||||
|
|
||||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||||
pth = False
|
pth = False
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
||||||
|
#modified to support different types of flux controlnets
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
@ -12,22 +13,65 @@ from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
|||||||
from .model import Flux
|
from .model import Flux
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
class MistolineCondDownsamplBlock(nn.Module):
|
||||||
|
def __init__(self, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.encoder(x)
|
||||||
|
|
||||||
|
class MistolineControlnetBlock(nn.Module):
|
||||||
|
def __init__(self, hidden_size, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device)
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.act(self.linear(x))
|
||||||
|
|
||||||
|
|
||||||
class ControlNetFlux(Flux):
|
class ControlNetFlux(Flux):
|
||||||
def __init__(self, latent_input=False, num_union_modes=0, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||||
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||||
|
|
||||||
self.main_model_double = 19
|
self.main_model_double = 19
|
||||||
self.main_model_single = 38
|
self.main_model_single = 38
|
||||||
|
|
||||||
|
self.mistoline = mistoline
|
||||||
# add ControlNet blocks
|
# add ControlNet blocks
|
||||||
|
if self.mistoline:
|
||||||
|
control_block = lambda : MistolineControlnetBlock(self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
else:
|
||||||
|
control_block = lambda : operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.controlnet_blocks = nn.ModuleList([])
|
self.controlnet_blocks = nn.ModuleList([])
|
||||||
for _ in range(self.params.depth):
|
for _ in range(self.params.depth):
|
||||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
self.controlnet_blocks.append(control_block())
|
||||||
self.controlnet_blocks.append(controlnet_block)
|
|
||||||
|
|
||||||
self.controlnet_single_blocks = nn.ModuleList([])
|
self.controlnet_single_blocks = nn.ModuleList([])
|
||||||
for _ in range(self.params.depth_single_blocks):
|
for _ in range(self.params.depth_single_blocks):
|
||||||
self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device))
|
self.controlnet_single_blocks.append(control_block())
|
||||||
|
|
||||||
self.num_union_modes = num_union_modes
|
self.num_union_modes = num_union_modes
|
||||||
self.controlnet_mode_embedder = None
|
self.controlnet_mode_embedder = None
|
||||||
@ -38,23 +82,26 @@ class ControlNetFlux(Flux):
|
|||||||
self.latent_input = latent_input
|
self.latent_input = latent_input
|
||||||
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
if not self.latent_input:
|
if not self.latent_input:
|
||||||
self.input_hint_block = nn.Sequential(
|
if self.mistoline:
|
||||||
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations)
|
||||||
nn.SiLU(),
|
else:
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
self.input_hint_block = nn.Sequential(
|
||||||
nn.SiLU(),
|
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
nn.SiLU(),
|
||||||
nn.SiLU(),
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
nn.SiLU(),
|
||||||
nn.SiLU(),
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
nn.SiLU(),
|
||||||
nn.SiLU(),
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
nn.SiLU(),
|
||||||
nn.SiLU(),
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
nn.SiLU(),
|
||||||
nn.SiLU(),
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
nn.SiLU(),
|
||||||
)
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
def forward_orig(
|
def forward_orig(
|
||||||
self,
|
self,
|
||||||
@ -73,9 +120,6 @@ class ControlNetFlux(Flux):
|
|||||||
|
|
||||||
# running on sequences img
|
# running on sequences img
|
||||||
img = self.img_in(img)
|
img = self.img_in(img)
|
||||||
if not self.latent_input:
|
|
||||||
controlnet_cond = self.input_hint_block(controlnet_cond)
|
|
||||||
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
|
||||||
|
|
||||||
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
||||||
img = img + controlnet_cond
|
img = img + controlnet_cond
|
||||||
@ -131,9 +175,14 @@ class ControlNetFlux(Flux):
|
|||||||
patch_size = 2
|
patch_size = 2
|
||||||
if self.latent_input:
|
if self.latent_input:
|
||||||
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
||||||
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
elif self.mistoline:
|
||||||
|
hint = hint * 2.0 - 1.0
|
||||||
|
hint = self.input_cond_block(hint)
|
||||||
else:
|
else:
|
||||||
hint = hint * 2.0 - 1.0
|
hint = hint * 2.0 - 1.0
|
||||||
|
hint = self.input_hint_block(hint)
|
||||||
|
|
||||||
|
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
|
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
|
Loading…
Reference in New Issue
Block a user