import io import nodes import node_helpers import torch import comfy.model_management import comfy.model_sampling import comfy.utils import math import numpy as np import av from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords class EmptyLTXVLatentVideo: @classmethod def INPUT_TYPES(s): return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), "length": ("INT", {"default": 97, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} RETURN_TYPES = ("LATENT",) FUNCTION = "generate" CATEGORY = "latent/video/ltxv" def generate(self, width, height, length, batch_size=1): latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device()) return ({"samples": latent}, ) class LTXVImgToVideo: @classmethod def INPUT_TYPES(s): return {"required": {"positive": ("CONDITIONING", ), "negative": ("CONDITIONING", ), "vae": ("VAE",), "image": ("IMAGE",), "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), "length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), }} RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_NAMES = ("positive", "negative", "latent") CATEGORY = "conditioning/video_models" FUNCTION = "generate" def generate(self, positive, negative, image, vae, width, height, length, batch_size): pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) encode_pixels = pixels[:, :, :, :3] t = vae.encode(encode_pixels) latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device()) latent[:, :, :t.shape[2]] = t conditioning_latent_frames_mask = torch.ones( (batch_size, 1, latent.shape[2], 1, 1), dtype=torch.float32, device=latent.device, ) conditioning_latent_frames_mask[:, :, :t.shape[2]] = 0 return (positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}, ) def conditioning_get_any_value(conditioning, key, default=None): for t in conditioning: if key in t[1]: return t[1][key] return default def get_noise_mask(latent): noise_mask = latent.get("noise_mask", None) latent_image = latent["samples"] if noise_mask is None: batch_size, _, latent_length, _, _ = latent_image.shape noise_mask = torch.ones( (batch_size, 1, latent_length, 1, 1), dtype=torch.float32, device=latent_image.device, ) else: noise_mask = noise_mask.clone() return noise_mask def get_keyframe_idxs(cond): keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None) if keyframe_idxs is None: return None, 0 num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0] return keyframe_idxs, num_keyframes class LTXVAddGuide: @classmethod def INPUT_TYPES(s): return {"required": {"positive": ("CONDITIONING", ), "negative": ("CONDITIONING", ), "vae": ("VAE",), "latent": ("LATENT",), "image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames." \ "If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}), "frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999, "tooltip": "Frame index to start the conditioning at. Must be divisible by 8. " \ "If a frame is not divisible by 8, it will be rounded down to the nearest multiple of 8. " \ "Negative values are counted from the end of the video."}), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), } } RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_NAMES = ("positive", "negative", "latent") CATEGORY = "conditioning/video_models" FUNCTION = "generate" def __init__(self): self._num_prefix_frames = 2 self._patchifier = SymmetricPatchifier(1) def encode(self, vae, latent_width, latent_height, images, scale_factors): time_scale_factor, width_scale_factor, height_scale_factor = scale_factors images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1] pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1) encode_pixels = pixels[:, :, :, :3] t = vae.encode(encode_pixels) return encode_pixels, t def get_latent_index(self, cond, latent_length, frame_idx, scale_factors): time_scale_factor, _, _ = scale_factors _, num_keyframes = get_keyframe_idxs(cond) latent_count = latent_length - num_keyframes frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * 8 + 1 + frame_idx, 0) frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8 latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor return frame_idx, latent_idx def add_keyframe_index(self, cond, frame_idx, guiding_latent, scale_factors): keyframe_idxs, _ = get_keyframe_idxs(cond) _, latent_coords = self._patchifier.patchify(guiding_latent) pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, True) pixel_coords[:, 0] += frame_idx if keyframe_idxs is None: keyframe_idxs = pixel_coords else: keyframe_idxs = torch.cat([keyframe_idxs, pixel_coords], dim=2) return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs}) def append_keyframe(self, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors): positive = self.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors) negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors) mask = torch.full( (noise_mask.shape[0], 1, guiding_latent.shape[2], 1, 1), 1.0 - strength, dtype=noise_mask.dtype, device=noise_mask.device, ) latent_image = torch.cat([latent_image, guiding_latent], dim=2) noise_mask = torch.cat([noise_mask, mask], dim=2) return positive, negative, latent_image, noise_mask def replace_latent_frames(self, latent_image, noise_mask, guiding_latent, latent_idx, strength): cond_length = guiding_latent.shape[2] assert latent_image.shape[2] >= latent_idx + cond_length, "Conditioning frames exceed the length of the latent sequence." mask = torch.full( (noise_mask.shape[0], 1, cond_length, 1, 1), 1.0 - strength, dtype=noise_mask.dtype, device=noise_mask.device, ) latent_image = latent_image.clone() noise_mask = noise_mask.clone() latent_image[:, :, latent_idx : latent_idx + cond_length] = guiding_latent noise_mask[:, :, latent_idx : latent_idx + cond_length] = mask return latent_image, noise_mask def generate(self, positive, negative, vae, latent, image, frame_idx, strength): scale_factors = vae.downscale_index_formula latent_image = latent["samples"] noise_mask = get_noise_mask(latent) _, _, latent_length, latent_height, latent_width = latent_image.shape image, t = self.encode(vae, latent_width, latent_height, image, scale_factors) frame_idx, latent_idx = self.get_latent_index(positive, latent_length, frame_idx, scale_factors) assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." if frame_idx == 0: latent_image, noise_mask = self.replace_latent_frames(latent_image, noise_mask, t, latent_idx, strength) return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) num_prefix_frames = min(self._num_prefix_frames, t.shape[2]) positive, negative, latent_image, noise_mask = self.append_keyframe( positive, negative, frame_idx, latent_image, noise_mask, t[:, :, :num_prefix_frames], strength, scale_factors, ) latent_idx += num_prefix_frames t = t[:, :, num_prefix_frames:] if t.shape[2] == 0: return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) latent_image, noise_mask = self.replace_latent_frames( latent_image, noise_mask, t, latent_idx, strength, ) return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) class LTXVCropGuides: @classmethod def INPUT_TYPES(s): return {"required": {"positive": ("CONDITIONING", ), "negative": ("CONDITIONING", ), "latent": ("LATENT",), } } RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_NAMES = ("positive", "negative", "latent") CATEGORY = "conditioning/video_models" FUNCTION = "crop" def __init__(self): self._patchifier = SymmetricPatchifier(1) def crop(self, positive, negative, latent): latent_image = latent["samples"].clone() noise_mask = get_noise_mask(latent) _, num_keyframes = get_keyframe_idxs(positive) latent_image = latent_image[:, :, :-num_keyframes] noise_mask = noise_mask[:, :, :-num_keyframes] positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None}) negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None}) return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) class LTXVConditioning: @classmethod def INPUT_TYPES(s): return {"required": {"positive": ("CONDITIONING", ), "negative": ("CONDITIONING", ), "frame_rate": ("FLOAT", {"default": 25.0, "min": 0.0, "max": 1000.0, "step": 0.01}), }} RETURN_TYPES = ("CONDITIONING", "CONDITIONING") RETURN_NAMES = ("positive", "negative") FUNCTION = "append" CATEGORY = "conditioning/video_models" def append(self, positive, negative, frame_rate): positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate}) negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate}) return (positive, negative) class ModelSamplingLTXV: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}), "base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}), }, "optional": {"latent": ("LATENT",), } } RETURN_TYPES = ("MODEL",) FUNCTION = "patch" CATEGORY = "advanced/model" def patch(self, model, max_shift, base_shift, latent=None): m = model.clone() if latent is None: tokens = 4096 else: tokens = math.prod(latent["samples"].shape[2:]) x1 = 1024 x2 = 4096 mm = (max_shift - base_shift) / (x2 - x1) b = base_shift - mm * x1 shift = (tokens) * mm + b sampling_base = comfy.model_sampling.ModelSamplingFlux sampling_type = comfy.model_sampling.CONST class ModelSamplingAdvanced(sampling_base, sampling_type): pass model_sampling = ModelSamplingAdvanced(model.model.model_config) model_sampling.set_parameters(shift=shift) m.add_object_patch("model_sampling", model_sampling) return (m, ) class LTXVScheduler: @classmethod def INPUT_TYPES(s): return {"required": {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), "max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}), "base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}), "stretch": ("BOOLEAN", { "default": True, "tooltip": "Stretch the sigmas to be in the range [terminal, 1]." }), "terminal": ( "FLOAT", { "default": 0.1, "min": 0.0, "max": 0.99, "step": 0.01, "tooltip": "The terminal value of the sigmas after stretching." }, ), }, "optional": {"latent": ("LATENT",), } } RETURN_TYPES = ("SIGMAS",) CATEGORY = "sampling/custom_sampling/schedulers" FUNCTION = "get_sigmas" def get_sigmas(self, steps, max_shift, base_shift, stretch, terminal, latent=None): if latent is None: tokens = 4096 else: tokens = math.prod(latent["samples"].shape[2:]) sigmas = torch.linspace(1.0, 0.0, steps + 1) x1 = 1024 x2 = 4096 mm = (max_shift - base_shift) / (x2 - x1) b = base_shift - mm * x1 sigma_shift = (tokens) * mm + b power = 1 sigmas = torch.where( sigmas != 0, math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power), 0, ) # Stretch sigmas so that its final value matches the given terminal value. if stretch: non_zero_mask = sigmas != 0 non_zero_sigmas = sigmas[non_zero_mask] one_minus_z = 1.0 - non_zero_sigmas scale_factor = one_minus_z[-1] / (1.0 - terminal) stretched = 1.0 - (one_minus_z / scale_factor) sigmas[non_zero_mask] = stretched return (sigmas,) def encode_single_frame(output_file, image_array: np.ndarray, crf): container = av.open(output_file, "w", format="mp4") try: stream = container.add_stream( "h264", rate=1, options={"crf": str(crf), "preset": "veryfast"} ) stream.height = image_array.shape[0] stream.width = image_array.shape[1] av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat( format="yuv420p" ) container.mux(stream.encode(av_frame)) container.mux(stream.encode()) finally: container.close() def decode_single_frame(video_file): container = av.open(video_file) try: stream = next(s for s in container.streams if s.type == "video") frame = next(container.decode(stream)) finally: container.close() return frame.to_ndarray(format="rgb24") def preprocess(image: torch.Tensor, crf=29): if crf == 0: return image image_array = (image[:(image.shape[0] // 2) * 2, :(image.shape[1] // 2) * 2] * 255.0).byte().cpu().numpy() with io.BytesIO() as output_file: encode_single_frame(output_file, image_array, crf) video_bytes = output_file.getvalue() with io.BytesIO(video_bytes) as video_file: image_array = decode_single_frame(video_file) tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0 return tensor class LTXVPreprocess: @classmethod def INPUT_TYPES(s): return { "required": { "image": ("IMAGE",), "img_compression": ( "INT", { "default": 35, "min": 0, "max": 100, "tooltip": "Amount of compression to apply on image.", }, ), } } FUNCTION = "preprocess" RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("output_image",) CATEGORY = "image" def preprocess(self, image, img_compression): output_image = image if img_compression > 0: output_images = [] for i in range(image.shape[0]): output_images.append(preprocess(image[i], img_compression)) return (torch.stack(output_images),) NODE_CLASS_MAPPINGS = { "EmptyLTXVLatentVideo": EmptyLTXVLatentVideo, "LTXVImgToVideo": LTXVImgToVideo, "ModelSamplingLTXV": ModelSamplingLTXV, "LTXVConditioning": LTXVConditioning, "LTXVScheduler": LTXVScheduler, "LTXVAddGuide": LTXVAddGuide, "LTXVPreprocess": LTXVPreprocess, "LTXVCropGuides": LTXVCropGuides, }