mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Pass extra conds directly to unet.
This commit is contained in:
parent
036f88c621
commit
d1d2fea806
@ -50,7 +50,7 @@ class BaseModel(torch.nn.Module):
|
||||
self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
||||
self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}, **kwargs):
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
if c_concat is not None:
|
||||
xc = torch.cat([x] + [c_concat], dim=1)
|
||||
else:
|
||||
@ -60,9 +60,10 @@ class BaseModel(torch.nn.Module):
|
||||
xc = xc.to(dtype)
|
||||
t = t.to(dtype)
|
||||
context = context.to(dtype)
|
||||
if c_adm is not None:
|
||||
c_adm = c_adm.to(dtype)
|
||||
return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options).float()
|
||||
extra_conds = {}
|
||||
for o in kwargs:
|
||||
extra_conds[o] = kwargs[o].to(dtype)
|
||||
return self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||
|
||||
def get_dtype(self):
|
||||
return self.diffusion_model.dtype
|
||||
@ -107,7 +108,7 @@ class BaseModel(torch.nn.Module):
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(data)
|
||||
adm = self.encode_adm(**kwargs)
|
||||
if adm is not None:
|
||||
out['c_adm'] = comfy.conds.CONDRegular(adm)
|
||||
out['y'] = comfy.conds.CONDRegular(adm)
|
||||
return out
|
||||
|
||||
def load_model_weights(self, sd, unet_prefix=""):
|
||||
|
Loading…
Reference in New Issue
Block a user