Allow joining a batch of images with a single mask

Previously, JoinImageWithAlpha required a batch of images to match a batch of masks. But for some use cases it's easier to provide a batch of images and a single mask.

This change automatically repeats the mask for all images in a batch.

In the same spirit, PorterDuffImageComposite will now allow a single mask for a batch of images (for both src and dst).

But also, PorterDuffImageComposite will apply the same logic to src and dst: if src contains one image, and dst is a batch it will repeat src to match dst (or the opposite).
This commit is contained in:
Denys Smirnov 2024-05-06 12:33:16 +03:00
parent 565eb6d176
commit ed0c0d1c26

View File

@ -107,10 +107,24 @@ class PorterDuffImageComposite:
CATEGORY = "mask/compositing" CATEGORY = "mask/compositing"
def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode): def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode):
batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha)) batch_size = min(len(source), len(destination))
if batch_size == 1:
if len(source) != 1:
batch_size = len(source)
elif len(destination) != 1:
batch_size = len(destination)
out_images = [] out_images = []
out_alphas = [] out_alphas = []
if batch_size != 1:
if len(source) == 1:
source = source.repeat(batch_size, 1, 1, 1)
if len(destination) == 1:
destination = destination.repeat(batch_size, 1, 1, 1)
if len(source_alpha) == 1:
source_alpha = source_alpha.repeat(batch_size, 1, 1)
if len(destination_alpha) == 1:
destination_alpha = destination_alpha.repeat(batch_size, 1, 1)
for i in range(batch_size): for i in range(batch_size):
src_image = source[i] src_image = source[i]
dst_image = destination[i] dst_image = destination[i]
@ -180,6 +194,8 @@ class JoinImageWithAlpha:
batch_size = min(len(image), len(alpha)) batch_size = min(len(image), len(alpha))
out_images = [] out_images = []
if len(alpha) == 1 and batch_size != 1:
alpha = alpha.repeat(batch_size, 1, 1, 1)
alpha = 1.0 - resize_mask(alpha, image.shape[1:]) alpha = 1.0 - resize_mask(alpha, image.shape[1:])
for i in range(batch_size): for i in range(batch_size):
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2)) out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))