diff --git a/comfy/model_base.py b/comfy/model_base.py index f45b375de..80f6667ec 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -627,3 +627,12 @@ class StableAudio1(BaseModel): cross_attn = torch.cat([cross_attn.to(device), seconds_start_embed.repeat((cross_attn.shape[0], 1, 1)), seconds_total_embed.repeat((cross_attn.shape[0], 1, 1))], dim=1) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out + + def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + sd = super().state_dict_for_saving(clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) + d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()} + for k in d: + s = d[k] + for l in s: + sd["{}{}".format(k, l)] = s[l] + return sd diff --git a/comfy/sd.py b/comfy/sd.py index ea6e9b663..dda0887ba 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -236,7 +236,7 @@ class VAE: self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig}, decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig}) - elif "decoder.layers.0.weight_v" in sd: + elif "decoder.layers.1.layers.0.beta" in sd: self.first_stage_model = AudioOobleckVAE() self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 761498dbc..21fdb7ec7 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -543,13 +543,16 @@ class StableAudio(supported_models_base.BASE): seconds_total_sd = utils.state_dict_prefix_replace(state_dict, {"conditioner.conditioners.seconds_total.": ""}, filter_keys=True) return model_base.StableAudio1(self, seconds_start_embedder_weights=seconds_start_sd, seconds_total_embedder_weights=seconds_total_sd, device=device) - def process_unet_state_dict(self, state_dict): for k in list(state_dict.keys()): if k.endswith(".cross_attend_norm.beta") or k.endswith(".ff_norm.beta") or k.endswith(".pre_norm.beta"): #These weights are all zero state_dict.pop(k) return state_dict + def process_unet_state_dict_for_saving(self, state_dict): + replace_prefix = {"": "model.model."} + return utils.state_dict_prefix_replace(state_dict, replace_prefix) + def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(sa_t5.SAT5Tokenizer, sa_t5.SAT5Model)