From 492db2de8db7e082addf131b40adb4a1b7535821 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 21 Sep 2023 01:14:42 -0400 Subject: [PATCH] Allow having a different pooled output for each image in a batch. --- comfy/model_base.py | 4 ++-- comfy/samplers.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index ca154dba..ed2dc83e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -181,7 +181,7 @@ class SDXLRefiner(BaseModel): out.append(self.embedder(torch.Tensor([crop_h]))) out.append(self.embedder(torch.Tensor([crop_w]))) out.append(self.embedder(torch.Tensor([aesthetic_score]))) - flat = torch.flatten(torch.cat(out))[None, ] + flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1) class SDXL(BaseModel): @@ -206,5 +206,5 @@ class SDXL(BaseModel): out.append(self.embedder(torch.Tensor([crop_w]))) out.append(self.embedder(torch.Tensor([target_height]))) out.append(self.embedder(torch.Tensor([target_width]))) - flat = torch.flatten(torch.cat(out))[None, ] + flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1) diff --git a/comfy/samplers.py b/comfy/samplers.py index 57673a02..e3192ca5 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -7,6 +7,7 @@ from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps import math from comfy import model_base +import comfy.utils def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) return abs(a*b) // math.gcd(a, b) @@ -538,7 +539,7 @@ def encode_adm(model, conds, batch_size, width, height, device, prompt_type): if adm_out is not None: x[1] = x[1].copy() - x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size).to(device) + x[1]["adm_encoded"] = comfy.utils.repeat_to_batch_size(adm_out, batch_size).to(device) return conds