Corrected joining images with alpha (for RGBA input), and checking scaling conditions

This commit is contained in:
MoonRide303 2023-09-24 00:12:55 +02:00
parent 585fb0475b
commit 214ca7197e

View File

@ -113,19 +113,21 @@ class PorterDuffImageComposite:
src_image = source[i]
dst_image = destination[i]
assert src_image.shape[2] == dst_image.shape[2] # inputs need to have same number of channels
src_alpha = source_alpha[i].unsqueeze(2)
dst_alpha = destination_alpha[i].unsqueeze(2)
if dst_alpha.shape != dst_image.shape:
upscale_input = dst_alpha[None,:,:,:].permute(0, 3, 1, 2)
if dst_alpha.shape[:2] != dst_image.shape[:2]:
upscale_input = dst_alpha.unsqueeze(0).permute(0, 3, 1, 2)
upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center')
dst_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
if src_image.shape != dst_image.shape:
upscale_input = src_image[None,:,:,:].permute(0, 3, 1, 2)
upscale_input = src_image.unsqueeze(0).permute(0, 3, 1, 2)
upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center')
src_image = upscale_output.permute(0, 2, 3, 1).squeeze(0)
if src_alpha.shape != dst_alpha.shape:
upscale_input = src_alpha[None,:,:,:].permute(0, 3, 1, 2)
upscale_input = src_alpha.unsqueeze(0).permute(0, 3, 1, 2)
upscale_output = comfy.utils.common_upscale(upscale_input, dst_alpha.shape[1], dst_alpha.shape[0], upscale_method='bicubic', crop='center')
src_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
@ -177,7 +179,7 @@ class JoinImageWithAlpha:
out_images = []
for i in range(batch_size):
out_images.append(torch.cat((image[i], alpha[i].unsqueeze(2)), dim=2))
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
result = (torch.stack(out_images),)
return result