From 05c77dbf0912c72bc7526244961058cd5621a07f Mon Sep 17 00:00:00 2001 From: MayureshMore Date: Fri, 14 Feb 2025 18:52:34 -0800 Subject: [PATCH] [Edited] Fix minor bug in the main function --- comfy_extras/nodes_stable_cascade.py | 123 +++++++++++---------------- 1 file changed, 48 insertions(+), 75 deletions(-) diff --git a/comfy_extras/nodes_stable_cascade.py b/comfy_extras/nodes_stable_cascade.py index 00340321..22fb0a8d 100644 --- a/comfy_extras/nodes_stable_cascade.py +++ b/comfy_extras/nodes_stable_cascade.py @@ -15,127 +15,100 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . """ - import torch import nodes import comfy.utils - class StableCascade_EmptyLatentImage: - def __init__(self, device="cpu"): + + def __init__(self, device='cpu'): + """Auto-generated docstring for function __init__.""" self.device = device @classmethod def INPUT_TYPES(s): - return {"required": { - "width": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}), - "compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}) - }} - RETURN_TYPES = ("LATENT", "LATENT") - RETURN_NAMES = ("stage_c", "stage_b") - FUNCTION = "generate" - - CATEGORY = "latent/stable_cascade" + """Auto-generated docstring for function INPUT_TYPES.""" + return {'required': {'width': ('INT', {'default': 1024, 'min': 256, 'max': nodes.MAX_RESOLUTION, 'step': 8}), 'height': ('INT', {'default': 1024, 'min': 256, 'max': nodes.MAX_RESOLUTION, 'step': 8}), 'compression': ('INT', {'default': 42, 'min': 4, 'max': 128, 'step': 1}), 'batch_size': ('INT', {'default': 1, 'min': 1, 'max': 4096})}} + RETURN_TYPES = ('LATENT', 'LATENT') + RETURN_NAMES = ('stage_c', 'stage_b') + FUNCTION = 'generate' + CATEGORY = 'latent/stable_cascade' def generate(self, width, height, compression, batch_size=1): + """Auto-generated docstring for function generate.""" c_latent = torch.zeros([batch_size, 16, height // compression, width // compression]) b_latent = torch.zeros([batch_size, 4, height // 4, width // 4]) - return ({ - "samples": c_latent, - }, { - "samples": b_latent, - }) + return ({'samples': c_latent}, {'samples': b_latent}) class StableCascade_StageC_VAEEncode: - def __init__(self, device="cpu"): + + def __init__(self, device='cpu'): + """Auto-generated docstring for function __init__.""" self.device = device @classmethod def INPUT_TYPES(s): - return {"required": { - "image": ("IMAGE",), - "vae": ("VAE", ), - "compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}), - }} - RETURN_TYPES = ("LATENT", "LATENT") - RETURN_NAMES = ("stage_c", "stage_b") - FUNCTION = "generate" - - CATEGORY = "latent/stable_cascade" + """Auto-generated docstring for function INPUT_TYPES.""" + return {'required': {'image': ('IMAGE',), 'vae': ('VAE',), 'compression': ('INT', {'default': 42, 'min': 4, 'max': 128, 'step': 1})}} + RETURN_TYPES = ('LATENT', 'LATENT') + RETURN_NAMES = ('stage_c', 'stage_b') + FUNCTION = 'generate' + CATEGORY = 'latent/stable_cascade' def generate(self, image, vae, compression): + """Auto-generated docstring for function generate.""" width = image.shape[-2] height = image.shape[-3] - out_width = (width // compression) * vae.downscale_ratio - out_height = (height // compression) * vae.downscale_ratio - - s = comfy.utils.common_upscale(image.movedim(-1,1), out_width, out_height, "bicubic", "center").movedim(1,-1) - - c_latent = vae.encode(s[:,:,:,:3]) - b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2]) - return ({ - "samples": c_latent, - }, { - "samples": b_latent, - }) + out_width = width // compression * vae.downscale_ratio + out_height = height // compression * vae.downscale_ratio + s = comfy.utils.common_upscale(image.movedim(-1, 1), out_width, out_height, 'bicubic', 'center').movedim(1, -1) + c_latent = vae.encode(s[:, :, :, :3]) + b_latent = torch.zeros([c_latent.shape[0], 4, height // 8 * 2, width // 8 * 2]) + return ({'samples': c_latent}, {'samples': b_latent}) class StableCascade_StageB_Conditioning: + @classmethod def INPUT_TYPES(s): - return {"required": { "conditioning": ("CONDITIONING",), - "stage_c": ("LATENT",), - }} - RETURN_TYPES = ("CONDITIONING",) - - FUNCTION = "set_prior" - - CATEGORY = "conditioning/stable_cascade" + """Auto-generated docstring for function INPUT_TYPES.""" + return {'required': {'conditioning': ('CONDITIONING',), 'stage_c': ('LATENT',)}} + RETURN_TYPES = ('CONDITIONING',) + FUNCTION = 'set_prior' + CATEGORY = 'conditioning/stable_cascade' def set_prior(self, conditioning, stage_c): + """Auto-generated docstring for function set_prior.""" c = [] for t in conditioning: d = t[1].copy() d['stable_cascade_prior'] = stage_c['samples'] n = [t[0], d] c.append(n) - return (c, ) + return (c,) class StableCascade_SuperResolutionControlnet: - def __init__(self, device="cpu"): + + def __init__(self, device='cpu'): + """Auto-generated docstring for function __init__.""" self.device = device @classmethod def INPUT_TYPES(s): - return {"required": { - "image": ("IMAGE",), - "vae": ("VAE", ), - }} - RETURN_TYPES = ("IMAGE", "LATENT", "LATENT") - RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b") - FUNCTION = "generate" - + """Auto-generated docstring for function INPUT_TYPES.""" + return {'required': {'image': ('IMAGE',), 'vae': ('VAE',)}} + RETURN_TYPES = ('IMAGE', 'LATENT', 'LATENT') + RETURN_NAMES = ('controlnet_input', 'stage_c', 'stage_b') + FUNCTION = 'generate' EXPERIMENTAL = True - CATEGORY = "_for_testing/stable_cascade" + CATEGORY = '_for_testing/stable_cascade' def generate(self, image, vae): + """Auto-generated docstring for function generate.""" width = image.shape[-2] height = image.shape[-3] batch_size = image.shape[0] - controlnet_input = vae.encode(image[:,:,:,:3]).movedim(1, -1) - + controlnet_input = vae.encode(image[:, :, :, :3]).movedim(1, -1) c_latent = torch.zeros([batch_size, 16, height // 16, width // 16]) b_latent = torch.zeros([batch_size, 4, height // 2, width // 2]) - return (controlnet_input, { - "samples": c_latent, - }, { - "samples": b_latent, - }) - -NODE_CLASS_MAPPINGS = { - "StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage, - "StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning, - "StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode, - "StableCascade_SuperResolutionControlnet": StableCascade_SuperResolutionControlnet, -} + return (controlnet_input, {'samples': c_latent}, {'samples': b_latent}) +NODE_CLASS_MAPPINGS = {'StableCascade_EmptyLatentImage': StableCascade_EmptyLatentImage, 'StableCascade_StageB_Conditioning': StableCascade_StageB_Conditioning, 'StableCascade_StageC_VAEEncode': StableCascade_StageC_VAEEncode, 'StableCascade_SuperResolutionControlnet': StableCascade_SuperResolutionControlnet} \ No newline at end of file