support TAESD3 (#3738)

This commit is contained in:
Dr.Lt.Data 2024-06-16 15:03:53 +09:00 committed by GitHub
parent bb1969cab7
commit df7db0e027
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 28 additions and 13 deletions

View File

@ -129,6 +129,7 @@ class SD3(LatentFormat):
[-0.0749, -0.0634, -0.0456], [-0.0749, -0.0634, -0.0456],
[-0.1418, -0.1457, -0.1259] [-0.1418, -0.1457, -0.1259]
] ]
self.taesd_decoder_name = "taesd3_decoder"
def process_in(self, latent): def process_in(self, latent):
return (latent - self.shift_factor) * self.scale_factor return (latent - self.shift_factor) * self.scale_factor

View File

@ -166,7 +166,7 @@ class CLIP:
return self.patcher.get_key_patches() return self.patcher.get_key_patches()
class VAE: class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None): def __init__(self, sd=None, device=None, config=None, dtype=None, latent_channels=4):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd) sd = diffusers_convert.convert_vae_state_dict(sd)
@ -174,7 +174,7 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
self.downscale_ratio = 8 self.downscale_ratio = 8
self.upscale_ratio = 8 self.upscale_ratio = 8
self.latent_channels = 4 self.latent_channels = latent_channels
self.output_channels = 3 self.output_channels = 3
self.process_input = lambda image: image * 2.0 - 1.0 self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
@ -189,7 +189,7 @@ class VAE:
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config}, encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config}) decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
elif "taesd_decoder.1.weight" in sd: elif "taesd_decoder.1.weight" in sd:
self.first_stage_model = comfy.taesd.taesd.TAESD() self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels)
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
self.first_stage_model = StageA() self.first_stage_model = StageA()
self.downscale_ratio = 4 self.downscale_ratio = 4

View File

@ -25,18 +25,19 @@ class Block(nn.Module):
def forward(self, x): def forward(self, x):
return self.fuse(self.conv(x) + self.skip(x)) return self.fuse(self.conv(x) + self.skip(x))
def Encoder(): def Encoder(latent_channels=4):
return nn.Sequential( return nn.Sequential(
conv(3, 64), Block(64, 64), conv(3, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 4), conv(64, latent_channels),
) )
def Decoder():
def Decoder(latent_channels=4):
return nn.Sequential( return nn.Sequential(
Clamp(), conv(4, 64), nn.ReLU(), Clamp(), conv(latent_channels, 64), nn.ReLU(),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
@ -47,11 +48,11 @@ class TAESD(nn.Module):
latent_magnitude = 3 latent_magnitude = 3
latent_shift = 0.5 latent_shift = 0.5
def __init__(self, encoder_path=None, decoder_path=None): def __init__(self, encoder_path=None, decoder_path=None, latent_channels=4):
"""Initialize pretrained TAESD on the given device from the given checkpoints.""" """Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__() super().__init__()
self.taesd_encoder = Encoder() self.taesd_encoder = Encoder(latent_channels=latent_channels)
self.taesd_decoder = Decoder() self.taesd_decoder = Decoder(latent_channels=latent_channels)
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0)) self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
if encoder_path is not None: if encoder_path is not None:
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True)) self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))

View File

@ -64,7 +64,7 @@ def get_previewer(device, latent_format):
if method == LatentPreviewMethod.TAESD: if method == LatentPreviewMethod.TAESD:
if taesd_decoder_path: if taesd_decoder_path:
taesd = TAESD(None, taesd_decoder_path).to(device) taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device)
previewer = TAESDPreviewerImpl(taesd) previewer = TAESDPreviewerImpl(taesd)
else: else:
logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name)) logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))

View File

@ -634,6 +634,8 @@ class VAELoader:
sdxl_taesd_dec = False sdxl_taesd_dec = False
sd1_taesd_enc = False sd1_taesd_enc = False
sd1_taesd_dec = False sd1_taesd_dec = False
sd3_taesd_enc = False
sd3_taesd_dec = False
for v in approx_vaes: for v in approx_vaes:
if v.startswith("taesd_decoder."): if v.startswith("taesd_decoder."):
@ -644,10 +646,16 @@ class VAELoader:
sdxl_taesd_dec = True sdxl_taesd_dec = True
elif v.startswith("taesdxl_encoder."): elif v.startswith("taesdxl_encoder."):
sdxl_taesd_enc = True sdxl_taesd_enc = True
elif v.startswith("taesd3_decoder."):
sd3_taesd_dec = True
elif v.startswith("taesd3_encoder."):
sd3_taesd_enc = True
if sd1_taesd_dec and sd1_taesd_enc: if sd1_taesd_dec and sd1_taesd_enc:
vaes.append("taesd") vaes.append("taesd")
if sdxl_taesd_dec and sdxl_taesd_enc: if sdxl_taesd_dec and sdxl_taesd_enc:
vaes.append("taesdxl") vaes.append("taesdxl")
if sd3_taesd_dec and sd3_taesd_enc:
vaes.append("taesd3")
return vaes return vaes
@staticmethod @staticmethod
@ -670,6 +678,8 @@ class VAELoader:
sd["vae_scale"] = torch.tensor(0.18215) sd["vae_scale"] = torch.tensor(0.18215)
elif name == "taesdxl": elif name == "taesdxl":
sd["vae_scale"] = torch.tensor(0.13025) sd["vae_scale"] = torch.tensor(0.13025)
elif name == "taesd3":
sd["vae_scale"] = torch.tensor(1.5305)
return sd return sd
@classmethod @classmethod
@ -682,12 +692,15 @@ class VAELoader:
#TODO: scale factor? #TODO: scale factor?
def load_vae(self, vae_name): def load_vae(self, vae_name):
if vae_name in ["taesd", "taesdxl"]: if vae_name in ["taesd", "taesdxl", "taesd3"]:
sd = self.load_taesd(vae_name) sd = self.load_taesd(vae_name)
else: else:
vae_path = folder_paths.get_full_path("vae", vae_name) vae_path = folder_paths.get_full_path("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path) sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
latent_channels = 16 if vae_name == 'taesd3' else 4
vae = comfy.sd.VAE(sd=sd, latent_channels=latent_channels)
return (vae,) return (vae,)
class ControlNetLoader: class ControlNetLoader: