mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Auto detect playground v2.5 model.
This commit is contained in:
parent
d46583ecec
commit
8daedc5bf2
@ -15,9 +15,10 @@ class ModelType(Enum):
|
||||
V_PREDICTION = 2
|
||||
V_PREDICTION_EDM = 3
|
||||
STABLE_CASCADE = 4
|
||||
EDM = 5
|
||||
|
||||
|
||||
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling
|
||||
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling
|
||||
|
||||
|
||||
def model_sampling(model_config, model_type):
|
||||
@ -33,6 +34,9 @@ def model_sampling(model_config, model_type):
|
||||
elif model_type == ModelType.STABLE_CASCADE:
|
||||
c = EPS
|
||||
s = StableCascadeSampling
|
||||
elif model_type == ModelType.EDM:
|
||||
c = EDM
|
||||
s = ModelSamplingContinuousEDM
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
|
@ -163,7 +163,13 @@ class SDXL(supported_models_base.BASE):
|
||||
latent_format = latent_formats.SDXL
|
||||
|
||||
def model_type(self, state_dict, prefix=""):
|
||||
if "v_pred" in state_dict:
|
||||
if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
|
||||
self.latent_format = latent_formats.SDXL_Playground_2_5()
|
||||
self.sampling_settings["sigma_data"] = 0.5
|
||||
self.sampling_settings["sigma_max"] = 80.0
|
||||
self.sampling_settings["sigma_min"] = 0.002
|
||||
return model_base.ModelType.EDM
|
||||
elif "v_pred" in state_dict:
|
||||
return model_base.ModelType.V_PREDICTION
|
||||
else:
|
||||
return model_base.ModelType.EPS
|
||||
|
Loading…
Reference in New Issue
Block a user