mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Fix potential issue with non clip text embeddings.
This commit is contained in:
parent
25853d0be8
commit
82cae45d44
@ -5,7 +5,7 @@
|
|||||||
"attention_dropout": 0.0,
|
"attention_dropout": 0.0,
|
||||||
"bos_token_id": 0,
|
"bos_token_id": 0,
|
||||||
"dropout": 0.0,
|
"dropout": 0.0,
|
||||||
"eos_token_id": 2,
|
"eos_token_id": 49407,
|
||||||
"hidden_act": "gelu",
|
"hidden_act": "gelu",
|
||||||
"hidden_size": 1280,
|
"hidden_size": 1280,
|
||||||
"initializer_factor": 1.0,
|
"initializer_factor": 1.0,
|
||||||
|
@ -87,6 +87,7 @@ class CLIPTextModel_(torch.nn.Module):
|
|||||||
heads = config_dict["num_attention_heads"]
|
heads = config_dict["num_attention_heads"]
|
||||||
intermediate_size = config_dict["intermediate_size"]
|
intermediate_size = config_dict["intermediate_size"]
|
||||||
intermediate_activation = config_dict["hidden_act"]
|
intermediate_activation = config_dict["hidden_act"]
|
||||||
|
self.eos_token_id = config_dict["eos_token_id"]
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
|
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
|
||||||
@ -111,7 +112,7 @@ class CLIPTextModel_(torch.nn.Module):
|
|||||||
if i is not None and final_layer_norm_intermediate:
|
if i is not None and final_layer_norm_intermediate:
|
||||||
i = self.final_layer_norm(i)
|
i = self.final_layer_norm(i)
|
||||||
|
|
||||||
pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]
|
pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
|
||||||
return x, i, pooled_output
|
return x, i, pooled_output
|
||||||
|
|
||||||
class CLIPTextModel(torch.nn.Module):
|
class CLIPTextModel(torch.nn.Module):
|
||||||
|
@ -140,15 +140,13 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
|
|
||||||
def set_up_textual_embeddings(self, tokens, current_embeds):
|
def set_up_textual_embeddings(self, tokens, current_embeds):
|
||||||
out_tokens = []
|
out_tokens = []
|
||||||
next_new_token = token_dict_size = current_embeds.weight.shape[0] - 1
|
next_new_token = token_dict_size = current_embeds.weight.shape[0]
|
||||||
embedding_weights = []
|
embedding_weights = []
|
||||||
|
|
||||||
for x in tokens:
|
for x in tokens:
|
||||||
tokens_temp = []
|
tokens_temp = []
|
||||||
for y in x:
|
for y in x:
|
||||||
if isinstance(y, numbers.Integral):
|
if isinstance(y, numbers.Integral):
|
||||||
if y == token_dict_size: #EOS token
|
|
||||||
y = -1
|
|
||||||
tokens_temp += [int(y)]
|
tokens_temp += [int(y)]
|
||||||
else:
|
else:
|
||||||
if y.shape[0] == current_embeds.weight.shape[1]:
|
if y.shape[0] == current_embeds.weight.shape[1]:
|
||||||
@ -164,11 +162,10 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
n = token_dict_size
|
n = token_dict_size
|
||||||
if len(embedding_weights) > 0:
|
if len(embedding_weights) > 0:
|
||||||
new_embedding = torch.nn.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
|
new_embedding = torch.nn.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
|
||||||
new_embedding.weight[:token_dict_size] = current_embeds.weight[:-1]
|
new_embedding.weight[:token_dict_size] = current_embeds.weight
|
||||||
for x in embedding_weights:
|
for x in embedding_weights:
|
||||||
new_embedding.weight[n] = x
|
new_embedding.weight[n] = x
|
||||||
n += 1
|
n += 1
|
||||||
new_embedding.weight[n] = current_embeds.weight[-1] #EOS embedding
|
|
||||||
self.transformer.set_input_embeddings(new_embedding)
|
self.transformer.set_input_embeddings(new_embedding)
|
||||||
|
|
||||||
processed_tokens = []
|
processed_tokens = []
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
"attention_dropout": 0.0,
|
"attention_dropout": 0.0,
|
||||||
"bos_token_id": 0,
|
"bos_token_id": 0,
|
||||||
"dropout": 0.0,
|
"dropout": 0.0,
|
||||||
"eos_token_id": 2,
|
"eos_token_id": 49407,
|
||||||
"hidden_act": "quick_gelu",
|
"hidden_act": "quick_gelu",
|
||||||
"hidden_size": 768,
|
"hidden_size": 768,
|
||||||
"initializer_factor": 1.0,
|
"initializer_factor": 1.0,
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
"attention_dropout": 0.0,
|
"attention_dropout": 0.0,
|
||||||
"bos_token_id": 0,
|
"bos_token_id": 0,
|
||||||
"dropout": 0.0,
|
"dropout": 0.0,
|
||||||
"eos_token_id": 2,
|
"eos_token_id": 49407,
|
||||||
"hidden_act": "gelu",
|
"hidden_act": "gelu",
|
||||||
"hidden_size": 1024,
|
"hidden_size": 1024,
|
||||||
"initializer_factor": 1.0,
|
"initializer_factor": 1.0,
|
||||||
|
Loading…
Reference in New Issue
Block a user