Add ForcedTiledVAE subclass of VAE

This commit is contained in:
Oscar Anghammar 2025-05-03 17:31:36 +02:00
parent aee2908d03
commit 4321020417

View File

@ -717,6 +717,26 @@ class VAE:
except:
return None
class ForcedTiledVAE(VAE):
def __init__(self, tile_size, overlap, temporal_size=64, temporal_overlap=8, sd=None, device=None, config=None, dtype=None, metadata=None):
super().__init__(sd, device, config, dtype, metadata)
self.tile_size = tile_size
self.overlap = overlap
self.temporal_size = temporal_size
self.temporal_overlap = temporal_overlap
@classmethod
def cast(cls, vae: VAE, tile_size, overlap, temporal_size=64, temporal_overlap=8):
vae.__class__ = cls
vae.tile_size = tile_size
vae.overlap = overlap
vae.temporal_size = temporal_size
vae.temporal_overlap = temporal_overlap
return vae
def encode(self, pixel_samples):
return super().encode_tiled(pixel_samples, self.tile_size, self.tile_size, self.overlap, self.temporal_size, self.temporal_overlap)
class StyleModel:
def __init__(self, model, device="cpu"):
self.model = model