Merge 11a23056aa9032d630ab74dd168ae25c3a140dc8 into d8e5662822168101afb5e08a8ba75b6eefff6e02

This commit is contained in:
Anghammar 2025-05-18 11:16:42 +02:00 committed by GitHub
commit 04cc4ff452
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 40 additions and 0 deletions

View File

@ -717,6 +717,26 @@ class VAE:
except:
return None
class ForcedTiledVAE(VAE):
def __init__(self, tile_size, overlap, temporal_size=64, temporal_overlap=8, sd=None, device=None, config=None, dtype=None, metadata=None):
super().__init__(sd, device, config, dtype, metadata)
self.tile_size = tile_size
self.overlap = overlap
self.temporal_size = temporal_size
self.temporal_overlap = temporal_overlap
@classmethod
def cast(cls, vae: VAE, tile_size, overlap, temporal_size=64, temporal_overlap=8):
vae.__class__ = cls
vae.tile_size = tile_size
vae.overlap = overlap
vae.temporal_size = temporal_size
vae.temporal_overlap = temporal_overlap
return vae
def encode(self, pixel_samples):
return super().encode_tiled(pixel_samples, self.tile_size, self.tile_size, self.overlap, self.temporal_size, self.temporal_overlap)
class StyleModel:
def __init__(self, model, device="cpu"):
self.model = model

View File

@ -356,6 +356,24 @@ class VAEEncodeTiled:
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
return ({"samples": t}, )
class ForceTiledVAEEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": {"vae": ("VAE", ),
"tile_size": ("INT", {"default": 256, "min": 64, "max": 4096, "step": 32}),
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
"temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time."}),
"temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}),
}}
RETURN_TYPES = ("VAE",)
FUNCTION = "cast"
CATEGORY = "_for_testing"
def cast(self, vae, tile_size, overlap, temporal_size=64, temporal_overlap=8):
tiled_vae = comfy.sd.ForcedTiledVAE.cast(vae, tile_size, overlap, temporal_size, temporal_overlap)
return (tiled_vae, )
class VAEEncodeForInpaint:
@classmethod
def INPUT_TYPES(s):
@ -1997,6 +2015,7 @@ NODE_CLASS_MAPPINGS = {
"CLIPVisionLoader": CLIPVisionLoader,
"VAEDecodeTiled": VAEDecodeTiled,
"VAEEncodeTiled": VAEEncodeTiled,
"ForceTiledVAEEncode": ForceTiledVAEEncode,
"unCLIPCheckpointLoader": unCLIPCheckpointLoader,
"GLIGENLoader": GLIGENLoader,
"GLIGENTextBoxApply": GLIGENTextBoxApply,
@ -2078,6 +2097,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
# _for_testing
"VAEDecodeTiled": "VAE Decode (Tiled)",
"VAEEncodeTiled": "VAE Encode (Tiled)",
"ForceTiledVAEEncode": "Force tiled VAE Encode",
}
EXTENSION_WEB_DIRS = {}