import io
import nodes
import node_helpers
import torch
import comfy.model_management
import comfy.model_sampling
import comfy.utils
import math
import numpy as np
import av
from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords

class EmptyLTXVLatentVideo:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
                              "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
                              "length": ("INT", {"default": 97, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
                              "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
    RETURN_TYPES = ("LATENT",)
    FUNCTION = "generate"

    CATEGORY = "latent/video/ltxv"

    def generate(self, width, height, length, batch_size=1):
        latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
        return ({"samples": latent}, )


class LTXVImgToVideo:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {"positive": ("CONDITIONING", ),
                             "negative": ("CONDITIONING", ),
                             "vae": ("VAE",),
                             "image": ("IMAGE",),
                             "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
                             "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
                             "length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
                             "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
                             }}

    RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
    RETURN_NAMES = ("positive", "negative", "latent")

    CATEGORY = "conditioning/video_models"
    FUNCTION = "generate"

    def generate(self, positive, negative, image, vae, width, height, length, batch_size):
        pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
        encode_pixels = pixels[:, :, :, :3]
        t = vae.encode(encode_pixels)

        latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
        latent[:, :, :t.shape[2]] = t

        conditioning_latent_frames_mask = torch.ones(
            (batch_size, 1, latent.shape[2], 1, 1),
            dtype=torch.float32,
            device=latent.device,
        )
        conditioning_latent_frames_mask[:, :, :t.shape[2]] = 0

        return (positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}, )


def conditioning_get_any_value(conditioning, key, default=None):
    for t in conditioning:
        if key in t[1]:
            return t[1][key]
    return default


def get_noise_mask(latent):
    noise_mask = latent.get("noise_mask", None)
    latent_image = latent["samples"]
    if noise_mask is None:
        batch_size, _, latent_length, _, _ = latent_image.shape
        noise_mask = torch.ones(
            (batch_size, 1, latent_length, 1, 1),
            dtype=torch.float32,
            device=latent_image.device,
        )
    else:
        noise_mask = noise_mask.clone()
    return noise_mask

def get_keyframe_idxs(cond):
    keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
    if keyframe_idxs is None:
        return None, 0
    num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0]
    return keyframe_idxs, num_keyframes

