diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 9895d67d..eeeeaea0 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -416,9 +416,8 @@ class LTXVModel(torch.nn.Module): self.patchifier = SymmetricPatchifier(1) - def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, transformer_options={}, **kwargs): + def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, guiding_latent_noise_scale=0, transformer_options={}, **kwargs): patches_replace = transformer_options.get("patches_replace", {}) - image_noise_scale = transformer_options.get("image_noise_scale", 0.15) indices_grid = self.patchifier.get_grid( orig_num_frames=x.shape[2], @@ -433,20 +432,21 @@ class LTXVModel(torch.nn.Module): ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype) input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1)) ts *= input_ts - ts[:, :, 0] = 0.0 + ts[:, :, 0] = guiding_latent_noise_scale * (input_ts[:, :, 0] ** 2) timestep = self.patchifier.patchify(ts) input_x = x.clone() x[:, :, 0] = guiding_latent[:, :, 0] - if image_noise_scale > 0: + if guiding_latent_noise_scale > 0: if self.generator is None: self.generator = torch.Generator(device=x.device).manual_seed(42) elif self.generator.device != x.device: self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state()) noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]] - guiding_noise = image_noise_scale * (input_ts ** 2) * torch.randn(size=noise_shape, device=x.device, generator=self.generator) + scale = guiding_latent_noise_scale * (input_ts ** 2) + guiding_noise = scale * torch.randn(size=noise_shape, device=x.device, generator=self.generator) - x[:, :, 0] += guiding_noise[:, :, 0] + x[:, :, 0] = guiding_noise[:, :, 0] + x[:, :, 0] * (1.0 - scale[:, :, 0]) orig_shape = list(x.shape) diff --git a/comfy/model_base.py b/comfy/model_base.py index 8f37af66..f90ceebb 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -804,5 +804,9 @@ class LTXV(BaseModel): if guiding_latent is not None: out['guiding_latent'] = comfy.conds.CONDRegular(guiding_latent) + guiding_latent_noise_scale = kwargs.get("guiding_latent_noise_scale", None) + if guiding_latent_noise_scale is not None: + out["guiding_latent_noise_scale"] = comfy.conds.CONDConstant(guiding_latent_noise_scale) + out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25)) return out diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index bbb7a4dd..dec91241 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -1,4 +1,3 @@ -import io import nodes import node_helpers import torch @@ -33,7 +32,9 @@ class LTXVImgToVideo: "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), "length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + "image_noise_scale": ("FLOAT", {"default": 0.15, "min": 0, "max": 1.0, "step": 0.01, "tooltip": "Amount of noise to apply on conditioning image latent."}) + }} RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_NAMES = ("positive", "negative", "latent") @@ -41,12 +42,12 @@ class LTXVImgToVideo: CATEGORY = "conditioning/video_models" FUNCTION = "generate" - def generate(self, positive, negative, image, vae, width, height, length, batch_size): + def generate(self, positive, negative, image, vae, width, height, length, batch_size, image_noise_scale): pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) encode_pixels = pixels[:, :, :, :3] t = vae.encode(encode_pixels) - positive = node_helpers.conditioning_set_values(positive, {"guiding_latent": t}) - negative = node_helpers.conditioning_set_values(negative, {"guiding_latent": t}) + positive = node_helpers.conditioning_set_values(positive, {"guiding_latent": t, "guiding_latent_noise_scale": image_noise_scale}) + negative = node_helpers.conditioning_set_values(negative, {"guiding_latent": t, "guiding_latent_noise_scale": image_noise_scale}) latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device()) latent[:, :, :t.shape[2]] = t @@ -78,7 +79,6 @@ class ModelSamplingLTXV: return {"required": { "model": ("MODEL",), "max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}), "base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}), - "image_noise_scale": ("FLOAT", {"default": 0.15, "min": 0, "max": 100, "step": 0.01, "tooltip": "Amount of noise to apply on conditioning image latent."}) }, "optional": {"latent": ("LATENT",), } } @@ -88,7 +88,7 @@ class ModelSamplingLTXV: CATEGORY = "advanced/model" - def patch(self, model, max_shift, base_shift, image_noise_scale, latent=None): + def patch(self, model, max_shift, base_shift, latent=None): m = model.clone() if latent is None: @@ -111,7 +111,6 @@ class ModelSamplingLTXV: model_sampling = ModelSamplingAdvanced(model.model.model_config) model_sampling.set_parameters(shift=shift) m.add_object_patch("model_sampling", model_sampling) - m.model_options.setdefault("transformer_options", {})["image_noise_scale"] = image_noise_scale return (m, )