Don't auto convert clip and vae weights to fp16 when saving checkpoint.

This commit is contained in:
comfyanonymous 2024-06-12 01:07:58 -04:00
parent 32be358213
commit 1ddf512fdc

View File

@ -207,9 +207,6 @@ class BaseModel(torch.nn.Module):
unet_state_dict = self.diffusion_model.state_dict()
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
if self.get_dtype() == torch.float16:
extra_sds = map(lambda sd: utils.convert_sd_to(sd, torch.float16), extra_sds)
if self.model_type == ModelType.V_PREDICTION:
unet_state_dict["v_pred"] = torch.tensor([])