mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-16 08:33:29 +00:00
Merge branch 'comfyanonymous:master' into chroma-support
This commit is contained in:
commit
1ca0353b96
@ -847,6 +847,7 @@ class SpatialTransformer(nn.Module):
|
|||||||
if not isinstance(context, list):
|
if not isinstance(context, list):
|
||||||
context = [context] * len(self.transformer_blocks)
|
context = [context] * len(self.transformer_blocks)
|
||||||
b, c, h, w = x.shape
|
b, c, h, w = x.shape
|
||||||
|
transformer_options["activations_shape"] = list(x.shape)
|
||||||
x_in = x
|
x_in = x
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
if not self.use_linear:
|
if not self.use_linear:
|
||||||
@ -962,6 +963,7 @@ class SpatialVideoTransformer(SpatialTransformer):
|
|||||||
transformer_options={}
|
transformer_options={}
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
_, _, h, w = x.shape
|
_, _, h, w = x.shape
|
||||||
|
transformer_options["activations_shape"] = list(x.shape)
|
||||||
x_in = x
|
x_in = x
|
||||||
spatial_context = None
|
spatial_context = None
|
||||||
if exists(context):
|
if exists(context):
|
||||||
|
10
comfy/sd.py
10
comfy/sd.py
@ -266,6 +266,7 @@ class VAE:
|
|||||||
self.process_input = lambda image: image * 2.0 - 1.0
|
self.process_input = lambda image: image * 2.0 - 1.0
|
||||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
self.disable_offload = False
|
||||||
|
|
||||||
self.downscale_index_formula = None
|
self.downscale_index_formula = None
|
||||||
self.upscale_index_formula = None
|
self.upscale_index_formula = None
|
||||||
@ -338,6 +339,7 @@ class VAE:
|
|||||||
self.process_output = lambda audio: audio
|
self.process_output = lambda audio: audio
|
||||||
self.process_input = lambda audio: audio
|
self.process_input = lambda audio: audio
|
||||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
self.disable_offload = True
|
||||||
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight" in sd: #genmo mochi vae
|
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight" in sd: #genmo mochi vae
|
||||||
if "blocks.2.blocks.3.stack.5.weight" in sd:
|
if "blocks.2.blocks.3.stack.5.weight" in sd:
|
||||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})
|
||||||
@ -516,7 +518,7 @@ class VAE:
|
|||||||
pixel_samples = None
|
pixel_samples = None
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||||
free_memory = model_management.get_free_memory(self.device)
|
free_memory = model_management.get_free_memory(self.device)
|
||||||
batch_number = int(free_memory / memory_used)
|
batch_number = int(free_memory / memory_used)
|
||||||
batch_number = max(1, batch_number)
|
batch_number = max(1, batch_number)
|
||||||
@ -545,7 +547,7 @@ class VAE:
|
|||||||
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||||
self.throw_exception_if_invalid()
|
self.throw_exception_if_invalid()
|
||||||
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||||
dims = samples.ndim - 2
|
dims = samples.ndim - 2
|
||||||
args = {}
|
args = {}
|
||||||
if tile_x is not None:
|
if tile_x is not None:
|
||||||
@ -579,7 +581,7 @@ class VAE:
|
|||||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||||
free_memory = model_management.get_free_memory(self.device)
|
free_memory = model_management.get_free_memory(self.device)
|
||||||
batch_number = int(free_memory / max(1, memory_used))
|
batch_number = int(free_memory / max(1, memory_used))
|
||||||
batch_number = max(1, batch_number)
|
batch_number = max(1, batch_number)
|
||||||
@ -613,7 +615,7 @@ class VAE:
|
|||||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||||
|
|
||||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile
|
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||||
|
|
||||||
args = {}
|
args = {}
|
||||||
if tile_x is not None:
|
if tile_x is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user