mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 10:25:16 +00:00
It doesn't make sense for c_crossattn and c_concat to be lists.
This commit is contained in:
parent
5f101f4da1
commit
1c012d69af
@ -50,10 +50,10 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}):
|
def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}):
|
||||||
if c_concat is not None:
|
if c_concat is not None:
|
||||||
xc = torch.cat([x] + c_concat, dim=1)
|
xc = torch.cat([x] + [c_concat], dim=1)
|
||||||
else:
|
else:
|
||||||
xc = x
|
xc = x
|
||||||
context = torch.cat(c_crossattn, 1)
|
context = c_crossattn
|
||||||
dtype = self.get_dtype()
|
dtype = self.get_dtype()
|
||||||
xc = xc.to(dtype)
|
xc = xc.to(dtype)
|
||||||
t = t.to(dtype)
|
t = t.to(dtype)
|
||||||
|
@ -165,9 +165,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
c_crossattn_out.append(c)
|
c_crossattn_out.append(c)
|
||||||
|
|
||||||
if len(c_crossattn_out) > 0:
|
if len(c_crossattn_out) > 0:
|
||||||
out['c_crossattn'] = [torch.cat(c_crossattn_out)]
|
out['c_crossattn'] = torch.cat(c_crossattn_out)
|
||||||
if len(c_concat) > 0:
|
if len(c_concat) > 0:
|
||||||
out['c_concat'] = [torch.cat(c_concat)]
|
out['c_concat'] = torch.cat(c_concat)
|
||||||
if len(c_adm) > 0:
|
if len(c_adm) > 0:
|
||||||
out['c_adm'] = torch.cat(c_adm)
|
out['c_adm'] = torch.cat(c_adm)
|
||||||
return out
|
return out
|
||||||
|
Loading…
Reference in New Issue
Block a user