Support new flux model variants.

This commit is contained in:
comfyanonymous 2024-11-21 08:38:23 -05:00
parent 41444b5236
commit 8f0009aad0
9 changed files with 147 additions and 19 deletions

View File

@ -23,6 +23,7 @@ class CLIPAttention(torch.nn.Module):
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a), ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
"gelu": torch.nn.functional.gelu, "gelu": torch.nn.functional.gelu,
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
} }
class CLIPMLP(torch.nn.Module): class CLIPMLP(torch.nn.Module):
@ -139,27 +140,35 @@ class CLIPTextModel(torch.nn.Module):
class CLIPVisionEmbeddings(torch.nn.Module): class CLIPVisionEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None): def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
super().__init__() super().__init__()
num_patches = (image_size // patch_size) ** 2
if model_type == "siglip_vision_model":
self.class_embedding = None
patch_bias = True
else:
num_patches = num_patches + 1
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device)) self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
patch_bias = False
self.patch_embedding = operations.Conv2d( self.patch_embedding = operations.Conv2d(
in_channels=num_channels, in_channels=num_channels,
out_channels=embed_dim, out_channels=embed_dim,
kernel_size=patch_size, kernel_size=patch_size,
stride=patch_size, stride=patch_size,
bias=False, bias=patch_bias,
dtype=dtype, dtype=dtype,
device=device device=device
) )
num_patches = (image_size // patch_size) ** 2 self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
num_positions = num_patches + 1
self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
def forward(self, pixel_values): def forward(self, pixel_values):
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2) embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
return torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + comfy.ops.cast_to_input(self.position_embedding.weight, embeds) if self.class_embedding is not None:
embeds = torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1)
return embeds + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
class CLIPVision(torch.nn.Module): class CLIPVision(torch.nn.Module):
@ -170,9 +179,15 @@ class CLIPVision(torch.nn.Module):
heads = config_dict["num_attention_heads"] heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"] intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"] intermediate_activation = config_dict["hidden_act"]
model_type = config_dict["model_type"]
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=dtype, device=device, operations=operations) self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
if model_type == "siglip_vision_model":
self.pre_layrnorm = lambda a: a
self.output_layernorm = True
else:
self.pre_layrnorm = operations.LayerNorm(embed_dim) self.pre_layrnorm = operations.LayerNorm(embed_dim)
self.output_layernorm = False
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.post_layernorm = operations.LayerNorm(embed_dim) self.post_layernorm = operations.LayerNorm(embed_dim)
@ -181,6 +196,10 @@ class CLIPVision(torch.nn.Module):
x = self.pre_layrnorm(x) x = self.pre_layrnorm(x)
#TODO: attention_mask? #TODO: attention_mask?
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output) x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
if self.output_layernorm:
x = self.post_layernorm(x)
pooled_output = x
else:
pooled_output = self.post_layernorm(x[:, 0, :]) pooled_output = self.post_layernorm(x[:, 0, :])
return x, i, pooled_output return x, i, pooled_output
@ -188,7 +207,10 @@ class CLIPVisionModelProjection(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations): def __init__(self, config_dict, dtype, device, operations):
super().__init__() super().__init__()
self.vision_model = CLIPVision(config_dict, dtype, device, operations) self.vision_model = CLIPVision(config_dict, dtype, device, operations)
if "projection_dim" in config_dict:
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False) self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
else:
self.visual_projection = lambda a: a
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
x = self.vision_model(*args, **kwargs) x = self.vision_model(*args, **kwargs)

View File

@ -16,9 +16,9 @@ class Output:
def __setitem__(self, key, item): def __setitem__(self, key, item):
setattr(self, key, item) setattr(self, key, item)
def clip_preprocess(image, size=224): def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]):
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype) mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype) std = torch.tensor(std, device=image.device, dtype=image.dtype)
image = image.movedim(-1, 1) image = image.movedim(-1, 1)
if not (image.shape[2] == size and image.shape[3] == size): if not (image.shape[2] == size and image.shape[3] == size):
scale = (size / min(image.shape[2], image.shape[3])) scale = (size / min(image.shape[2], image.shape[3]))
@ -35,6 +35,8 @@ class ClipVisionModel():
config = json.load(f) config = json.load(f)
self.image_size = config.get("image_size", 224) self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
self.load_device = comfy.model_management.text_encoder_device() self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device() offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
@ -51,7 +53,7 @@ class ClipVisionModel():
def encode_image(self, image): def encode_image(self, image):
comfy.model_management.load_model_gpu(self.patcher) comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float() pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std).float()
out = self.model(pixel_values=pixel_values, intermediate_output=-2) out = self.model(pixel_values=pixel_values, intermediate_output=-2)
outputs = Output() outputs = Output()
@ -94,7 +96,9 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd: elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd: elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
if sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577: if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
else: else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")

View File

@ -0,0 +1,13 @@
{
"num_channels": 3,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"image_size": 384,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 14,
"image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5]
}

View File

