diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 4370516b9..ff27b09a8 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -31,6 +31,7 @@ class EPS: return model_input - model_output * sigma def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) if max_denoise: noise = noise * torch.sqrt(1.0 + sigma ** 2.0) else: @@ -61,9 +62,11 @@ class CONST: return model_input - model_output * sigma def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) return sigma * noise + (1.0 - sigma) * latent_image def inverse_noise_scaling(self, sigma, latent): + sigma = sigma.view(sigma.shape[:1] + (1,) * (latent.ndim - 1)) return latent / (1.0 - sigma) class ModelSamplingDiscrete(torch.nn.Module):