Make the text projection saved in the checkpoint the right format.

This commit is contained in:
comfyanonymous 2024-02-27 01:52:23 -05:00
parent 03c47fc0f2
commit b416be7d78
2 changed files with 5 additions and 1 deletions

View File

@ -237,6 +237,10 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
capture_qkv_bias[k_pre][code2idx[k_code]] = v capture_qkv_bias[k_pre][code2idx[k_code]] = v
continue continue
text_proj = "transformer.text_projection.weight"
if k.endswith(text_proj):
new_state_dict[k.replace(text_proj, "text_projection")] = v.transpose(0, 1).contiguous()
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k) relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
new_state_dict[relabelled_key] = v new_state_dict[relabelled_key] = v

View File

@ -110,7 +110,7 @@ def clip_text_transformers_convert(sd, prefix_from, prefix_to):
tp = "{}text_projection".format(prefix_from) tp = "{}text_projection".format(prefix_from)
if tp in sd: if tp in sd:
sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp).transpose(0, 1) sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp).transpose(0, 1).contiguous()
return sd return sd