Use common function to reshape batch to.

This commit is contained in:
comfyanonymous 2023-09-02 03:42:49 -04:00
parent 36ea8784a8
commit 77a176f9e0
2 changed files with 10 additions and 5 deletions

View File

@ -1,6 +1,7 @@
import torch
import comfy.model_management
import comfy.samplers
import comfy.utils
import math
import numpy as np
@ -28,8 +29,7 @@ def prepare_mask(noise_mask, shape, device):
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
noise_mask = noise_mask.round()
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
if noise_mask.shape[0] < shape[0]:
noise_mask = noise_mask.repeat(math.ceil(shape[0] / noise_mask.shape[0]), 1, 1, 1)[:shape[0]]
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
noise_mask = noise_mask.to(device)
return noise_mask
@ -37,9 +37,7 @@ def broadcast_cond(cond, batch, device):
"""broadcasts conditioning to the batch size"""
copy = []
for p in cond:
t = p[0]
if t.shape[0] < batch:
t = torch.cat([t] * batch)
t = comfy.utils.repeat_to_batch_size(p[0], batch)
t = t.to(device)
copy += [[t] + p[1:]]
return copy

View File

@ -223,6 +223,13 @@ def unet_to_diffusers(unet_config):
return diffusers_unet_map
def repeat_to_batch_size(tensor, batch_size):
if tensor.shape[0] > batch_size:
return tensor[:batch_size]
elif tensor.shape[0] < batch_size:
return tensor.repeat([math.ceil(batch_size / tensor.shape[0])] + [1] * (len(tensor.shape) - 1))[:batch_size]
return tensor
def convert_sd_to(state_dict, dtype):
keys = list(state_dict.keys())
for k in keys: