Refactor previewer model loading into the LatentFormat class

also cleans up some unused imports in latent_preview.py
This commit is contained in:
asagi4 2024-09-24 18:52:40 +03:00
parent 96be1c553c
commit 433a02b2e8
3 changed files with 40 additions and 39 deletions

View File

@ -232,7 +232,7 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
Use ```--preview-method auto``` to enable previews.
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd) or the Stable Cascade previewer, download [taesd_decoder.pth, taesdxl_decoder.pth, taesd3_decoder.pth and taef1_decoder.pth](https://github.com/madebyollin/taesd/) and/or [previewer.safetensors](https://huggingface.co/stabilityai/stable-cascade/resolve/main/previewer.safetensors) and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI and launch it with `--preview-method taesd` to enable high-quality previews.
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd) or the Stable Cascade previewer, download [taesd_decoder.pth, taesdxl_decoder.pth, taesd3_decoder.pth and taef1_decoder.pth](https://github.com/madebyollin/taesd/) and/or [previewer.safetensors](https://huggingface.co/stabilityai/stable-cascade/resolve/main/previewer.safetensors) and place them in the `models/vae_approx` folder (save `previewer.safetensors` as `cascade_previewer.safetensors`). Once they're installed, restart ComfyUI and launch it with `--preview-method taesd` to enable high-quality previews.
## How to use TLS/SSL?
Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"`

View File

@ -1,10 +1,32 @@
import torch
import folder_paths
import logging
from comfy.taesd.taesd import TAESD
from comfy.ldm.cascade.stage_c_coder import Previewer
import comfy.utils
class LatentFormat:
scale_factor = 1.0
latent_channels = 4
latent_rgb_factors = None
taesd_decoder_name = None
# Default if decoder name is defined
previewer_class = TAESD
def load_previewer(self, device):
model = None
if not self.taesd_decoder_name:
return None
filename = next((fn for fn in folder_paths.get_filename_list("vae_approx") if fn.startswith(self.taesd_decoder_name)), "")
model_path = folder_paths.get_full_path("vae_approx", filename)
if model_path:
model = self.previewer_class(decoder_path=model_path, latent_channels=self.latent_channels).to(device)
if not model:
logging.warning("Warning: Could not load previewer model: models/vae_approx/%s", self.taesd_decoder_name)
return model
def process_in(self, latent):
return latent * self.scale_factor
@ -73,8 +95,19 @@ class SD_X4(LatentFormat):
[ 0.2523, -0.0055, -0.1651]
]
class CascadePreviewWrapper(Previewer):
def __init__(self, decoder_path=None, **kwargs):
super().__init__()
self.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True), strict=True)
self.eval()
def decode(self, latent):
return self(latent)
class SC_Prior(LatentFormat):
latent_channels = 16
taesd_decoder_name = "cascade_previewer"
previewer_class = CascadePreviewWrapper
def __init__(self):
self.scale_factor = 1.0
self.latent_rgb_factors = [
@ -95,7 +128,6 @@ class SC_Prior(LatentFormat):
[ 0.0542, 0.1545, 0.1325],
[-0.0352, -0.1672, -0.2541]
]
taesd_decoder_name = "previewer.safetensors"
class SC_B(LatentFormat):
def __init__(self):

View File

@ -1,14 +1,8 @@
import torch
from PIL import Image
import struct
import numpy as np
from comfy.cli_args import args, LatentPreviewMethod
from comfy.taesd.taesd import TAESD
import comfy.model_management
import folder_paths
import comfy.utils
import logging
from comfy.ldm.cascade.stage_c_coder import Previewer
MAX_PREVIEW_RESOLUTION = args.preview_size
@ -36,17 +30,6 @@ class TAESDPreviewerImpl(LatentPreviewer):
return preview_to_image(x_sample)
class StageCPreviewer(Previewer):
def __init__(self, path):
super().__init__()
sd = comfy.utils.load_torch_file(path, safe_load=True)
self.load_state_dict(sd, strict=True)
self.eval()
def decode(self, latent):
return self(latent)
class Latent2RGBPreviewer(LatentPreviewer):
def __init__(self, latent_rgb_factors):
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu")
@ -57,32 +40,18 @@ class Latent2RGBPreviewer(LatentPreviewer):
return preview_to_image(latent_image)
def get_previewer(device, latent_format):
def get_previewer(device, latent_format, method=None):
previewer = None
method = args.preview_method
if method is None:
method = args.preview_method
if method != LatentPreviewMethod.NoPreviews:
# TODO previewer methods
taesd_decoder_path = None
if latent_format.taesd_decoder_name is not None:
taesd_decoder_path = next(
(fn for fn in folder_paths.get_filename_list("vae_approx")
if fn.startswith(latent_format.taesd_decoder_name)),
""
)
taesd_decoder_path = folder_paths.get_full_path("vae_approx", taesd_decoder_path)
if method == LatentPreviewMethod.Auto:
method = LatentPreviewMethod.Latent2RGB
if method == LatentPreviewMethod.TAESD:
if taesd_decoder_path:
if 'previewer' in taesd_decoder_path:
taesd = StageCPreviewer(taesd_decoder_path).to(device)
else:
taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device)
previewer = TAESDPreviewerImpl(taesd)
else:
logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
model = latent_format.load_previewer(device)
if model:
previewer = TAESDPreviewerImpl(model)
if previewer is None:
if latent_format.latent_rgb_factors is not None: