This commit is contained in:
comfyanonymous 2023-10-10 21:46:53 -04:00
parent 851bb87ca9
commit 5e885bd9c8

View File

@ -6,6 +6,8 @@ Tiny AutoEncoder for Stable Diffusion
import torch
import torch.nn as nn
import comfy.utils
def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
@ -50,17 +52,9 @@ class TAESD(nn.Module):
self.encoder = Encoder()
self.decoder = Decoder()
if encoder_path is not None:
if encoder_path.lower().endswith(".safetensors"):
import safetensors.torch
self.encoder.load_state_dict(safetensors.torch.load_file(encoder_path, device="cpu"))
else:
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True))
self.encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
if decoder_path is not None:
if decoder_path.lower().endswith(".safetensors"):
import safetensors.torch
self.decoder.load_state_dict(safetensors.torch.load_file(decoder_path, device="cpu"))
else:
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True))
self.decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
@staticmethod
def scale_latents(x):