diff --git a/comfy/sd.py b/comfy/sd.py index e98a3aa8..d4aa68b7 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -717,6 +717,26 @@ class VAE: except: return None +class ForcedTiledVAE(VAE): + def __init__(self, tile_size, overlap, temporal_size=64, temporal_overlap=8, sd=None, device=None, config=None, dtype=None, metadata=None): + super().__init__(sd, device, config, dtype, metadata) + self.tile_size = tile_size + self.overlap = overlap + self.temporal_size = temporal_size + self.temporal_overlap = temporal_overlap + + @classmethod + def cast(cls, vae: VAE, tile_size, overlap, temporal_size=64, temporal_overlap=8): + vae.__class__ = cls + vae.tile_size = tile_size + vae.overlap = overlap + vae.temporal_size = temporal_size + vae.temporal_overlap = temporal_overlap + return vae + + def encode(self, pixel_samples): + return super().encode_tiled(pixel_samples, self.tile_size, self.tile_size, self.overlap, self.temporal_size, self.temporal_overlap) + class StyleModel: def __init__(self, model, device="cpu"): self.model = model diff --git a/nodes.py b/nodes.py index 95e831b8..81095929 100644 --- a/nodes.py +++ b/nodes.py @@ -356,6 +356,24 @@ class VAEEncodeTiled: t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap) return ({"samples": t}, ) +class ForceTiledVAEEncode: + @classmethod + def INPUT_TYPES(s): + return {"required": {"vae": ("VAE", ), + "tile_size": ("INT", {"default": 256, "min": 64, "max": 4096, "step": 32}), + "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}), + "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time."}), + "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}), + }} + RETURN_TYPES = ("VAE",) + FUNCTION = "cast" + + CATEGORY = "_for_testing" + + def cast(self, vae, tile_size, overlap, temporal_size=64, temporal_overlap=8): + tiled_vae = comfy.sd.ForcedTiledVAE.cast(vae, tile_size, overlap, temporal_size, temporal_overlap) + return (tiled_vae, ) + class VAEEncodeForInpaint: @classmethod def INPUT_TYPES(s): @@ -1997,6 +2015,7 @@ NODE_CLASS_MAPPINGS = { "CLIPVisionLoader": CLIPVisionLoader, "VAEDecodeTiled": VAEDecodeTiled, "VAEEncodeTiled": VAEEncodeTiled, + "ForceTiledVAEEncode": ForceTiledVAEEncode, "unCLIPCheckpointLoader": unCLIPCheckpointLoader, "GLIGENLoader": GLIGENLoader, "GLIGENTextBoxApply": GLIGENTextBoxApply, @@ -2078,6 +2097,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { # _for_testing "VAEDecodeTiled": "VAE Decode (Tiled)", "VAEEncodeTiled": "VAE Encode (Tiled)", + "ForceTiledVAEEncode": "Force tiled VAE Encode", } EXTENSION_WEB_DIRS = {}