From 1c012d69afa8bd92a007a3e468e2a1f874365d39 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 31 Aug 2023 13:25:00 -0400 Subject: [PATCH] It doesn't make sense for c_crossattn and c_concat to be lists. --- comfy/model_base.py | 4 ++-- comfy/samplers.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index acd4169a..677a23de 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -50,10 +50,10 @@ class BaseModel(torch.nn.Module): def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}): if c_concat is not None: - xc = torch.cat([x] + c_concat, dim=1) + xc = torch.cat([x] + [c_concat], dim=1) else: xc = x - context = torch.cat(c_crossattn, 1) + context = c_crossattn dtype = self.get_dtype() xc = xc.to(dtype) t = t.to(dtype) diff --git a/comfy/samplers.py b/comfy/samplers.py index 134336de..103ac33f 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -165,9 +165,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con c_crossattn_out.append(c) if len(c_crossattn_out) > 0: - out['c_crossattn'] = [torch.cat(c_crossattn_out)] + out['c_crossattn'] = torch.cat(c_crossattn_out) if len(c_concat) > 0: - out['c_concat'] = [torch.cat(c_concat)] + out['c_concat'] = torch.cat(c_concat) if len(c_adm) > 0: out['c_adm'] = torch.cat(c_adm) return out