From d1d2fea806c07b1519634ec6dbc8c7f60dee8f4e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 25 Oct 2023 00:07:53 -0400 Subject: [PATCH] Pass extra conds directly to unet. --- comfy/model_base.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index edc246f8..ea3ea61f 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -50,7 +50,7 @@ class BaseModel(torch.nn.Module): self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32)) self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) - def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}, **kwargs): + def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): if c_concat is not None: xc = torch.cat([x] + [c_concat], dim=1) else: @@ -60,9 +60,10 @@ class BaseModel(torch.nn.Module): xc = xc.to(dtype) t = t.to(dtype) context = context.to(dtype) - if c_adm is not None: - c_adm = c_adm.to(dtype) - return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options).float() + extra_conds = {} + for o in kwargs: + extra_conds[o] = kwargs[o].to(dtype) + return self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() def get_dtype(self): return self.diffusion_model.dtype @@ -107,7 +108,7 @@ class BaseModel(torch.nn.Module): out['c_concat'] = comfy.conds.CONDNoiseShape(data) adm = self.encode_adm(**kwargs) if adm is not None: - out['c_adm'] = comfy.conds.CONDRegular(adm) + out['y'] = comfy.conds.CONDRegular(adm) return out def load_model_weights(self, sd, unet_prefix=""):