It doesn't make sense for c_crossattn and c_concat to be lists.

This commit is contained in:
comfyanonymous 2023-08-31 13:25:00 -04:00
parent 5f101f4da1
commit 1c012d69af
2 changed files with 4 additions and 4 deletions

View File

@ -50,10 +50,10 @@ class BaseModel(torch.nn.Module):
def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}): def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}):
if c_concat is not None: if c_concat is not None:
xc = torch.cat([x] + c_concat, dim=1) xc = torch.cat([x] + [c_concat], dim=1)
else: else:
xc = x xc = x
context = torch.cat(c_crossattn, 1) context = c_crossattn
dtype = self.get_dtype() dtype = self.get_dtype()
xc = xc.to(dtype) xc = xc.to(dtype)
t = t.to(dtype) t = t.to(dtype)

View File

@ -165,9 +165,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
c_crossattn_out.append(c) c_crossattn_out.append(c)
if len(c_crossattn_out) > 0: if len(c_crossattn_out) > 0:
out['c_crossattn'] = [torch.cat(c_crossattn_out)] out['c_crossattn'] = torch.cat(c_crossattn_out)
if len(c_concat) > 0: if len(c_concat) > 0:
out['c_concat'] = [torch.cat(c_concat)] out['c_concat'] = torch.cat(c_concat)
if len(c_adm) > 0: if len(c_adm) > 0:
out['c_adm'] = torch.cat(c_adm) out['c_adm'] = torch.cat(c_adm)
return out return out