[Edited] Fix minor bug in the main function

This commit is contained in:
MayureshMore 2025-02-14 18:52:34 -08:00
parent 6ad56a5961
commit 05c77dbf09

View File

@ -15,127 +15,100 @@
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
import nodes
import comfy.utils
class StableCascade_EmptyLatentImage:
def __init__(self, device="cpu"):
def __init__(self, device='cpu'):
"""Auto-generated docstring for function __init__."""
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": {
"width": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})
}}
RETURN_TYPES = ("LATENT", "LATENT")
RETURN_NAMES = ("stage_c", "stage_b")
FUNCTION = "generate"
CATEGORY = "latent/stable_cascade"
"""Auto-generated docstring for function INPUT_TYPES."""
return {'required': {'width': ('INT', {'default': 1024, 'min': 256, 'max': nodes.MAX_RESOLUTION, 'step': 8}), 'height': ('INT', {'default': 1024, 'min': 256, 'max': nodes.MAX_RESOLUTION, 'step': 8}), 'compression': ('INT', {'default': 42, 'min': 4, 'max': 128, 'step': 1}), 'batch_size': ('INT', {'default': 1, 'min': 1, 'max': 4096})}}
RETURN_TYPES = ('LATENT', 'LATENT')
RETURN_NAMES = ('stage_c', 'stage_b')
FUNCTION = 'generate'
CATEGORY = 'latent/stable_cascade'
def generate(self, width, height, compression, batch_size=1):
"""Auto-generated docstring for function generate."""
c_latent = torch.zeros([batch_size, 16, height // compression, width // compression])
b_latent = torch.zeros([batch_size, 4, height // 4, width // 4])
return ({
"samples": c_latent,
}, {
"samples": b_latent,
})
return ({'samples': c_latent}, {'samples': b_latent})
class StableCascade_StageC_VAEEncode:
def __init__(self, device="cpu"):
def __init__(self, device='cpu'):
"""Auto-generated docstring for function __init__."""
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": {
"image": ("IMAGE",),
"vae": ("VAE", ),
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
}}
RETURN_TYPES = ("LATENT", "LATENT")
RETURN_NAMES = ("stage_c", "stage_b")
FUNCTION = "generate"
CATEGORY = "latent/stable_cascade"
"""Auto-generated docstring for function INPUT_TYPES."""
return {'required': {'image': ('IMAGE',), 'vae': ('VAE',), 'compression': ('INT', {'default': 42, 'min': 4, 'max': 128, 'step': 1})}}
RETURN_TYPES = ('LATENT', 'LATENT')
RETURN_NAMES = ('stage_c', 'stage_b')
FUNCTION = 'generate'
CATEGORY = 'latent/stable_cascade'
def generate(self, image, vae, compression):
"""Auto-generated docstring for function generate."""
width = image.shape[-2]
height = image.shape[-3]
out_width = (width // compression) * vae.downscale_ratio
out_height = (height // compression) * vae.downscale_ratio
s = comfy.utils.common_upscale(image.movedim(-1,1), out_width, out_height, "bicubic", "center").movedim(1,-1)
c_latent = vae.encode(s[:,:,:,:3])
b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2])
return ({
"samples": c_latent,
}, {
"samples": b_latent,
})
out_width = width // compression * vae.downscale_ratio
out_height = height // compression * vae.downscale_ratio
s = comfy.utils.common_upscale(image.movedim(-1, 1), out_width, out_height, 'bicubic', 'center').movedim(1, -1)
c_latent = vae.encode(s[:, :, :, :3])
b_latent = torch.zeros([c_latent.shape[0], 4, height // 8 * 2, width // 8 * 2])
return ({'samples': c_latent}, {'samples': b_latent})
class StableCascade_StageB_Conditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": { "conditioning": ("CONDITIONING",),
"stage_c": ("LATENT",),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "set_prior"
CATEGORY = "conditioning/stable_cascade"
"""Auto-generated docstring for function INPUT_TYPES."""
return {'required': {'conditioning': ('CONDITIONING',), 'stage_c': ('LATENT',)}}
RETURN_TYPES = ('CONDITIONING',)
FUNCTION = 'set_prior'
CATEGORY = 'conditioning/stable_cascade'
def set_prior(self, conditioning, stage_c):
"""Auto-generated docstring for function set_prior."""
c = []
for t in conditioning:
d = t[1].copy()
d['stable_cascade_prior'] = stage_c['samples']
n = [t[0], d]
c.append(n)
return (c, )
return (c,)
class StableCascade_SuperResolutionControlnet:
def __init__(self, device="cpu"):
def __init__(self, device='cpu'):
"""Auto-generated docstring for function __init__."""
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": {
"image": ("IMAGE",),
"vae": ("VAE", ),
}}
RETURN_TYPES = ("IMAGE", "LATENT", "LATENT")
RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b")
FUNCTION = "generate"
"""Auto-generated docstring for function INPUT_TYPES."""
return {'required': {'image': ('IMAGE',), 'vae': ('VAE',)}}
RETURN_TYPES = ('IMAGE', 'LATENT', 'LATENT')
RETURN_NAMES = ('controlnet_input', 'stage_c', 'stage_b')
FUNCTION = 'generate'
EXPERIMENTAL = True
CATEGORY = "_for_testing/stable_cascade"
CATEGORY = '_for_testing/stable_cascade'
def generate(self, image, vae):
"""Auto-generated docstring for function generate."""
width = image.shape[-2]
height = image.shape[-3]
batch_size = image.shape[0]
controlnet_input = vae.encode(image[:,:,:,:3]).movedim(1, -1)
controlnet_input = vae.encode(image[:, :, :, :3]).movedim(1, -1)
c_latent = torch.zeros([batch_size, 16, height // 16, width // 16])
b_latent = torch.zeros([batch_size, 4, height // 2, width // 2])
return (controlnet_input, {
"samples": c_latent,
}, {
"samples": b_latent,
})
NODE_CLASS_MAPPINGS = {
"StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage,
"StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning,
"StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode,
"StableCascade_SuperResolutionControlnet": StableCascade_SuperResolutionControlnet,
}
return (controlnet_input, {'samples': c_latent}, {'samples': b_latent})
NODE_CLASS_MAPPINGS = {'StableCascade_EmptyLatentImage': StableCascade_EmptyLatentImage, 'StableCascade_StageB_Conditioning': StableCascade_StageB_Conditioning, 'StableCascade_StageC_VAEEncode': StableCascade_StageC_VAEEncode, 'StableCascade_SuperResolutionControlnet': StableCascade_SuperResolutionControlnet}