Round CLIP position ids to fix float issues in some checkpoints.

This commit is contained in:
comfyanonymous 2023-01-28 00:19:33 -05:00
parent e615d40ca1
commit 2973ff24c5

View File

@ -32,6 +32,9 @@ def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.")
sd[y] = sd.pop(x)
if 'cond_stage_model.transformer.text_model.embeddings.position_ids' in sd:
sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = sd['cond_stage_model.transformer.text_model.embeddings.position_ids'].round()
for x in load_state_dict_to:
x.load_state_dict(sd, strict=False)