diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index f49cef95..9895d67d 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -379,6 +379,7 @@ class LTXVModel(torch.nn.Module): positional_embedding_max_pos=[20, 2048, 2048], dtype=None, device=None, operations=None, **kwargs): super().__init__() + self.generator = None self.dtype = dtype self.out_channels = in_channels self.inner_dim = num_attention_heads * attention_head_dim @@ -417,6 +418,7 @@ class LTXVModel(torch.nn.Module): def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, transformer_options={}, **kwargs): patches_replace = transformer_options.get("patches_replace", {}) + image_noise_scale = transformer_options.get("image_noise_scale", 0.15) indices_grid = self.patchifier.get_grid( orig_num_frames=x.shape[2], @@ -435,6 +437,17 @@ class LTXVModel(torch.nn.Module): timestep = self.patchifier.patchify(ts) input_x = x.clone() x[:, :, 0] = guiding_latent[:, :, 0] + if image_noise_scale > 0: + if self.generator is None: + self.generator = torch.Generator(device=x.device).manual_seed(42) + elif self.generator.device != x.device: + self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state()) + + noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]] + guiding_noise = image_noise_scale * (input_ts ** 2) * torch.randn(size=noise_shape, device=x.device, generator=self.generator) + + x[:, :, 0] += guiding_noise[:, :, 0] + orig_shape = list(x.shape) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index e6a48fc4..bbb7a4dd 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -1,3 +1,4 @@ +import io import nodes import node_helpers import torch @@ -77,6 +78,7 @@ class ModelSamplingLTXV: return {"required": { "model": ("MODEL",), "max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}), "base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}), + "image_noise_scale": ("FLOAT", {"default": 0.15, "min": 0, "max": 100, "step": 0.01, "tooltip": "Amount of noise to apply on conditioning image latent."}) }, "optional": {"latent": ("LATENT",), } } @@ -86,7 +88,7 @@ class ModelSamplingLTXV: CATEGORY = "advanced/model" - def patch(self, model, max_shift, base_shift, latent=None): + def patch(self, model, max_shift, base_shift, image_noise_scale, latent=None): m = model.clone() if latent is None: @@ -109,6 +111,8 @@ class ModelSamplingLTXV: model_sampling = ModelSamplingAdvanced(model.model.model_config) model_sampling.set_parameters(shift=shift) m.add_object_patch("model_sampling", model_sampling) + m.model_options.setdefault("transformer_options", {})["image_noise_scale"] = image_noise_scale + return (m, )