Fix a few things in text enc code for models with no eos token.

This commit is contained in:
comfyanonymous 2024-12-10 23:07:26 -05:00
parent 1c8d11e48a
commit 44db978531

View File

@ -199,11 +199,18 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
attention_mask = None attention_mask = None
if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks: if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks:
attention_mask = torch.zeros_like(tokens) attention_mask = torch.zeros_like(tokens)
end_token = self.special_tokens.get("end", -1) end_token = self.special_tokens.get("end", None)
if end_token is None:
cmp_token = self.special_tokens.get("pad", -1)
else:
cmp_token = end_token
for x in range(attention_mask.shape[0]): for x in range(attention_mask.shape[0]):
for y in range(attention_mask.shape[1]): for y in range(attention_mask.shape[1]):
attention_mask[x, y] = 1 attention_mask[x, y] = 1
if tokens[x, y] == end_token: if tokens[x, y] == cmp_token:
if end_token is None:
attention_mask[x, y] = 0
break break
attention_mask_model = None attention_mask_model = None
@ -522,10 +529,14 @@ class SDTokenizer:
for i, t_group in enumerate(tokens): for i, t_group in enumerate(tokens):
#determine if we're going to try and keep the tokens in a single batch #determine if we're going to try and keep the tokens in a single batch
is_large = len(t_group) >= self.max_word_length is_large = len(t_group) >= self.max_word_length
if self.end_token is not None:
has_end_token = 1
else:
has_end_token = 0
while len(t_group) > 0: while len(t_group) > 0:
if len(t_group) + len(batch) > self.max_length - 1: if len(t_group) + len(batch) > self.max_length - has_end_token:
remaining_length = self.max_length - len(batch) - 1 remaining_length = self.max_length - len(batch) - has_end_token
#break word in two and add end token #break word in two and add end token
if is_large: if is_large:
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]]) batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])