From 8daedc5bf2ac106f1920c634866198c82e06997e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 27 Feb 2024 18:03:03 -0500 Subject: [PATCH] Auto detect playground v2.5 model. --- comfy/model_base.py | 6 +++++- comfy/supported_models.py | 8 +++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 421f271b..170b1fd4 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -15,9 +15,10 @@ class ModelType(Enum): V_PREDICTION = 2 V_PREDICTION_EDM = 3 STABLE_CASCADE = 4 + EDM = 5 -from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling +from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling def model_sampling(model_config, model_type): @@ -33,6 +34,9 @@ def model_sampling(model_config, model_type): elif model_type == ModelType.STABLE_CASCADE: c = EPS s = StableCascadeSampling + elif model_type == ModelType.EDM: + c = EDM + s = ModelSamplingContinuousEDM class ModelSampling(s, c): pass diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 5d57a31a..74908216 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -163,7 +163,13 @@ class SDXL(supported_models_base.BASE): latent_format = latent_formats.SDXL def model_type(self, state_dict, prefix=""): - if "v_pred" in state_dict: + if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5 + self.latent_format = latent_formats.SDXL_Playground_2_5() + self.sampling_settings["sigma_data"] = 0.5 + self.sampling_settings["sigma_max"] = 80.0 + self.sampling_settings["sigma_min"] = 0.002 + return model_base.ModelType.EDM + elif "v_pred" in state_dict: return model_base.ModelType.V_PREDICTION else: return model_base.ModelType.EPS