mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Pass extra conds directly to unet.
This commit is contained in:
parent
036f88c621
commit
d1d2fea806
@ -50,7 +50,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
||||||
self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||||
|
|
||||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}, **kwargs):
|
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||||
if c_concat is not None:
|
if c_concat is not None:
|
||||||
xc = torch.cat([x] + [c_concat], dim=1)
|
xc = torch.cat([x] + [c_concat], dim=1)
|
||||||
else:
|
else:
|
||||||
@ -60,9 +60,10 @@ class BaseModel(torch.nn.Module):
|
|||||||
xc = xc.to(dtype)
|
xc = xc.to(dtype)
|
||||||
t = t.to(dtype)
|
t = t.to(dtype)
|
||||||
context = context.to(dtype)
|
context = context.to(dtype)
|
||||||
if c_adm is not None:
|
extra_conds = {}
|
||||||
c_adm = c_adm.to(dtype)
|
for o in kwargs:
|
||||||
return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options).float()
|
extra_conds[o] = kwargs[o].to(dtype)
|
||||||
|
return self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||||
|
|
||||||
def get_dtype(self):
|
def get_dtype(self):
|
||||||
return self.diffusion_model.dtype
|
return self.diffusion_model.dtype
|
||||||
@ -107,7 +108,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
out['c_concat'] = comfy.conds.CONDNoiseShape(data)
|
out['c_concat'] = comfy.conds.CONDNoiseShape(data)
|
||||||
adm = self.encode_adm(**kwargs)
|
adm = self.encode_adm(**kwargs)
|
||||||
if adm is not None:
|
if adm is not None:
|
||||||
out['c_adm'] = comfy.conds.CONDRegular(adm)
|
out['y'] = comfy.conds.CONDRegular(adm)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def load_model_weights(self, sd, unet_prefix=""):
|
def load_model_weights(self, sd, unet_prefix=""):
|
||||||
|
Loading…
Reference in New Issue
Block a user