class LTXVAddGuide:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {"positive": ("CONDITIONING", ),
                             "negative": ("CONDITIONING", ),
                             "vae": ("VAE",),
                             "latent": ("LATENT",),
                             "image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames." \
                                                 "If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}),
                             "frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999,
                                                   "tooltip": "Frame index to start the conditioning at. Must be divisible by 8. " \
                                                   "If a frame is not divisible by 8, it will be rounded down to the nearest multiple of 8. " \
                                                   "Negative values are counted from the end of the video."}),
                             "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
                             }
            }

    RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
    RETURN_NAMES = ("positive", "negative", "latent")

    CATEGORY = "conditioning/video_models"
    FUNCTION = "generate"

    def __init__(self):
        self._num_prefix_frames = 2
        self._patchifier = SymmetricPatchifier(1)

    def encode(self, vae, latent_width, latent_height, images, scale_factors):
        time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
        images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
        pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1)
        encode_pixels = pixels[:, :, :, :3]
        t = vae.encode(encode_pixels)
        return encode_pixels, t

    def get_latent_index(self, cond, latent_length, frame_idx, scale_factors):
        time_scale_factor, _, _ = scale_factors
        _, num_keyframes = get_keyframe_idxs(cond)
        latent_count = latent_length - num_keyframes
        frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * 8 + 1 + frame_idx, 0)
        frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8

        latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor

        return frame_idx, latent_idx

    def add_keyframe_index(self, cond, frame_idx, guiding_latent, scale_factors):
        keyframe_idxs, _ = get_keyframe_idxs(cond)
        _, latent_coords = self._patchifier.patchify(guiding_latent)
        pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, True)
        pixel_coords[:, 0] += frame_idx
        if keyframe_idxs is None:
            keyframe_idxs = pixel_coords
        else:
            keyframe_idxs = torch.cat([keyframe_idxs, pixel_coords], dim=2)
        return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})

    def append_keyframe(self, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
        positive = self.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
        negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)

        mask = torch.full(
            (noise_mask.shape[0], 1, guiding_latent.shape[2], 1, 1),
            1.0 - strength,
            dtype=noise_mask.dtype,
            device=noise_mask.device,
        )

        latent_image = torch.cat([latent_image, guiding_latent], dim=2)
        noise_mask = torch.cat([noise_mask, mask], dim=2)
        return positive, negative, latent_image, noise_mask

    def replace_latent_frames(self, latent_image, noise_mask, guiding_latent, latent_idx, strength):
        cond_length = guiding_latent.shape[2]
        assert latent_image.shape[2] >= latent_idx + cond_length, "Conditioning frames exceed the length of the latent sequence."

        mask = torch.full(
            (noise_mask.shape[0], 1, cond_length, 1, 1),
            1.0 - strength,
            dtype=noise_mask.dtype,
            device=noise_mask.device,
        )

        latent_image = latent_image.clone()
        noise_mask = noise_mask.clone()

        latent_image[:, :, latent_idx : latent_idx + cond_length] = guiding_latent
        noise_mask[:, :, latent_idx : latent_idx + cond_length] = mask

        return latent_image, noise_mask

    def generate(self, positive, negative, vae, latent, image, frame_idx, strength):
        scale_factors = vae.downscale_index_formula
        latent_image = latent["samples"]
        noise_mask = get_noise_mask(latent)

        _, _, latent_length, latent_height, latent_width = latent_image.shape
        image, t = self.encode(vae, latent_width, latent_height, image, scale_factors)

        frame_idx, latent_idx = self.get_latent_index(positive, latent_length, frame_idx, scale_factors)
        assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."

        num_prefix_frames = min(self._num_prefix_frames, t.shape[2])

        positive, negative, latent_image, noise_mask = self.append_keyframe(
            positive,
            negative,
            frame_idx,
            latent_image,
            noise_mask,
            t[:, :, :num_prefix_frames],
            strength,
            scale_factors,
        )

        latent_idx += num_prefix_frames

        t = t[:, :, num_prefix_frames:]
        if t.shape[2] == 0:
            return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)

        latent_image, noise_mask = self.replace_latent_frames(
            latent_image,
            noise_mask,
            t,
            latent_idx,
            strength,
        )

        return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)


class LTXVCropGuides:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {"positive": ("CONDITIONING", ),
                             "negative": ("CONDITIONING", ),
                             "latent": ("LATENT",),
                             }
            }

    RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
    RETURN_NAMES = ("positive", "negative", "latent")

    CATEGORY = "conditioning/video_models"
    FUNCTION = "crop"

    def __init__(self):
        self._patchifier = SymmetricPatchifier(1)

    def crop(self, positive, negative, latent):
        latent_image = latent["samples"].clone()
        noise_mask = get_noise_mask(latent)

        _, num_keyframes = get_keyframe_idxs(positive)
        if num_keyframes == 0:
            return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)

        latent_image = latent_image[:, :, :-num_keyframes]
        noise_mask = noise_mask[:, :, :-num_keyframes]

        positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None})
        negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None})

        return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)


class LTXVConditioning:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {"positive": ("CONDITIONING", ),
                             "negative": ("CONDITIONING", ),
                             "frame_rate": ("FLOAT", {"default": 25.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
                             }}
    RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
    RETURN_NAMES = ("positive", "negative")
    FUNCTION = "append"

    CATEGORY = "conditioning/video_models"

    def append(self, positive, negative, frame_rate):
        positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate})
        negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate})
        return (positive, negative)


class ModelSamplingLTXV:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "model": ("MODEL",),
                              "max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
                              "base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
                              },
                "optional": {"latent": ("LATENT",), }
                }

    RETURN_TYPES = ("MODEL",)
    FUNCTION = "patch"

    CATEGORY = "advanced/model"

    def patch(self, model, max_shift, base_shift, latent=None):
        m = model.clone()

        if latent is None:
            tokens = 4096
        else:
            tokens = math.prod(latent["samples"].shape[2:])

        x1 = 1024
        x2 = 4096
        mm = (max_shift - base_shift) / (x2 - x1)
        b = base_shift - mm * x1
        shift = (tokens) * mm + b

        sampling_base = comfy.model_sampling.ModelSamplingFlux
        sampling_type = comfy.model_sampling.CONST

        class ModelSamplingAdvanced(sampling_base, sampling_type):
            pass

        model_sampling = ModelSamplingAdvanced(model.model.model_config)
        model_sampling.set_parameters(shift=shift)
        m.add_object_patch("model_sampling", model_sampling)

        return (m, )