@ -20,6 +20,7 @@ import comfy.ldm.common_dit
@dataclass @dataclass
class FluxParams: class FluxParams:
in_channels: int in_channels: int
out_channels: int
vec_in_dim: int vec_in_dim: int
context_in_dim: int context_in_dim: int
hidden_size: int hidden_size: int
@ -29,6 +30,7 @@ class FluxParams:
depth_single_blocks: int depth_single_blocks: int
axes_dim: list axes_dim: list
theta: int theta: int
patch_size: int
qkv_bias: bool qkv_bias: bool
guidance_embed: bool guidance_embed: bool
@ -43,8 +45,9 @@ class Flux(nn.Module):
self.dtype = dtype self.dtype = dtype
params = FluxParams(**kwargs) params = FluxParams(**kwargs)
self.params = params self.params = params
self.in_channels = params.in_channels * 2 * 2 self.patch_size = params.patch_size
self.out_channels = self.in_channels self.in_channels = params.in_channels * params.patch_size * params.patch_size
self.out_channels = params.out_channels * params.patch_size * params.patch_size
if params.hidden_size % params.num_heads != 0: if params.hidden_size % params.num_heads != 0:
raise ValueError( raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
@ -165,7 +168,7 @@ class Flux(nn.Module):
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs): def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape bs, c, h, w = x.shape
patch_size = 2 patch_size = self.patch_size
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)

25
comfy/ldm/flux/redux.py Normal file
View File

@ -0,0 +1,25 @@
import torch
import comfy.ops
ops = comfy.ops.manual_cast
class ReduxImageEncoder(torch.nn.Module):
def __init__(
self,
redux_dim: int = 1152,
txt_in_features: int = 4096,
device=None,
dtype=None,
) -> None:
super().__init__()
self.redux_dim = redux_dim
self.device = device
self.dtype = dtype
self.redux_up = ops.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
self.redux_down = ops.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
def forward(self, sigclip_embeds) -> torch.Tensor:
projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds)))
return projected_x

17
comfy/lora_convert.py Normal file
View File

@ -0,0 +1,17 @@
import torch
def convert_lora_bfl_control(sd): #BFL loras for Flux
sd_out = {}
for k in sd:
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
sd_out[k_to] = sd[k]
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
return sd_out
def convert_lora(sd):
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
return convert_lora_bfl_control(sd)
return sd

View File

@ -710,6 +710,38 @@ class Flux(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None): def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
def concat_cond(self, **kwargs):
num_channels = self.diffusion_model.img_in.weight.shape[1] // (self.diffusion_model.patch_size * self.diffusion_model.patch_size)
out_channels = self.model_config.unet_config["out_channels"]
if num_channels <= out_channels:
return None
image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"]
if image is None:
image = torch.zeros_like(noise)
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = utils.resize_to_batch_size(image, noise.shape[0])
image = self.process_latent_in(image)
if num_channels <= out_channels * 2:
return image
#inpaint model
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None:
mask = torch.ones_like(noise)[:, :1]
mask = torch.mean(mask, dim=1, keepdim=True)
print(mask.shape)
mask = utils.common_upscale(mask.to(device), noise.shape[-1] * 8, noise.shape[-2] * 8, "bilinear", "center")
mask = mask.view(mask.shape[0], mask.shape[2] // 8, 8, mask.shape[3] // 8, 8).permute(0, 2, 4, 1, 3).reshape(mask.shape[0], -1, mask.shape[2] // 8, mask.shape[3] // 8)
mask = utils.resize_to_batch_size(mask, noise.shape[0])
return torch.cat((image, mask), dim=1)
def encode_adm(self, **kwargs): def encode_adm(self, **kwargs):
return kwargs["pooled_output"] return kwargs["pooled_output"]

View File

@ -137,6 +137,12 @@ def detect_unet_config(state_dict, key_prefix):
dit_config = {} dit_config = {}
dit_config["image_model"] = "flux" dit_config["image_model"] = "flux"
dit_config["in_channels"] = 16 dit_config["in_channels"] = 16
patch_size = 2
dit_config["patch_size"] = patch_size
in_key = "{}img_in.weight".format(key_prefix)
if in_key in state_dict_keys:
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
dit_config["out_channels"] = 16
dit_config["vec_in_dim"] = 768 dit_config["vec_in_dim"] = 768
dit_config["context_in_dim"] = 4096 dit_config["context_in_dim"] = 4096
dit_config["hidden_size"] = 3072 dit_config["hidden_size"] = 3072

View File

@ -30,9 +30,12 @@ import comfy.text_encoders.genmo
import comfy.model_patcher import comfy.model_patcher
import comfy.lora import comfy.lora
import comfy.lora_convert
import comfy.t2i_adapter.adapter import comfy.t2i_adapter.adapter
import comfy.taesd.taesd import comfy.taesd.taesd
import comfy.ldm.flux.redux
def load_lora_for_models(model, clip, lora, strength_model, strength_clip): def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
key_map = {} key_map = {}
if model is not None: if model is not None:
@ -40,6 +43,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
if clip is not None: if clip is not None:
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map) key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
lora = comfy.lora_convert.convert_lora(lora)
loaded = comfy.lora.load_lora(lora, key_map) loaded = comfy.lora.load_lora(lora, key_map)
if model is not None: if model is not None:
new_modelpatcher = model.clone() new_modelpatcher = model.clone()
@ -433,6 +437,8 @@ def load_style_model(ckpt_path):
keys = model_data.keys() keys = model_data.keys()
if "style_embedding" in keys: if "style_embedding" in keys:
model = comfy.t2i_adapter.adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8) model = comfy.t2i_adapter.adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
elif "redux_down.weight" in keys:
model = comfy.ldm.flux.redux.ReduxImageEncoder()
else: else:
raise Exception("invalid style model {}".format(ckpt_path)) raise Exception("invalid style model {}".format(ckpt_path))
model.load_state_dict(model_data) model.load_state_dict(model_data)