Move beta_schedule to model_config and allow disabling unet creation.

This commit is contained in:
comfyanonymous 2023-08-29 14:22:53 -04:00
parent 968078b149
commit 15adc3699f
2 changed files with 4 additions and 2 deletions

View File

@ -19,8 +19,9 @@ class BaseModel(torch.nn.Module):
unet_config = model_config.unet_config
self.latent_format = model_config.latent_format
self.model_config = model_config
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
self.diffusion_model = UNetModel(**unet_config, device=device)
self.register_schedule(given_betas=None, beta_schedule=model_config.beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
if not unet_config.get("disable_unet_model_creation", False):
self.diffusion_model = UNetModel(**unet_config, device=device)
self.model_type = model_type
self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None:

View File

@ -33,6 +33,7 @@ class BASE:
clip_prefix = []
clip_vision_prefix = None
noise_aug_config = None
beta_schedule = "linear"
@classmethod
def matches(s, unet_config):