Support llava clip vision model.

This commit is contained in:
comfyanonymous 2025-03-06 00:24:43 -05:00
parent 85ef295069
commit 0bef826a98
4 changed files with 61 additions and 3 deletions

View File

@ -211,6 +211,15 @@ class CLIPVision(torch.nn.Module):
pooled_output = self.post_layernorm(x[:, 0, :]) pooled_output = self.post_layernorm(x[:, 0, :])
return x, i, pooled_output return x, i, pooled_output
class LlavaProjector(torch.nn.Module):
def __init__(self, in_dim, out_dim, dtype, device, operations):
super().__init__()
self.linear_1 = operations.Linear(in_dim, out_dim, bias=True, device=device, dtype=dtype)
self.linear_2 = operations.Linear(out_dim, out_dim, bias=True, device=device, dtype=dtype)
def forward(self, x):
return self.linear_2(torch.nn.functional.gelu(self.linear_1(x[:, 1:])))
class CLIPVisionModelProjection(torch.nn.Module): class CLIPVisionModelProjection(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations): def __init__(self, config_dict, dtype, device, operations):
super().__init__() super().__init__()
@ -220,7 +229,16 @@ class CLIPVisionModelProjection(torch.nn.Module):
else: else:
self.visual_projection = lambda a: a self.visual_projection = lambda a: a
if "llava3" == config_dict.get("projector_type", None):
self.multi_modal_projector = LlavaProjector(config_dict["hidden_size"], 4096, dtype, device, operations)
else:
self.multi_modal_projector = None
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
x = self.vision_model(*args, **kwargs) x = self.vision_model(*args, **kwargs)
out = self.visual_projection(x[2]) out = self.visual_projection(x[2])
return (x[0], x[1], out) projected = None
if self.multi_modal_projector is not None:
projected = self.multi_modal_projector(x[1])
return (x[0], x[1], out, projected)

View File

@ -65,6 +65,7 @@ class ClipVisionModel():
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device()) outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device()) outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device()) outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
outputs["mm_projected"] = out[3]
return outputs return outputs
def convert_to_transformers(sd, prefix): def convert_to_transformers(sd, prefix):
@ -104,7 +105,10 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152: if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577: elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json") if "multi_modal_projector.linear_1.bias" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
else: else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
else: else:

View File

@ -0,0 +1,19 @@
{
"attention_dropout": 0.0,
"dropout": 0.0,
"hidden_act": "quick_gelu",
"hidden_size": 1024,
"image_size": 336,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-5,
"model_type": "clip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 24,
"patch_size": 14,
"projection_dim": 768,
"projector_type": "llava3",
"torch_dtype": "float32"
}

View File

@ -196,8 +196,25 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
index = 0 index = 0
pad_extra = 0 pad_extra = 0
for o in other_embeds: for o in other_embeds:
emb = o[1]
if torch.is_tensor(emb):
emb = {"type": "embedding", "data": emb}
emb_type = emb.get("type", None)
if emb_type == "embedding":
emb = emb.get("data", None)
else:
if hasattr(self.transformer, "preprocess_embed"):
emb = self.transformer.preprocess_embed(emb, device=device)
else:
emb = None
if emb is None:
index += -1
continue
ind = index + o[0] ind = index + o[0]
emb = o[1].view(1, -1, o[1].shape[-1]).to(device=device, dtype=torch.float32) emb = emb.view(1, -1, emb.shape[-1]).to(device=device, dtype=torch.float32)
emb_shape = emb.shape[1] emb_shape = emb.shape[1]
if emb.shape[-1] == tokens_embed.shape[-1]: if emb.shape[-1] == tokens_embed.shape[-1]:
tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1) tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1)