Pass extra conds directly to unet.

This commit is contained in:
comfyanonymous 2023-10-25 00:07:53 -04:00
parent 036f88c621
commit d1d2fea806

View File

@ -50,7 +50,7 @@ class BaseModel(torch.nn.Module):
self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32)) self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}, **kwargs): def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
if c_concat is not None: if c_concat is not None:
xc = torch.cat([x] + [c_concat], dim=1) xc = torch.cat([x] + [c_concat], dim=1)
else: else:
@ -60,9 +60,10 @@ class BaseModel(torch.nn.Module):
xc = xc.to(dtype) xc = xc.to(dtype)
t = t.to(dtype) t = t.to(dtype)
context = context.to(dtype) context = context.to(dtype)
if c_adm is not None: extra_conds = {}
c_adm = c_adm.to(dtype) for o in kwargs:
return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options).float() extra_conds[o] = kwargs[o].to(dtype)
return self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
def get_dtype(self): def get_dtype(self):
return self.diffusion_model.dtype return self.diffusion_model.dtype
@ -107,7 +108,7 @@ class BaseModel(torch.nn.Module):
out['c_concat'] = comfy.conds.CONDNoiseShape(data) out['c_concat'] = comfy.conds.CONDNoiseShape(data)
adm = self.encode_adm(**kwargs) adm = self.encode_adm(**kwargs)
if adm is not None: if adm is not None:
out['c_adm'] = comfy.conds.CONDRegular(adm) out['y'] = comfy.conds.CONDRegular(adm)
return out return out
def load_model_weights(self, sd, unet_prefix=""): def load_model_weights(self, sd, unet_prefix=""):