mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Fix a few things in text enc code for models with no eos token.
This commit is contained in:
parent
1c8d11e48a
commit
44db978531
@ -199,11 +199,18 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
attention_mask = None
|
attention_mask = None
|
||||||
if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks:
|
if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks:
|
||||||
attention_mask = torch.zeros_like(tokens)
|
attention_mask = torch.zeros_like(tokens)
|
||||||
end_token = self.special_tokens.get("end", -1)
|
end_token = self.special_tokens.get("end", None)
|
||||||
|
if end_token is None:
|
||||||
|
cmp_token = self.special_tokens.get("pad", -1)
|
||||||
|
else:
|
||||||
|
cmp_token = end_token
|
||||||
|
|
||||||
for x in range(attention_mask.shape[0]):
|
for x in range(attention_mask.shape[0]):
|
||||||
for y in range(attention_mask.shape[1]):
|
for y in range(attention_mask.shape[1]):
|
||||||
attention_mask[x, y] = 1
|
attention_mask[x, y] = 1
|
||||||
if tokens[x, y] == end_token:
|
if tokens[x, y] == cmp_token:
|
||||||
|
if end_token is None:
|
||||||
|
attention_mask[x, y] = 0
|
||||||
break
|
break
|
||||||
|
|
||||||
attention_mask_model = None
|
attention_mask_model = None
|
||||||
@ -522,10 +529,14 @@ class SDTokenizer:
|
|||||||
for i, t_group in enumerate(tokens):
|
for i, t_group in enumerate(tokens):
|
||||||
#determine if we're going to try and keep the tokens in a single batch
|
#determine if we're going to try and keep the tokens in a single batch
|
||||||
is_large = len(t_group) >= self.max_word_length
|
is_large = len(t_group) >= self.max_word_length
|
||||||
|
if self.end_token is not None:
|
||||||
|
has_end_token = 1
|
||||||
|
else:
|
||||||
|
has_end_token = 0
|
||||||
|
|
||||||
while len(t_group) > 0:
|
while len(t_group) > 0:
|
||||||
if len(t_group) + len(batch) > self.max_length - 1:
|
if len(t_group) + len(batch) > self.max_length - has_end_token:
|
||||||
remaining_length = self.max_length - len(batch) - 1
|
remaining_length = self.max_length - len(batch) - has_end_token
|
||||||
#break word in two and add end token
|
#break word in two and add end token
|
||||||
if is_large:
|
if is_large:
|
||||||
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
|
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
|
||||||
|
Loading…
Reference in New Issue
Block a user