diff --git a/comfy/taesd/taesd.py b/comfy/taesd/taesd.py index 92f74c11..8df1f160 100644 --- a/comfy/taesd/taesd.py +++ b/comfy/taesd/taesd.py @@ -6,6 +6,8 @@ Tiny AutoEncoder for Stable Diffusion import torch import torch.nn as nn +import comfy.utils + def conv(n_in, n_out, **kwargs): return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) @@ -50,17 +52,9 @@ class TAESD(nn.Module): self.encoder = Encoder() self.decoder = Decoder() if encoder_path is not None: - if encoder_path.lower().endswith(".safetensors"): - import safetensors.torch - self.encoder.load_state_dict(safetensors.torch.load_file(encoder_path, device="cpu")) - else: - self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True)) + self.encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True)) if decoder_path is not None: - if decoder_path.lower().endswith(".safetensors"): - import safetensors.torch - self.decoder.load_state_dict(safetensors.torch.load_file(decoder_path, device="cpu")) - else: - self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True)) + self.decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True)) @staticmethod def scale_latents(x):