From 77a176f9e0f4777363a414fbb006cb133d31e034 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 2 Sep 2023 03:42:49 -0400 Subject: [PATCH] Use common function to reshape batch to. --- comfy/sample.py | 8 +++----- comfy/utils.py | 7 +++++++ 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/comfy/sample.py b/comfy/sample.py index 79ea37e0d..e4730b189 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -1,6 +1,7 @@ import torch import comfy.model_management import comfy.samplers +import comfy.utils import math import numpy as np @@ -28,8 +29,7 @@ def prepare_mask(noise_mask, shape, device): noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear") noise_mask = noise_mask.round() noise_mask = torch.cat([noise_mask] * shape[1], dim=1) - if noise_mask.shape[0] < shape[0]: - noise_mask = noise_mask.repeat(math.ceil(shape[0] / noise_mask.shape[0]), 1, 1, 1)[:shape[0]] + noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0]) noise_mask = noise_mask.to(device) return noise_mask @@ -37,9 +37,7 @@ def broadcast_cond(cond, batch, device): """broadcasts conditioning to the batch size""" copy = [] for p in cond: - t = p[0] - if t.shape[0] < batch: - t = torch.cat([t] * batch) + t = comfy.utils.repeat_to_batch_size(p[0], batch) t = t.to(device) copy += [[t] + p[1:]] return copy diff --git a/comfy/utils.py b/comfy/utils.py index 693e2612d..47f4b9709 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -223,6 +223,13 @@ def unet_to_diffusers(unet_config): return diffusers_unet_map +def repeat_to_batch_size(tensor, batch_size): + if tensor.shape[0] > batch_size: + return tensor[:batch_size] + elif tensor.shape[0] < batch_size: + return tensor.repeat([math.ceil(batch_size / tensor.shape[0])] + [1] * (len(tensor.shape) - 1))[:batch_size] + return tensor + def convert_sd_to(state_dict, dtype): keys = list(state_dict.keys()) for k in keys: