Make sure the pooled output stays at the EOS token with added embeddings.

This commit is contained in:
comfyanonymous 2023-08-03 20:27:50 -04:00
parent 9534f0f8a5
commit c99d8002f8
2 changed files with 13 additions and 6 deletions

View File

@ -91,13 +91,15 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def set_up_textual_embeddings(self, tokens, current_embeds): def set_up_textual_embeddings(self, tokens, current_embeds):
out_tokens = [] out_tokens = []
next_new_token = token_dict_size = current_embeds.weight.shape[0] next_new_token = token_dict_size = current_embeds.weight.shape[0] - 1
embedding_weights = [] embedding_weights = []
for x in tokens: for x in tokens:
tokens_temp = [] tokens_temp = []
for y in x: for y in x:
if isinstance(y, int): if isinstance(y, int):
if y == token_dict_size: #EOS token
y = -1
tokens_temp += [y] tokens_temp += [y]
else: else:
if y.shape[0] == current_embeds.weight.shape[1]: if y.shape[0] == current_embeds.weight.shape[1]:
@ -110,15 +112,21 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens_temp += [self.empty_tokens[0][-1]] tokens_temp += [self.empty_tokens[0][-1]]
out_tokens += [tokens_temp] out_tokens += [tokens_temp]
n = token_dict_size
if len(embedding_weights) > 0: if len(embedding_weights) > 0:
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype) new_embedding = torch.nn.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
new_embedding.weight[:token_dict_size] = current_embeds.weight[:] new_embedding.weight[:token_dict_size] = current_embeds.weight[:-1]
n = token_dict_size
for x in embedding_weights: for x in embedding_weights:
new_embedding.weight[n] = x new_embedding.weight[n] = x
n += 1 n += 1
new_embedding.weight[n] = current_embeds.weight[-1] #EOS embedding
self.transformer.set_input_embeddings(new_embedding) self.transformer.set_input_embeddings(new_embedding)
return out_tokens
processed_tokens = []
for x in out_tokens:
processed_tokens += [list(map(lambda a: n if a == -1 else a, x))] #The EOS token should always be the largest one
return processed_tokens
def forward(self, tokens): def forward(self, tokens):
backup_embeds = self.transformer.get_input_embeddings() backup_embeds = self.transformer.get_input_embeddings()

View File

@ -267,7 +267,6 @@ export const ComfyWidgets = {
return { widget: node.addWidget(widgetType, inputName, val, () => {}, config) }; return { widget: node.addWidget(widgetType, inputName, val, () => {}, config) };
}, },
INT(node, inputName, inputData, app) { INT(node, inputName, inputData, app) {
console.log(app);
let widgetType = isSlider(inputData[1]["display"], app); let widgetType = isSlider(inputData[1]["display"], app);
const { val, config } = getNumberDefaults(inputData, 1); const { val, config } = getNumberDefaults(inputData, 1);
Object.assign(config, { precision: 0 }); Object.assign(config, { precision: 0 });