From 3a100b9a550b9700d08eecb006b5accd65863925 Mon Sep 17 00:00:00 2001
From: comfyanonymous <comfyanonymous@protonmail.com>
Date: Fri, 4 Apr 2025 21:24:56 -0400
Subject: [PATCH] Disable partial offloading of audio VAE.

---
 comfy/sd.py | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/comfy/sd.py b/comfy/sd.py
index d096f496..4d3aef3e 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -265,6 +265,7 @@ class VAE:
         self.process_input = lambda image: image * 2.0 - 1.0
         self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
         self.working_dtypes = [torch.bfloat16, torch.float32]
+        self.disable_offload = False
 
         self.downscale_index_formula = None
         self.upscale_index_formula = None
@@ -337,6 +338,7 @@ class VAE:
                 self.process_output = lambda audio: audio
                 self.process_input = lambda audio: audio
                 self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
+                self.disable_offload = True
             elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight" in sd: #genmo mochi vae
                 if "blocks.2.blocks.3.stack.5.weight" in sd:
                     sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})
@@ -515,7 +517,7 @@ class VAE:
         pixel_samples = None
         try:
             memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
-            model_management.load_models_gpu([self.patcher], memory_required=memory_used)
+            model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
             free_memory = model_management.get_free_memory(self.device)
             batch_number = int(free_memory / memory_used)
             batch_number = max(1, batch_number)
@@ -544,7 +546,7 @@ class VAE:
     def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
         self.throw_exception_if_invalid()
         memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
-        model_management.load_models_gpu([self.patcher], memory_required=memory_used)
+        model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
         dims = samples.ndim - 2
         args = {}
         if tile_x is not None:
@@ -578,7 +580,7 @@ class VAE:
             pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
         try:
             memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
-            model_management.load_models_gpu([self.patcher], memory_required=memory_used)
+            model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
             free_memory = model_management.get_free_memory(self.device)
             batch_number = int(free_memory / max(1, memory_used))
             batch_number = max(1, batch_number)
@@ -612,7 +614,7 @@ class VAE:
             pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
 
         memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)  # TODO: calculate mem required for tile
-        model_management.load_models_gpu([self.patcher], memory_required=memory_used)
+        model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
 
         args = {}
         if tile_x is not None: