mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 05:57:20 +00:00
Support llava clip vision model.
This commit is contained in:
parent
85ef295069
commit
0bef826a98
@ -211,6 +211,15 @@ class CLIPVision(torch.nn.Module):
|
|||||||
pooled_output = self.post_layernorm(x[:, 0, :])
|
pooled_output = self.post_layernorm(x[:, 0, :])
|
||||||
return x, i, pooled_output
|
return x, i, pooled_output
|
||||||
|
|
||||||
|
class LlavaProjector(torch.nn.Module):
|
||||||
|
def __init__(self, in_dim, out_dim, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.linear_1 = operations.Linear(in_dim, out_dim, bias=True, device=device, dtype=dtype)
|
||||||
|
self.linear_2 = operations.Linear(out_dim, out_dim, bias=True, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear_2(torch.nn.functional.gelu(self.linear_1(x[:, 1:])))
|
||||||
|
|
||||||
class CLIPVisionModelProjection(torch.nn.Module):
|
class CLIPVisionModelProjection(torch.nn.Module):
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -220,7 +229,16 @@ class CLIPVisionModelProjection(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.visual_projection = lambda a: a
|
self.visual_projection = lambda a: a
|
||||||
|
|
||||||
|
if "llava3" == config_dict.get("projector_type", None):
|
||||||
|
self.multi_modal_projector = LlavaProjector(config_dict["hidden_size"], 4096, dtype, device, operations)
|
||||||
|
else:
|
||||||
|
self.multi_modal_projector = None
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
x = self.vision_model(*args, **kwargs)
|
x = self.vision_model(*args, **kwargs)
|
||||||
out = self.visual_projection(x[2])
|
out = self.visual_projection(x[2])
|
||||||
return (x[0], x[1], out)
|
projected = None
|
||||||
|
if self.multi_modal_projector is not None:
|
||||||
|
projected = self.multi_modal_projector(x[1])
|
||||||
|
|
||||||
|
return (x[0], x[1], out, projected)
|
||||||
|
@ -65,6 +65,7 @@ class ClipVisionModel():
|
|||||||
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
|
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
|
||||||
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
|
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
|
||||||
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
|
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
|
||||||
|
outputs["mm_projected"] = out[3]
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def convert_to_transformers(sd, prefix):
|
def convert_to_transformers(sd, prefix):
|
||||||
@ -104,7 +105,10 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
|||||||
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
|
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||||
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
if "multi_modal_projector.linear_1.bias" in sd:
|
||||||
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")
|
||||||
|
else:
|
||||||
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||||
else:
|
else:
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||||
else:
|
else:
|
||||||
|
19
comfy/clip_vision_config_vitl_336_llava.json
Normal file
19
comfy/clip_vision_config_vitl_336_llava.json
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
{
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"dropout": 0.0,
|
||||||
|
"hidden_act": "quick_gelu",
|
||||||
|
"hidden_size": 1024,
|
||||||
|
"image_size": 336,
|
||||||
|
"initializer_factor": 1.0,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 4096,
|
||||||
|
"layer_norm_eps": 1e-5,
|
||||||
|
"model_type": "clip_vision_model",
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_channels": 3,
|
||||||
|
"num_hidden_layers": 24,
|
||||||
|
"patch_size": 14,
|
||||||
|
"projection_dim": 768,
|
||||||
|
"projector_type": "llava3",
|
||||||
|
"torch_dtype": "float32"
|
||||||
|
}
|
@ -196,8 +196,25 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
index = 0
|
index = 0
|
||||||
pad_extra = 0
|
pad_extra = 0
|
||||||
for o in other_embeds:
|
for o in other_embeds:
|
||||||
|
emb = o[1]
|
||||||
|
if torch.is_tensor(emb):
|
||||||
|
emb = {"type": "embedding", "data": emb}
|
||||||
|
|
||||||
|
emb_type = emb.get("type", None)
|
||||||
|
if emb_type == "embedding":
|
||||||
|
emb = emb.get("data", None)
|
||||||
|
else:
|
||||||
|
if hasattr(self.transformer, "preprocess_embed"):
|
||||||
|
emb = self.transformer.preprocess_embed(emb, device=device)
|
||||||
|
else:
|
||||||
|
emb = None
|
||||||
|
|
||||||
|
if emb is None:
|
||||||
|
index += -1
|
||||||
|
continue
|
||||||
|
|
||||||
ind = index + o[0]
|
ind = index + o[0]
|
||||||
emb = o[1].view(1, -1, o[1].shape[-1]).to(device=device, dtype=torch.float32)
|
emb = emb.view(1, -1, emb.shape[-1]).to(device=device, dtype=torch.float32)
|
||||||
emb_shape = emb.shape[1]
|
emb_shape = emb.shape[1]
|
||||||
if emb.shape[-1] == tokens_embed.shape[-1]:
|
if emb.shape[-1] == tokens_embed.shape[-1]:
|
||||||
tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1)
|
tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1)
|
||||||
|
Loading…
Reference in New Issue
Block a user