mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
ltxv: add noise to guidance image to ensure generated motion. (#5937)
This commit is contained in:
parent
1e21f4c14e
commit
005d2d3a13
@ -379,6 +379,7 @@ class LTXVModel(torch.nn.Module):
|
|||||||
positional_embedding_max_pos=[20, 2048, 2048],
|
positional_embedding_max_pos=[20, 2048, 2048],
|
||||||
dtype=None, device=None, operations=None, **kwargs):
|
dtype=None, device=None, operations=None, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.generator = None
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.out_channels = in_channels
|
self.out_channels = in_channels
|
||||||
self.inner_dim = num_attention_heads * attention_head_dim
|
self.inner_dim = num_attention_heads * attention_head_dim
|
||||||
@ -417,6 +418,7 @@ class LTXVModel(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, transformer_options={}, **kwargs):
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
image_noise_scale = transformer_options.get("image_noise_scale", 0.15)
|
||||||
|
|
||||||
indices_grid = self.patchifier.get_grid(
|
indices_grid = self.patchifier.get_grid(
|
||||||
orig_num_frames=x.shape[2],
|
orig_num_frames=x.shape[2],
|
||||||
@ -435,6 +437,17 @@ class LTXVModel(torch.nn.Module):
|
|||||||
timestep = self.patchifier.patchify(ts)
|
timestep = self.patchifier.patchify(ts)
|
||||||
input_x = x.clone()
|
input_x = x.clone()
|
||||||
x[:, :, 0] = guiding_latent[:, :, 0]
|
x[:, :, 0] = guiding_latent[:, :, 0]
|
||||||
|
if image_noise_scale > 0:
|
||||||
|
if self.generator is None:
|
||||||
|
self.generator = torch.Generator(device=x.device).manual_seed(42)
|
||||||
|
elif self.generator.device != x.device:
|
||||||
|
self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state())
|
||||||
|
|
||||||
|
noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]]
|
||||||
|
guiding_noise = image_noise_scale * (input_ts ** 2) * torch.randn(size=noise_shape, device=x.device, generator=self.generator)
|
||||||
|
|
||||||
|
x[:, :, 0] += guiding_noise[:, :, 0]
|
||||||
|
|
||||||
|
|
||||||
orig_shape = list(x.shape)
|
orig_shape = list(x.shape)
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import io
|
||||||
import nodes
|
import nodes
|
||||||
import node_helpers
|
import node_helpers
|
||||||
import torch
|
import torch
|
||||||
@ -77,6 +78,7 @@ class ModelSamplingLTXV:
|
|||||||
return {"required": { "model": ("MODEL",),
|
return {"required": { "model": ("MODEL",),
|
||||||
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
|
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
|
||||||
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
|
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
|
||||||
|
"image_noise_scale": ("FLOAT", {"default": 0.15, "min": 0, "max": 100, "step": 0.01, "tooltip": "Amount of noise to apply on conditioning image latent."})
|
||||||
},
|
},
|
||||||
"optional": {"latent": ("LATENT",), }
|
"optional": {"latent": ("LATENT",), }
|
||||||
}
|
}
|
||||||
@ -86,7 +88,7 @@ class ModelSamplingLTXV:
|
|||||||
|
|
||||||
CATEGORY = "advanced/model"
|
CATEGORY = "advanced/model"
|
||||||
|
|
||||||
def patch(self, model, max_shift, base_shift, latent=None):
|
def patch(self, model, max_shift, base_shift, image_noise_scale, latent=None):
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
|
|
||||||
if latent is None:
|
if latent is None:
|
||||||
@ -109,6 +111,8 @@ class ModelSamplingLTXV:
|
|||||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||||
model_sampling.set_parameters(shift=shift)
|
model_sampling.set_parameters(shift=shift)
|
||||||
m.add_object_patch("model_sampling", model_sampling)
|
m.add_object_patch("model_sampling", model_sampling)
|
||||||
|
m.model_options.setdefault("transformer_options", {})["image_noise_scale"] = image_noise_scale
|
||||||
|
|
||||||
return (m, )
|
return (m, )
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user