class LTXVScheduler:
    @classmethod
    def INPUT_TYPES(s):
        return {"required":
                    {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
                     "max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
                     "base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
                     "stretch": ("BOOLEAN", {
                        "default": True,
                        "tooltip": "Stretch the sigmas to be in the range [terminal, 1]."
                    }),
                     "terminal": (
                        "FLOAT",
                        {
                            "default": 0.1, "min": 0.0, "max": 0.99, "step": 0.01,
                            "tooltip": "The terminal value of the sigmas after stretching."
                        },
                    ),
                    },
                "optional": {"latent": ("LATENT",), }
               }

    RETURN_TYPES = ("SIGMAS",)
    CATEGORY = "sampling/custom_sampling/schedulers"

    FUNCTION = "get_sigmas"

    def get_sigmas(self, steps, max_shift, base_shift, stretch, terminal, latent=None):
        if latent is None:
            tokens = 4096
        else:
            tokens = math.prod(latent["samples"].shape[2:])

        sigmas = torch.linspace(1.0, 0.0, steps + 1)

        x1 = 1024
        x2 = 4096
        mm = (max_shift - base_shift) / (x2 - x1)
        b = base_shift - mm * x1
        sigma_shift = (tokens) * mm + b

        power = 1
        sigmas = torch.where(
            sigmas != 0,
            math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
            0,
        )

        # Stretch sigmas so that its final value matches the given terminal value.
        if stretch:
            non_zero_mask = sigmas != 0
            non_zero_sigmas = sigmas[non_zero_mask]
            one_minus_z = 1.0 - non_zero_sigmas
            scale_factor = one_minus_z[-1] / (1.0 - terminal)
            stretched = 1.0 - (one_minus_z / scale_factor)
            sigmas[non_zero_mask] = stretched

        return (sigmas,)

def encode_single_frame(output_file, image_array: np.ndarray, crf):
    container = av.open(output_file, "w", format="mp4")
    try:
        stream = container.add_stream(
            "h264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
        )
        stream.height = image_array.shape[0]
        stream.width = image_array.shape[1]
        av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(
            format="yuv420p"
        )
        container.mux(stream.encode(av_frame))
        container.mux(stream.encode())
    finally:
        container.close()


def decode_single_frame(video_file):
    container = av.open(video_file)
    try:
        stream = next(s for s in container.streams if s.type == "video")
        frame = next(container.decode(stream))
    finally:
        container.close()
    return frame.to_ndarray(format="rgb24")


def preprocess(image: torch.Tensor, crf=29):
    if crf == 0:
        return image

    image_array = (image[:(image.shape[0] // 2) * 2, :(image.shape[1] // 2) * 2] * 255.0).byte().cpu().numpy()
    with io.BytesIO() as output_file:
        encode_single_frame(output_file, image_array, crf)
        video_bytes = output_file.getvalue()
    with io.BytesIO(video_bytes) as video_file:
        image_array = decode_single_frame(video_file)
    tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0
    return tensor


class LTXVPreprocess:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "image": ("IMAGE",),
                "img_compression": (
                    "INT",
                    {
                        "default": 35,
                        "min": 0,
                        "max": 100,
                        "tooltip": "Amount of compression to apply on image.",
                    },
                ),
            }
        }

    FUNCTION = "preprocess"
    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("output_image",)
    CATEGORY = "image"

    def preprocess(self, image, img_compression):
        if img_compression > 0:
            output_images = []
            for i in range(image.shape[0]):
                output_images.append(preprocess(image[i], img_compression))
        return (torch.stack(output_images),)


NODE_CLASS_MAPPINGS = {
    "EmptyLTXVLatentVideo": EmptyLTXVLatentVideo,
    "LTXVImgToVideo": LTXVImgToVideo,
    "ModelSamplingLTXV": ModelSamplingLTXV,
    "LTXVConditioning": LTXVConditioning,
    "LTXVScheduler": LTXVScheduler,
    "LTXVAddGuide": LTXVAddGuide,
    "LTXVPreprocess": LTXVPreprocess,
    "LTXVCropGuides": LTXVCropGuides,
}