Auto detect playground v2.5 model.

This commit is contained in:
comfyanonymous 2024-02-27 18:03:03 -05:00
parent d46583ecec
commit 8daedc5bf2
2 changed files with 12 additions and 2 deletions

View File

@ -15,9 +15,10 @@ class ModelType(Enum):
V_PREDICTION = 2
V_PREDICTION_EDM = 3
STABLE_CASCADE = 4
EDM = 5
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling
def model_sampling(model_config, model_type):
@ -33,6 +34,9 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.STABLE_CASCADE:
c = EPS
s = StableCascadeSampling
elif model_type == ModelType.EDM:
c = EDM
s = ModelSamplingContinuousEDM
class ModelSampling(s, c):
pass

View File

@ -163,7 +163,13 @@ class SDXL(supported_models_base.BASE):
latent_format = latent_formats.SDXL
def model_type(self, state_dict, prefix=""):
if "v_pred" in state_dict:
if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
self.latent_format = latent_formats.SDXL_Playground_2_5()
self.sampling_settings["sigma_data"] = 0.5
self.sampling_settings["sigma_max"] = 80.0
self.sampling_settings["sigma_min"] = 0.002
return model_base.ModelType.EDM
elif "v_pred" in state_dict:
return model_base.ModelType.V_PREDICTION
else:
return model_base.ModelType.EPS