mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 05:57:20 +00:00
Make sure the pooled output stays at the EOS token with added embeddings.
This commit is contained in:
parent
9534f0f8a5
commit
c99d8002f8
@ -91,13 +91,15 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
|
|
||||||
def set_up_textual_embeddings(self, tokens, current_embeds):
|
def set_up_textual_embeddings(self, tokens, current_embeds):
|
||||||
out_tokens = []
|
out_tokens = []
|
||||||
next_new_token = token_dict_size = current_embeds.weight.shape[0]
|
next_new_token = token_dict_size = current_embeds.weight.shape[0] - 1
|
||||||
embedding_weights = []
|
embedding_weights = []
|
||||||
|
|
||||||
for x in tokens:
|
for x in tokens:
|
||||||
tokens_temp = []
|
tokens_temp = []
|
||||||
for y in x:
|
for y in x:
|
||||||
if isinstance(y, int):
|
if isinstance(y, int):
|
||||||
|
if y == token_dict_size: #EOS token
|
||||||
|
y = -1
|
||||||
tokens_temp += [y]
|
tokens_temp += [y]
|
||||||
else:
|
else:
|
||||||
if y.shape[0] == current_embeds.weight.shape[1]:
|
if y.shape[0] == current_embeds.weight.shape[1]:
|
||||||
@ -110,15 +112,21 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
tokens_temp += [self.empty_tokens[0][-1]]
|
tokens_temp += [self.empty_tokens[0][-1]]
|
||||||
out_tokens += [tokens_temp]
|
out_tokens += [tokens_temp]
|
||||||
|
|
||||||
|
n = token_dict_size
|
||||||
if len(embedding_weights) > 0:
|
if len(embedding_weights) > 0:
|
||||||
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
|
new_embedding = torch.nn.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
|
||||||
new_embedding.weight[:token_dict_size] = current_embeds.weight[:]
|
new_embedding.weight[:token_dict_size] = current_embeds.weight[:-1]
|
||||||
n = token_dict_size
|
|
||||||
for x in embedding_weights:
|
for x in embedding_weights:
|
||||||
new_embedding.weight[n] = x
|
new_embedding.weight[n] = x
|
||||||
n += 1
|
n += 1
|
||||||
|
new_embedding.weight[n] = current_embeds.weight[-1] #EOS embedding
|
||||||
self.transformer.set_input_embeddings(new_embedding)
|
self.transformer.set_input_embeddings(new_embedding)
|
||||||
return out_tokens
|
|
||||||
|
processed_tokens = []
|
||||||
|
for x in out_tokens:
|
||||||
|
processed_tokens += [list(map(lambda a: n if a == -1 else a, x))] #The EOS token should always be the largest one
|
||||||
|
|
||||||
|
return processed_tokens
|
||||||
|
|
||||||
def forward(self, tokens):
|
def forward(self, tokens):
|
||||||
backup_embeds = self.transformer.get_input_embeddings()
|
backup_embeds = self.transformer.get_input_embeddings()
|
||||||
|
@ -267,7 +267,6 @@ export const ComfyWidgets = {
|
|||||||
return { widget: node.addWidget(widgetType, inputName, val, () => {}, config) };
|
return { widget: node.addWidget(widgetType, inputName, val, () => {}, config) };
|
||||||
},
|
},
|
||||||
INT(node, inputName, inputData, app) {
|
INT(node, inputName, inputData, app) {
|
||||||
console.log(app);
|
|
||||||
let widgetType = isSlider(inputData[1]["display"], app);
|
let widgetType = isSlider(inputData[1]["display"], app);
|
||||||
const { val, config } = getNumberDefaults(inputData, 1);
|
const { val, config } = getNumberDefaults(inputData, 1);
|
||||||
Object.assign(config, { precision: 0 });
|
Object.assign(config, { precision: 0 });
|
||||||
|
Loading…
Reference in New Issue
Block a user