diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index d266cd29..f33ed1be 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -2,10 +2,14 @@ import comfy.utils import comfy_extras.nodes_post_processing import torch -def reshape_latent_to(target_shape, latent): + +def reshape_latent_to(target_shape, latent, repeat_batch=True): if latent.shape[1:] != target_shape[1:]: - latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center") - return comfy.utils.repeat_to_batch_size(latent, target_shape[0]) + latent = comfy.utils.common_upscale(latent, target_shape[-1], target_shape[-2], "bilinear", "center") + if repeat_batch: + return comfy.utils.repeat_to_batch_size(latent, target_shape[0]) + else: + return latent class LatentAdd: @@ -116,8 +120,7 @@ class LatentBatch: s1 = samples1["samples"] s2 = samples2["samples"] - if s1.shape[1:] != s2.shape[1:]: - s2 = comfy.utils.common_upscale(s2, s1.shape[-1], s1.shape[-2], "bilinear", "center") + s2 = reshape_latent_to(s1.shape, s2, repeat_batch=False) s = torch.cat((s1, s2), dim=0) samples_out["samples"] = s samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])