migrating to updated nodes

This commit is contained in:
Ubuntu 2024-12-20 14:47:30 +00:00 committed by Ubuntu
parent 20833211f0
commit 7c7d9f54da
8 changed files with 1216 additions and 177 deletions

View File

@ -0,0 +1,99 @@
import io
import av
import comfy.latent_formats
import comfy.model_base
import comfy.model_management
import comfy.model_patcher
import comfy.sd
import comfy.supported_models_base
import comfy.utils
import numpy as np
import torch
from ltx_video.models.autoencoders.vae_encode import get_vae_size_scale_factor
def encode_single_frame(output_file, image_array: np.ndarray, crf):
container = av.open(output_file, "w", format="mp4")
try:
stream = container.add_stream(
"h264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
)
stream.height = image_array.shape[0]
stream.width = image_array.shape[1]
av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(
format="yuv420p"
)
container.mux(stream.encode(av_frame))
container.mux(stream.encode())
finally:
container.close()
def decode_single_frame(video_file):
container = av.open(video_file)
try:
stream = next(s for s in container.streams if s.type == "video")
frame = next(container.decode(stream))
finally:
container.close()
return frame.to_ndarray(format="rgb24")
def videofy(image: torch.Tensor, crf=29):
if crf == 0:
return image
image_array = (image * 255.0).byte().cpu().numpy()
with io.BytesIO() as output_file:
encode_single_frame(output_file, image_array, crf)
video_bytes = output_file.getvalue()
with io.BytesIO(video_bytes) as video_file:
image_array = decode_single_frame(video_file)
tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0
return tensor
def pad_tensor(tensor, target_len):
dim = 2
repeat_factor = target_len - tensor.shape[dim] # Ceiling division
last_element = tensor.select(dim, -1).unsqueeze(dim)
padding = last_element.repeat(1, 1, repeat_factor, 1, 1)
return torch.cat([tensor, padding], dim=dim)
def encode_media_conditioning(
init_media, vae, width, height, frames_number, image_compression, initial_latent
):
pixels = comfy.utils.common_upscale(
init_media.movedim(-1, 1), width, height, "bilinear", ""
).movedim(1, -1)
encode_pixels = pixels[:, :, :, :3]
if image_compression > 0:
for i in range(encode_pixels.shape[0]):
image = videofy(encode_pixels[i], image_compression)
encode_pixels[i] = image
encoded_latents = vae.encode(encode_pixels).float()
video_scale_factor, _, _ = get_vae_size_scale_factor(vae.first_stage_model)
video_scale_factor = video_scale_factor if frames_number > 1 else 1
target_len = (frames_number // video_scale_factor) + 1
encoded_latents = encoded_latents[:, :, :target_len]
if initial_latent is None:
initial_latent = encoded_latents
else:
if encoded_latents.shape[2] > initial_latent.shape[2]:
initial_latent = pad_tensor(initial_latent, encoded_latents.shape[2])
initial_latent[:, :, : encoded_latents.shape[2], ...] = encoded_latents
init_image_frame_number = init_media.shape[0]
if init_image_frame_number == 1:
result = pad_tensor(initial_latent, target_len)
elif init_image_frame_number % 8 != 1:
result = pad_tensor(initial_latent, target_len)
else:
result = initial_latent
return result

View File

@ -0,0 +1,219 @@
import math
from contextlib import nullcontext
import comfy.latent_formats
import comfy.model_base
import comfy.model_management
import comfy.model_patcher
import comfy.model_sampling
import comfy.sd
import comfy.supported_models_base
import comfy.utils
import torch
import torch.nn as nn
from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
from ltx_video.models.transformers.transformer3d import Transformer3DModel
class LTXVModelConfig:
def __init__(self, latent_channels, dtype):
self.unet_config = {}
self.unet_extra_config = {}
self.latent_format = comfy.latent_formats.LatentFormat()
self.latent_format.latent_channels = latent_channels
self.manual_cast_dtype = dtype
self.sampling_settings = {"multiplier": 1.0}
self.memory_usage_factor = 2.7
# denoiser is handled by extension
self.unet_config["disable_unet_model_creation"] = True
class LTXVSampling(torch.nn.Module, comfy.model_sampling.CONST):
def __init__(self, condition_mask, guiding_latent=None):
super().__init__()
self.condition_mask = condition_mask
self.guiding_latent = guiding_latent
self.set_parameters(shift=1.0, multiplier=1)
def set_parameters(self, shift=1.0, timesteps=1000, multiplier=1000):
self.shift = shift
self.multiplier = multiplier
ts = self.sigma((torch.arange(0, timesteps + 1, 1) / timesteps) * multiplier)
self.register_buffer("sigmas", ts)
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma):
return sigma * self.multiplier
def sigma(self, timestep):
return timestep
def percent_to_sigma(self, percent):
if percent <= 0.0:
return 1.0
if percent >= 1.0:
return 0.0
return 1.0 - percent
def calculate_input(self, sigma, noise):
if self.guiding_latent is not None:
noise = (
noise * (1 - self.condition_mask)
+ self.guiding_latent * self.condition_mask
)
return noise
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
self.condition_mask = self.condition_mask.to(latent_image.device)
scaled = latent_image * (1 - sigma) + noise * sigma
result = latent_image * self.condition_mask + scaled * (1 - self.condition_mask)
return result
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
result = model_input - model_output * sigma
# In order to d * dT to be zero in euler step, we need to set result equal to input in first latent frame.
if self.guiding_latent is not None:
result = (
result * (1 - self.condition_mask)
+ self.guiding_latent * self.condition_mask
)
else:
result = (
result * (1 - self.condition_mask) + model_input * self.condition_mask
)
return result
class LTXVModel(comfy.model_base.BaseModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_sampling = LTXVSampling(torch.zeros([1]))
class LTXVTransformer3D(nn.Module):
def __init__(
self,
transformer: Transformer3DModel,
patchifier: SymmetricPatchifier,
conditioning_mask,
latent_frame_rate,
vae_scale_factor,
):
super().__init__()
self.dtype = transformer.dtype
self.transformer = transformer
self.patchifier = patchifier
self.conditioning_mask = conditioning_mask
self.latent_frame_rate = latent_frame_rate
self.vae_scale_factor = vae_scale_factor
def indices_grid(
self,
latent_shape,
device,
):
use_rope = self.transformer.use_rope
scale_grid = (
(1 / self.latent_frame_rate, self.vae_scale_factor, self.vae_scale_factor)
if use_rope
else None
)
indices_grid = self.patchifier.get_grid(
orig_num_frames=latent_shape[2],
orig_height=latent_shape[3],
orig_width=latent_shape[4],
batch_size=latent_shape[0],
scale_grid=scale_grid,
device=device,
)
return indices_grid
def wrapped_transformer(
self,
latent,
timesteps,
context,
indices_grid,
skip_layer_mask=None,
skip_layer_strategy=None,
img_hw=None,
aspect_ratio=None,
mixed_precision=True,
**kwargs,
):
# infer mask from context padding, assumes padding vectors are all zero.
latent = latent.to(self.transformer.dtype)
latent_patchified = self.patchifier.patchify(latent)
context_mask = (context != 0).any(dim=2).to(self.transformer.dtype)
if mixed_precision:
context_manager = torch.autocast("cuda", dtype=torch.bfloat16)
else:
context_manager = nullcontext()
with context_manager:
noise_pred = self.transformer(
latent_patchified.to(self.transformer.dtype).to(
self.transformer.device
),
indices_grid.to(self.transformer.device),
encoder_hidden_states=context.to(self.transformer.device),
encoder_attention_mask=context_mask.to(self.transformer.device).to(
torch.int64
),
timestep=timesteps,
skip_layer_mask=skip_layer_mask,
skip_layer_strategy=skip_layer_strategy,
return_dict=False,
)[0]
result = self.patchifier.unpatchify(
latents=noise_pred,
output_height=latent.shape[3],
output_width=latent.shape[4],
output_num_frames=latent.shape[2],
out_channels=latent.shape[1] // math.prod(self.patchifier.patch_size),
)
return result
def forward(self, x, timesteps, context, img_hw=None, aspect_ratio=None, **kwargs):
transformer_options = kwargs.get("transformer_options", {})
ptb_index = transformer_options.get("ptb_index", None)
mixed_precision = transformer_options.get("mixed_precision", False)
cond_or_uncond = transformer_options.get("cond_or_uncond", [])
skip_block_list = transformer_options.get("skip_block_list", [])
skip_layer_strategy = transformer_options.get("skip_layer_strategy", None)
mask = self.patchifier.patchify(self.conditioning_mask).squeeze(-1).to(x.device)
ndim_mask = mask.ndimension()
expanded_timesteps = timesteps.view(timesteps.size(0), *([1] * (ndim_mask - 1)))
timesteps_masked = expanded_timesteps * (1 - mask)
skip_layer_mask = None
if ptb_index is not None and ptb_index in cond_or_uncond:
skip_layer_mask = self.transformer.create_skip_layer_mask(
skip_block_list,
1,
len(cond_or_uncond),
len(cond_or_uncond) - 1 - cond_or_uncond.index(ptb_index),
)
result = self.wrapped_transformer(
x,
timesteps_masked,
context,
indices_grid=self.indices_grid(x.shape, x.device),
mixed_precision=mixed_precision,
skip_layer_mask=skip_layer_mask,
skip_layer_strategy=skip_layer_strategy,
)
return result

View File

@ -1,13 +1,47 @@
import json
import math
from pathlib import Path
import comfy
import comfy.model_management
import comfy.model_patcher
from comfy_extras.nodes_custom_sampler import Noise_RandomNoise
import latent_preview
from .modules.video_model import inject_model
from .modules.model_2 import LTXVModel, LTXVModelConfig, LTXVTransformer3D
from ltx_video.models.autoencoders.causal_video_autoencoder import (
CausalVideoAutoencoder,
)
from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
from ltx_video.models.transformers.transformer3d import Transformer3DModel
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
import folder_paths
import node_helpers
import torch
import comfy
import safetensors.torch
from safetensors import safe_open
from .vae import MD_VideoVAE
from .stg import STGGuider
from ltx_video.models.autoencoders.vae_encode import get_vae_size_scale_factor
from .modules.img2vid import encode_media_conditioning
from .modules.model_2 import LTXVSampling
def get_normal_shift(
n_tokens: int,
min_tokens: int = 1024,
max_tokens: int = 4096,
min_shift: float = 0.95,
max_shift: float = 2.05,
) -> float:
m = (max_shift - min_shift) / (max_tokens - min_tokens)
b = min_shift - m * min_tokens
return m * n_tokens + b
class MD_LoadVideoModel:
"""
@ -33,19 +67,167 @@ class MD_LoadVideoModel:
FUNCTION = "load_model"
CATEGORY = "MemeDeck"
def load_model(self, chkpt_name, clip_name):
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", chkpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
model = out[0]
vae = out[2]
# def load_model(self, chkpt_name, clip_name):
# ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", chkpt_name)
# out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
# model = out[0]
# vae = out[2]
# clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name)
# clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=8)
# # modify model
# model.model.diffusion_model = inject_model(model.model.diffusion_model)
# return (model, clip, vae, )
def load_model(self, chkpt_name, clip_name):
dtype = torch.float32
load_device = comfy.model_management.get_torch_device()
offload_device = comfy.model_management.unet_offload_device()
ckpt_path = Path(folder_paths.get_full_path("checkpoints", chkpt_name))
vae_config = None
unet_config = None
with safe_open(ckpt_path, framework="pt", device="cpu") as f:
metadata = f.metadata()
if metadata is not None:
config_metadata = metadata.get("config", None)
if config_metadata is not None:
config_metadata = json.loads(config_metadata)
vae_config = config_metadata.get("vae", None)
unet_config = config_metadata.get("transformer", None)
weights = safetensors.torch.load_file(ckpt_path, device="cpu")
vae = self._load_vae(weights, vae_config)
num_latent_channels = vae.first_stage_model.config.latent_channels
model = self._load_unet(
load_device,
offload_device,
weights,
num_latent_channels,
dtype=dtype,
config=unet_config,
)
# Load clip
clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=8)
# modify model
model.model.diffusion_model = inject_model(model.model.diffusion_model)
# model.model.diffusion_model = inject_model(model.model.diffusion_model)
return (model, clip, vae, )
# -----------------------------
def _load_vae(self, weights, config=None):
if config is None:
config = {
"_class_name": "CausalVideoAutoencoder",
"dims": 3,
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"blocks": [
["res_x", 4],
["compress_all", 1],
["res_x_y", 1],
["res_x", 3],
["compress_all", 1],
["res_x_y", 1],
["res_x", 3],
["compress_all", 1],
["res_x", 3],
["res_x", 4],
],
"scaling_factor": 1.0,
"norm_layer": "pixel_norm",
"patch_size": 4,
"latent_log_var": "uniform",
"use_quant_conv": False,
"causal_decoder": False,
}
vae_prefix = "vae."
vae = MD_VideoVAE.from_config_and_state_dict(
vae_class=CausalVideoAutoencoder,
config=config,
state_dict={
key.removeprefix(vae_prefix): value
for key, value in weights.items()
if key.startswith(vae_prefix)
},
)
return vae
def _load_unet(
self,
load_device,
offload_device,
weights,
num_latent_channels,
dtype,
config=None,
):
if config is None:
config = {
"_class_name": "Transformer3DModel",
"_diffusers_version": "0.25.1",
"_name_or_path": "PixArt-alpha/PixArt-XL-2-256x256",
"activation_fn": "gelu-approximate",
"attention_bias": True,
"attention_head_dim": 64,
"attention_type": "default",
"caption_channels": 4096,
"cross_attention_dim": 2048,
"double_self_attention": False,
"dropout": 0.0,
"in_channels": 128,
"norm_elementwise_affine": False,
"norm_eps": 1e-06,
"norm_num_groups": 32,
"num_attention_heads": 32,
"num_embeds_ada_norm": 1000,
"num_layers": 28,
"num_vector_embeds": None,
"only_cross_attention": False,
"out_channels": 128,
"project_to_2d_pos": True,
"upcast_attention": False,
"use_linear_projection": False,
"qk_norm": "rms_norm",
"standardization_norm": "rms_norm",
"positional_embedding_type": "rope",
"positional_embedding_theta": 10000.0,
"positional_embedding_max_pos": [20, 2048, 2048],
"timestep_scale_multiplier": 1000,
}
transformer = Transformer3DModel.from_config(config)
unet_prefix = "model.diffusion_model."
transformer.load_state_dict(
{
key.removeprefix(unet_prefix): value
for key, value in weights.items()
if key.startswith(unet_prefix)
}
)
transformer.to(dtype).to(load_device).eval()
patchifier = SymmetricPatchifier(1)
diffusion_model = LTXVTransformer3D(transformer, patchifier, None, None, None)
model = LTXVModel(
LTXVModelConfig(num_latent_channels, dtype=dtype),
model_type=comfy.model_base.ModelType.FLOW,
device=comfy.model_management.get_torch_device(),
)
model.diffusion_model = diffusion_model
patcher = comfy.model_patcher.ModelPatcher(model, load_device, offload_device)
return patcher
class LatentGuide(torch.nn.Module):
def __init__(self, latent: torch.Tensor, index) -> None:
@ -82,14 +264,9 @@ class MD_ImgToVideo:
"default": 24,
"description": "The fps of the video."
}),
# LATENT GUIDE INPUTS
"add_latent_guide_index": ("INT", {
"default": 0,
"description": "The index of the latent to add to the guide."
}),
"add_latent_guide_insert": ("BOOLEAN", {
"default": False,
"description": "Whether to add the latent to the guide."
"crf": ("FLOAT", {
"default": 28.0,
"description": "The cfg of the image."
}),
# SCHEDULER INPUTS
"steps": ("INT", {
@ -135,35 +312,40 @@ class MD_ImgToVideo:
},
}
RETURN_TYPES = ("MODEL", "CONDITIONING", "CONDITIONING", "SIGMAS", "LATENT", "STRING")
RETURN_NAMES = ("model", "positive", "negative", "sigmas", "latent", "img2vid_metadata")
RETURN_TYPES = ("MODEL", "CONDITIONING", "CONDITIONING", "SIGMAS", "LATENT", "STRING", "GUIDER")
RETURN_NAMES = ("model", "positive", "negative", "sigmas", "latent", "img2vid_metadata", "guider")
FUNCTION = "img_to_video"
CATEGORY = "MemeDeck"
def img_to_video(self, model, positive, negative, vae, image, width, height, length, fps, add_latent_guide_index, add_latent_guide_insert, steps, max_shift, base_shift, stretch, terminal, attention_override, attention_adjustment_scale, attention_adjustment_rescale, attention_adjustment_cfg):
batch_size = 1
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels)
positive = node_helpers.conditioning_set_values(positive, {"guiding_latent": t})
negative = node_helpers.conditioning_set_values(negative, {"guiding_latent": t})
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
latent[:, :, :t.shape[2]] = t
latent_samples = {"samples": latent}
def img_to_video(self, model, positive, negative, vae, image, width, height, length, fps, crf, steps, max_shift, base_shift, stretch, terminal, attention_override, attention_adjustment_scale, attention_adjustment_rescale, attention_adjustment_cfg):
batch_size = 1
# pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
# # get pixels from image
# # encode_pixels = image[:, :, :, :3]
# encode_pixels = pixels[:, :, :, :3]
# t = vae.encode(encode_pixels)
# positive = node_helpers.conditioning_set_values(positive, {"guiding_latent": t})
# negative = node_helpers.conditioning_set_values(negative, {"guiding_latent": t})
# # latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
# # latent[:, :, :t.shape[2]] = t
# # latent_samples = {"samples": latent}
positive = node_helpers.conditioning_set_values(positive, {"frame_rate": fps})
negative = node_helpers.conditioning_set_values(negative, {"frame_rate": fps})
# positive = node_helpers.conditioning_set_values(positive, {"frame_rate": fps})
# negative = node_helpers.conditioning_set_values(negative, {"frame_rate": fps})
# 2. add latent guide
model, latent_updated = self.add_latent_guide(model, latent_samples, latent_samples, add_latent_guide_index, add_latent_guide_insert)
# 1. apply stg
model = self.apply_stg(model, "attention", attention_override)
# 3. apply attention override
attn_override_layers = self.attention_override(attention_override)
model = self.apply_attention_override(model, attention_adjustment_scale, attention_adjustment_rescale, attention_adjustment_cfg, attn_override_layers)
# 5. configure scheduler
sigmas = self.get_sigmas(steps, max_shift, base_shift, stretch, terminal, latent_updated)
# 2. configure sizes
model, latent_updated, sigma_shift = self.configure_sizes(model, vae, "Custom", width, height, length, fps, batch_size, mixed_precision=True, img_compression=crf, conditioning=image, initial_latent=None)
# 3. shift sigmas - model, scheduler, steps, denoise
scheduler_sigmas = self.get_sigmas(model, "normal", steps, 1.0)
sigmas = self.shift_sigmas(scheduler_sigmas, sigma_shift, stretch, terminal)
# 4. get guider
guider = self.get_guider(model, positive, negative, cfg=attention_adjustment_cfg, stg=attention_adjustment_scale, rescale=attention_adjustment_rescale)
# all parameters starting with width, height, fps, crf, etc
img2vid_metadata = {
@ -183,134 +365,136 @@ class MD_ImgToVideo:
}
json_img2vid_metadata = json.dumps(img2vid_metadata)
return (model, positive, negative, sigmas, latent_updated, json_img2vid_metadata)
return (model, positive, negative, sigmas, latent_updated, json_img2vid_metadata, guider)
# -----------------------------
# Attention functions
# -----------------------------
# 1. Add latent guide
def add_latent_guide(self, model, latent, image_latent, index, insert):
image_latent = image_latent['samples']
latent = latent['samples'].clone()
# # Convert negative index to positive
# if insert:
# index = max(0, min(index, latent.shape[2])) # Clamp index
# latent = torch.cat([
# latent[:,:,:index],
# image_latent[:,:,0:1],
# latent[:,:,index:]
# ], dim=2)
# else:
# latent[:,:,index] = image_latent[:,:,0]
if insert:
# Handle insertion
if index == 0:
# Insert at beginning
latent = torch.cat([image_latent[:,:,0:1], latent], dim=2)
elif index >= latent.shape[2] or index < 0:
# Append to end
latent = torch.cat([latent, image_latent[:,:,0:1]], dim=2)
else:
# Insert in middle
latent = torch.cat([
latent[:,:,:index],
image_latent[:,:,0:1],
latent[:,:,index:]
], dim=2)
# 1 ------------------------------------------------------------
def apply_stg(self, model, stg_mode: str, block_indices: str):
skip_block_list = [int(i.strip()) for i in block_indices.split(",")]
stg_mode = (
SkipLayerStrategy.Attention
if stg_mode == "attention"
else SkipLayerStrategy.Residual
)
new_model = model.clone()
new_model.model_options["transformer_options"]["skip_layer_strategy"] = stg_mode
if "skip_block_list" in new_model.model_options["transformer_options"]:
skip_block_list.extend(
new_model.model_options["transformer_options"]["skip_block_list"]
)
new_model.model_options["transformer_options"][
"skip_block_list"
] = skip_block_list
return new_model
# 2 ------------------------------------------------------------
def configure_sizes(
self,
model,
vae,
preset,
width,
height,
frames_number,
frame_rate,
batch,
mixed_precision,
img_compression,
conditioning=None,
initial_latent=None,
):
load_device = comfy.model_management.get_torch_device()
if preset != "Custom":
preset = preset.split("|")
width, height = map(int, preset[0].strip().split("x"))
frames_number = int(preset[1].strip())
latent_shape, latent_frame_rate = self.latent_shape_and_frame_rate(
vae, batch, height, width, frames_number, frame_rate
)
mask_shape = [
latent_shape[0],
1,
latent_shape[2],
latent_shape[3],
latent_shape[4],
]
conditioning_mask = torch.zeros(mask_shape, device=load_device)
initial_latent = (
None
if initial_latent is None
else initial_latent["samples"].to(load_device)
)
guiding_latent = None
if conditioning is not None:
latent = encode_media_conditioning(
conditioning,
vae,
width,
height,
frames_number,
image_compression=img_compression,
initial_latent=initial_latent,
)
conditioning_mask[:, :, 0] = 1.0
guiding_latent = latent[:, :, :1, ...]
else:
# Original replacement behavior
latent[:,:,index] = image_latent[:,:,0]
model = model.clone()
guiding_latent = LatentGuide(image_latent, index)
model.set_model_patch(guiding_latent, 'guiding_latents')
return (model, {"samples": latent},)
# 2. Apply attention override
def is_integer(self, string):
try:
int(string)
return True
except ValueError:
return False
latent = torch.zeros(latent_shape, dtype=torch.float32, device=load_device)
if initial_latent is not None:
latent[:, :, : initial_latent.shape[2], ...] = initial_latent
_, vae_scale_factor, _ = get_vae_size_scale_factor(vae.first_stage_model)
patcher = model.clone()
patcher.add_object_patch("diffusion_model.conditioning_mask", conditioning_mask)
patcher.add_object_patch("diffusion_model.latent_frame_rate", latent_frame_rate)
patcher.add_object_patch("diffusion_model.vae_scale_factor", vae_scale_factor)
patcher.add_object_patch(
"model_sampling", LTXVSampling(conditioning_mask, guiding_latent)
)
patcher.model_options.setdefault("transformer_options", {})[
"mixed_precision"
] = mixed_precision
num_latent_patches = latent_shape[2] * latent_shape[3] * latent_shape[4]
return (patcher, {"samples": latent}, get_normal_shift(num_latent_patches))
def attention_override(self, layers: str = "14"):
try:
return set(map(int, layers.split(',')))
except ValueError:
return set()
# layers_map = set([])
# return set(map(int, layers.split(',')))
# for block in layers.split(','):
# block = block.strip()
# if self.is_integer(block):
# layers_map.add(block)
def latent_shape_and_frame_rate(
self, vae, batch, height, width, frames_number, frame_rate
):
video_scale_factor, vae_scale_factor, _ = get_vae_size_scale_factor(
vae.first_stage_model
)
video_scale_factor = video_scale_factor if frames_number > 1 else 1
# return layers_map
latent_height = height // vae_scale_factor
latent_width = width // vae_scale_factor
latent_channels = vae.first_stage_model.config.latent_channels
latent_num_frames = math.floor(frames_number / video_scale_factor) + 1
latent_frame_rate = frame_rate / video_scale_factor
latent_shape = [
batch,
latent_channels,
latent_num_frames,
latent_height,
latent_width,
]
return latent_shape, latent_frame_rate
def apply_attention_override(self, model, scale, rescale, cfg, attention_override: set):
m = model.clone()
# 3 ------------------------------------------------------------
def get_sigmas(self, model, scheduler, steps, denoise):
total_steps = steps
if denoise < 1.0:
if denoise <= 0.0:
return (torch.FloatTensor([]),)
total_steps = int(steps/denoise)
def pag_fn(q, k,v, heads, attn_precision=None, transformer_options=None):
return v
def post_cfg_function(args):
model = args["model"]
cond_pred = args["cond_denoised"]
uncond_pred = args["uncond_denoised"]
len_conds = 1 if args.get('uncond', None) is None else 2
cond = args["cond"]
sigma = args["sigma"]
model_options = args["model_options"].copy()
x = args["input"]
if scale == 0:
if len_conds == 1:
return cond_pred
return uncond_pred + (cond_pred - uncond_pred)
for block_idx in attention_override:
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, pag_fn, f"layer", "self_attn", int(block_idx))
(perturbed,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
output = uncond_pred + cfg * (cond_pred - uncond_pred) \
+ scale * (cond_pred - perturbed)
if rescale > 0:
factor = cond_pred.std() / output.std()
factor = rescale * factor + (1 - rescale)
output = output * factor
return output
m.set_model_sampler_post_cfg_function(post_cfg_function)
return m
sigmas = comfy.samplers.calculate_sigmas(model.get_model_object("model_sampling"), scheduler, total_steps).cpu()
sigmas = sigmas[-(steps + 1):]
return sigmas
# -----------------------------
# Scheduler
# -----------------------------
def get_sigmas(self, steps, max_shift, base_shift, stretch, terminal, latent=None):
if latent is None:
tokens = 4096
else:
tokens = math.prod(latent["samples"].shape[2:])
sigmas = torch.linspace(1.0, 0.0, steps + 1)
x1 = 1024
x2 = 4096
mm = (max_shift - base_shift) / (x2 - x1)
b = base_shift - mm * x1
sigma_shift = (tokens) * mm + b
def shift_sigmas(self, sigmas, sigma_shift, stretch, terminal):
power = 1
sigmas = torch.where(
sigmas != 0,
@ -328,6 +512,183 @@ class MD_ImgToVideo:
sigmas[non_zero_mask] = stretched
return sigmas
# 4 ------------------------------------------------------------
def get_guider(self, model, positive, negative, cfg, stg, rescale):
guider = STGGuider(model)
guider.set_conds(positive, negative)
guider.set_cfg(cfg, stg, rescale)
return guider
# def img_to_video(self, model, positive, negative, vae, image, width, height, length, fps, add_latent_guide_index, add_latent_guide_insert, steps, max_shift, base_shift, stretch, terminal, attention_override, attention_adjustment_scale, attention_adjustment_rescale, attention_adjustment_cfg):
# batch_size = 1
# pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
# encode_pixels = pixels[:, :, :, :3]
# t = vae.encode(encode_pixels)
# positive = node_helpers.conditioning_set_values(positive, {"guiding_latent": t})
# negative = node_helpers.conditioning_set_values(negative, {"guiding_latent": t})
# latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
# latent[:, :, :t.shape[2]] = t
# latent_samples = {"samples": latent}
# positive = node_helpers.conditioning_set_values(positive, {"frame_rate": fps})
# negative = node_helpers.conditioning_set_values(negative, {"frame_rate": fps})
# # 2. add latent guide
# model, latent_updated = self.add_latent_guide(model, latent_samples, latent_samples, add_latent_guide_index, add_latent_guide_insert)
# # 3. apply attention override
# attn_override_layers = self.attention_override(attention_override)
# model = self.apply_attention_override(model, attention_adjustment_scale, attention_adjustment_rescale, attention_adjustment_cfg, attn_override_layers)
# # 4. configure scheduler
# sigmas = self.get_sigmas(steps, max_shift, base_shift, stretch, terminal, latent_updated)
# # all parameters starting with width, height, fps, crf, etc
# img2vid_metadata = {
# "width": width,
# "height": height,
# "length": length,
# "fps": fps,
# "steps": steps,
# "max_shift": max_shift,
# "base_shift": base_shift,
# "stretch": stretch,
# "terminal": terminal,
# "attention_override": attention_override,
# "attention_adjustment_scale": attention_adjustment_scale,
# "attention_adjustment_rescale": attention_adjustment_rescale,
# "attention_adjustment_cfg": attention_adjustment_cfg,
# }
# json_img2vid_metadata = json.dumps(img2vid_metadata)
# return (model, positive, negative, sigmas, latent_updated, json_img2vid_metadata)
# -----------------------------
# Attention functions
# -----------------------------
# 1. Add latent guide
# def add_latent_guide(self, model, latent, image_latent, index, insert):
# image_latent = image_latent['samples']
# latent = latent['samples'].clone()
# if insert:
# # Handle insertion
# if index == 0:
# # Insert at beginning
# latent = torch.cat([image_latent[:,:,0:1], latent], dim=2)
# elif index >= latent.shape[2] or index < 0:
# # Append to end
# latent = torch.cat([latent, image_latent[:,:,0:1]], dim=2)
# else:
# # Insert in middle
# latent = torch.cat([
# latent[:,:,:index],
# image_latent[:,:,0:1],
# latent[:,:,index:]
# ], dim=2)
# else:
# # Original replacement behavior
# latent[:,:,index] = image_latent[:,:,0]
# model = model.clone()
# guiding_latent = LatentGuide(image_latent, index)
# model.set_model_patch(guiding_latent, 'guiding_latents')
# return (model, {"samples": latent},)
# 2. Apply attention override
# def is_integer(self, string):
# try:
# int(string)
# return True
# except ValueError:
# return False
# def attention_override(self, layers: str = "14"):
# try:
# return set(map(int, layers.split(',')))
# except ValueError:
# return set()
# def apply_attention_override(self, model, scale, rescale, cfg, attention_override: set):
# m = model.clone()
# def pag_fn(q, k,v, heads, attn_precision=None, transformer_options=None):
# return v
# def post_cfg_function(args):
# model = args["model"]
# cond_pred = args["cond_denoised"]
# uncond_pred = args["uncond_denoised"]
# len_conds = 1 if args.get('uncond', None) is None else 2
# cond = args["cond"]
# sigma = args["sigma"]
# model_options = args["model_options"].copy()
# x = args["input"]
# if scale == 0:
# if len_conds == 1:
# return cond_pred
# return uncond_pred + (cond_pred - uncond_pred)
# for block_idx in attention_override:
# model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, pag_fn, f"layer", "self_attn", int(block_idx))
# (perturbed,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
# output = uncond_pred + cfg * (cond_pred - uncond_pred) \
# + scale * (cond_pred - perturbed)
# if rescale > 0:
# factor = cond_pred.std() / output.std()
# factor = rescale * factor + (1 - rescale)
# output = output * factor
# return output
# m.set_model_sampler_post_cfg_function(post_cfg_function)
# return m
# -----------------------------
# Scheduler
# -----------------------------
# def get_sigmas(self, steps, max_shift, base_shift, stretch, terminal, latent=None):
# if latent is None:
# tokens = 4096
# else:
# tokens = math.prod(latent["samples"].shape[2:])
# sigmas = torch.linspace(1.0, 0.0, steps + 1)
# x1 = 1024
# x2 = 4096
# mm = (max_shift - base_shift) / (x2 - x1)
# b = base_shift - mm * x1
# sigma_shift = (tokens) * mm + b
# power = 1
# sigmas = torch.where(
# sigmas != 0,
# math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
# 0,
# )
# # Stretch sigmas so that its final value matches the given terminal value.
# if stretch:
# non_zero_mask = sigmas != 0
# non_zero_sigmas = sigmas[non_zero_mask]
# one_minus_z = 1.0 - non_zero_sigmas
# scale_factor = one_minus_z[-1] / (1.0 - terminal)
# stretched = 1.0 - (one_minus_z / scale_factor)
# sigmas[non_zero_mask] = stretched
# return sigmas
KSAMPLER_NAMES = ["euler", "ddim", "euler_ancestral", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
@ -346,6 +707,7 @@ class MD_VideoSampler:
"model": ("MODEL",),
"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"guider": ("GUIDER",),
"sigmas": ("SIGMAS",),
"latent_image": ("LATENT",),
"sampler": (KSAMPLER_NAMES, ),
@ -367,34 +729,65 @@ class MD_VideoSampler:
FUNCTION = "video_sampler"
CATEGORY = "MemeDeck"
def video_sampler(self, model, positive, negative, sigmas, latent_image, sampler, noise_seed, cfg):
def video_sampler(self, model, positive, negative, guider, sigmas, latent_image, sampler, noise_seed, cfg):
# latent = latent_image
# latent_image = latent["samples"]
# latent = latent.copy()
# latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image)
# latent["samples"] = latent_image
# sampler_name = sampler
# noise = Noise_RandomNoise(noise_seed).generate_noise(latent)
# sampler = comfy.samplers.sampler_object(sampler)
# noise_mask = None
# if "noise_mask" in latent:
# noise_mask = latent["noise_mask"]
# x0_output = {}
# callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
# disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
# samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
# out = latent.copy()
# out["samples"] = samples
# if "x0" in x0_output:
# out_denoised = latent.copy()
# out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu())
# else:
# out_denoised = out
noise = Noise_RandomNoise(noise_seed)
latent = latent_image
latent_image = latent["samples"]
latent = latent.copy()
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image)
latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image)
latent["samples"] = latent_image
sampler_name = sampler
noise = Noise_RandomNoise(noise_seed).generate_noise(latent)
sampler = comfy.samplers.sampler_object(sampler)
sampler_name = sampler
sampler = comfy.samplers.sampler_object(sampler)
noise_mask = None
if "noise_mask" in latent:
noise_mask = latent["noise_mask"]
x0_output = {}
callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
callback = latent_preview.prepare_callback(guider.model_patcher, sigmas.shape[-1] - 1, x0_output)
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed)
samples = samples.to(comfy.model_management.intermediate_device())
out = latent.copy()
out["samples"] = samples
if "x0" in x0_output:
out_denoised = latent.copy()
out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu())
out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu())
else:
out_denoised = out
# return (out, out_denoised)
sampler_metadata = {
"sampler": sampler_name,

View File

@ -21,7 +21,7 @@ WATERMARK = """
<path d="M60.0859 196.8C65.9526 179.067 71.5526 161.667 76.8859 144.6C79.1526 137.4 81.4859 129.867 83.8859 122C86.2859 114.133 88.6859 106.333 91.0859 98.6C93.4859 90.8667 95.6859 83.4 97.6859 76.2C99.8193 69 101.686 62.3333 103.286 56.2C110.619 56.2 117.553 55.8 124.086 55C130.619 54.2 137.686 53.4667 145.286 52.8C144.886 55.7333 144.419 59.0667 143.886 62.8C143.486 66.4 142.953 70.2 142.286 74.2C141.753 78.2 141.153 82.3333 140.486 86.6C139.819 90.8667 139.019 96.3333 138.086 103C137.153 109.667 135.886 118 134.286 128H136.886C140.753 117.867 143.953 109.467 146.486 102.8C149.019 96 151.086 90.4667 152.686 86.2C154.286 81.9333 155.886 77.8 157.486 73.8C159.219 69.6667 160.819 65.8 162.286 62.2C163.886 58.4667 165.353 55.2 166.686 52.4C170.019 52.1333 173.153 51.8 176.086 51.4C179.019 51 181.953 50.6 184.886 50.2C187.819 49.6667 190.753 49.2 193.686 48.8C196.753 48.2667 200.086 47.6667 203.686 47C202.353 54.7333 201.086 62.6667 199.886 70.8C198.686 78.9333 197.619 87.0667 196.686 95.2C195.753 103.2 194.819 111.133 193.886 119C193.086 126.867 192.353 134.333 191.686 141.4C190.086 157.933 188.686 174.067 187.486 189.8L152.686 196C152.686 195.333 152.753 193.533 152.886 190.6C153.153 187.667 153.419 184.067 153.686 179.8C154.086 175.533 154.553 170.8 155.086 165.6C155.753 160.4 156.353 155.2 156.886 150C157.553 144.8 158.219 139.8 158.886 135C159.553 130.067 160.219 125.867 160.886 122.4H159.086C157.219 128 155.153 133.933 152.886 140.2C150.619 146.333 148.286 152.6 145.886 159C143.619 165.4 141.353 171.667 139.086 177.8C136.819 183.933 134.819 189.8 133.086 195.4C128.419 195.533 124.419 195.733 121.086 196C117.753 196.133 113.886 196.333 109.486 196.6L115.886 122.4H112.886C112.619 124.133 111.953 127.067 110.886 131.2C109.819 135.2 108.553 139.867 107.086 145.2C105.753 150.4 104.286 155.867 102.686 161.6C101.086 167.2 99.5526 172.467 98.0859 177.4C96.7526 182.2 95.6193 186.2 94.6859 189.4C93.7526 192.467 93.2193 194.2 93.0859 194.6L60.0859 196.8Z" fill="white"/>
</svg>
"""
WATERMARK_SIZE = 32
WATERMARK_SIZE = 28
class MD_SaveAnimatedWEBP:
def __init__(self):
@ -40,7 +40,7 @@ class MD_SaveAnimatedWEBP:
"lossless": ("BOOLEAN", {"default": False}),
"quality": ("INT", {"default": 90, "min": 0, "max": 100}),
"method": (list(s.methods.keys()),),
"crf": ("INT",),
"crf": ("FLOAT",),
"motion_prompt": ("STRING", ),
"negative_prompt": ("STRING", ),
"img2vid_metadata": ("STRING", ),
@ -67,7 +67,7 @@ class MD_SaveAnimatedWEBP:
pil_images = [Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8)) for image in images]
first_image = pil_images[0]
padding = 12
padding = 8
x = first_image.width - WATERMARK_SIZE - padding
y = first_image.height - WATERMARK_SIZE - padding
first_image_background_brightness = self.analyze_background_brightness(first_image, x, y, WATERMARK_SIZE)
@ -116,6 +116,66 @@ class MD_SaveAnimatedWEBP:
},
}
# def save_images(self, images, fps, filename_prefix, lossless, quality, method, crf=None, motion_prompt=None, negative_prompt=None, img2vid_metadata=None, sampler_metadata=None):
# start_time = time.time()
# method = self.methods.get(method)
# filename_prefix += self.prefix_append
# full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
# results = []
# # Vectorized conversion to PIL images
# pil_images = [Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8)) for image in images]
# first_image = pil_images[0]
# padding = 12
# x = first_image.width - WATERMARK_SIZE - padding
# y = first_image.height - WATERMARK_SIZE - padding
# first_image_background_brightness = self.analyze_background_brightness(first_image, x, y, WATERMARK_SIZE)
# watermarked_images = [self.add_watermark_to_image(img, first_image_background_brightness) for img in pil_images]
# metadata = pil_images[0].getexif()
# num_frames = len(pil_images)
# json_metadata = {
# "crf": crf,
# "motion_prompt": motion_prompt,
# "negative_prompt": negative_prompt,
# "img2vid_metadata": json.loads(img2vid_metadata),
# "sampler_metadata": json.loads(sampler_metadata),
# }
# # Optimized saving logic
# if num_frames == 1: # Single image, save once
# file = f"{filename}_{counter:05}_.webp"
# watermarked_images[0].save(os.path.join(full_output_folder, file), exif=metadata, lossless=lossless, quality=quality, method=method)
# results.append({
# "filename": file,
# "subfolder": subfolder,
# "type": self.type,
# })
# else: # multiple images, save as animation
# file = f"{filename}_{counter:05}_.webp"
# watermarked_images[0].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0 / fps), append_images=watermarked_images[1:], exif=metadata, lossless=lossless, quality=quality, method=method)
# results.append({
# "filename": file,
# "subfolder": subfolder,
# "type": self.type,
# })
# animated = num_frames != 1
# end_time = time.time()
# logger.info(f"Save images took: {end_time - start_time} seconds")
# return {
# "ui": {
# "images": results,
# "animated": (animated,),
# "metadata": (json.dumps(json_metadata),)
# },
# }
def add_watermark_to_image(self, img, background_brightness=None):
"""
Adds a watermark to a single PIL Image.
@ -127,7 +187,7 @@ class MD_SaveAnimatedWEBP:
A PIL Image object with the watermark added.
"""
padding = 12
padding = 8
x = img.width - WATERMARK_SIZE - padding
y = img.height - WATERMARK_SIZE - padding

View File

@ -143,7 +143,7 @@ class MD_CompressAdjustNode:
},
}
RETURN_TYPES = ("IMAGE", "INT", "INT", "INT")
RETURN_TYPES = ("IMAGE", "FLOAT", "INT", "INT")
RETURN_NAMES = ("adjusted_image", "crf", "width", "height")
FUNCTION = "tensor_to_video_and_back"
CATEGORY = "MemeDeck"
@ -291,11 +291,11 @@ class MD_CompressAdjustNode:
image_cv2 = cv2.cvtColor(np.array(tensor2pil(image)), cv2.COLOR_RGB2BGR)
# calculate the crf based on the image
analysis_results = self.analyze_compression_artifacts(image_cv2, width=width, height=height)
desired_crf = self.calculate_crf(analysis_results, self.ideal_blockiness, self.ideal_edge_density,
calculated_crf = self.calculate_crf(analysis_results, self.ideal_blockiness, self.ideal_edge_density,
self.ideal_color_variation, self.blockiness_weight,
self.edge_density_weight, self.color_variation_weight)
logger.info(f"detected crf: {desired_crf}")
logger.info(f"detected crf: {calculated_crf}")
args = [
utils.ffmpeg_path,
"-v", "error",

View File

@ -0,0 +1,103 @@
import comfy.samplers
import comfy.utils
import torch
from comfy.model_patcher import ModelPatcher
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
def stg(
noise_pred_pos,
noise_pred_neg,
noise_pred_pertubed,
cfg_scale,
stg_scale,
rescale_scale,
):
noise_pred = (
noise_pred_neg
+ cfg_scale * (noise_pred_pos - noise_pred_neg)
+ stg_scale * (noise_pred_pos - noise_pred_pertubed)
)
if rescale_scale != 0:
factor = noise_pred_pos.std() / noise_pred.std()
factor = rescale_scale * factor + (1 - rescale_scale)
noise_pred = noise_pred * factor
return noise_pred
class STGGuider(comfy.samplers.CFGGuider):
def set_conds(self, positive, negative):
self.inner_set_conds(
{"positive": positive, "negative": negative, "perturbed": positive}
)
def set_cfg(self, cfg, stg_scale, rescale_scale: float = None):
self.cfg = cfg
self.stg_scale = stg_scale
self.rescale_scale = rescale_scale
def predict_noise(
self,
x: torch.Tensor,
timestep: torch.Tensor,
model_options: dict = {},
seed=None,
):
# in CFGGuider.predict_noise, we call sampling_function(), which uses cfg_function() to compute pos & neg
# but we'd rather do a single batch of sampling pos, neg, and perturbed, so we call calc_cond_batch([perturbed,pos,neg]) directly
perturbed_cond = self.conds.get("perturbed", None)
positive_cond = self.conds.get("positive", None)
negative_cond = self.conds.get("negative", None)
noise_pred_neg = 0
# no similar optimization for stg=0, use CFG guider instead.
if self.cfg > 1:
model_options["transformer_options"]["ptb_index"] = 2
(noise_pred_perturbed, noise_pred_pos, noise_pred_neg) = (
comfy.samplers.calc_cond_batch(
self.inner_model,
[perturbed_cond, positive_cond, negative_cond],
x,
timestep,
model_options,
)
)
else:
model_options["transformer_options"]["ptb_index"] = 1
(noise_pred_perturbed, noise_pred_pos) = comfy.samplers.calc_cond_batch(
self.inner_model,
[perturbed_cond, positive_cond],
x,
timestep,
model_options,
)
stg_result = stg(
noise_pred_pos,
noise_pred_neg,
noise_pred_perturbed,
self.cfg,
self.stg_scale,
self.rescale_scale,
)
# normally this would be done in cfg_function, but we skipped
# that for efficiency: we can compute the noise predictions in
# a single call to calc_cond_batch() (rather than two)
# so we replicate the hook here
for fn in model_options.get("sampler_post_cfg_function", []):
args = {
"denoised": stg_result,
"cond": positive_cond,
"uncond": negative_cond,
"model": self.inner_model,
"uncond_denoised": noise_pred_neg,
"cond_denoised": noise_pred_pos,
"sigma": timestep,
"model_options": model_options,
"input": x,
# not in the original call in samplers.py:cfg_function, but made available for future hooks
"perturbed_cond": positive_cond,
"perturbed_cond_denoised": noise_pred_perturbed,
}
stg_result = fn(args)
return stg_result

View File

@ -0,0 +1,164 @@
from copy import copy
import comfy.latent_formats
import comfy.model_base
import comfy.model_management
import comfy.model_patcher
import comfy.sd
import comfy.supported_models_base
import comfy.utils
import torch
from diffusers.image_processor import VaeImageProcessor
from ltx_video.models.autoencoders.vae_encode import (
get_vae_size_scale_factor,
vae_decode,
vae_encode,
)
class MD_VideoVAE(comfy.sd.VAE):
def __init__(self, decode_timestep=0.05, decode_noise_scale=0.025, seed=42):
self.device = comfy.model_management.vae_device()
self.offload_device = comfy.model_management.vae_offload_device()
self.decode_timestep = decode_timestep
self.decode_noise_scale = decode_noise_scale
self.seed = seed
@classmethod
def from_pretrained(cls, vae_class, model_path, dtype=torch.bfloat16):
instance = cls()
model = vae_class.from_pretrained(
pretrained_model_name_or_path=model_path,
revision=None,
torch_dtype=dtype,
load_in_8bit=False,
)
instance._finalize_model(model)
return instance
@classmethod
def from_config_and_state_dict(
cls, vae_class, config, state_dict, dtype=torch.bfloat16
):
instance = cls()
model = vae_class.from_config(config)
model.load_state_dict(state_dict)
model.to(dtype)
instance._finalize_model(model)
return instance
def _finalize_model(self, model):
self.video_scale_factor, self.vae_scale_factor, _ = get_vae_size_scale_factor(
model
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.first_stage_model = model.eval().to(self.device)
# Assumes that the input samples have dimensions in following order
# (batch, channels, frames, height, width)
def decode(self, samples_in):
is_video = samples_in.shape[2] > 1
decode_timestep = self.decode_timestep
if getattr(self.first_stage_model.decoder, "timestep_conditioning", False):
samples_in = self.add_noise(
decode_timestep, self.decode_noise_scale, self.seed, samples_in
)
else:
decode_timestep = None
result = vae_decode(
samples_in.to(self.device),
vae=self.first_stage_model,
is_video=is_video,
vae_per_channel_normalize=True,
timestep=decode_timestep,
)
result = self.image_processor.postprocess(
result, output_type="pt", do_denormalize=[True]
)
return result.squeeze(0).permute(1, 2, 3, 0).to(torch.float32)
@staticmethod
def add_noise(decode_timestep, decode_noise_scale, seed, latents):
generator = torch.Generator(device="cpu").manual_seed(seed)
noise = torch.randn(
latents.size(),
generator=generator,
device=latents.device,
dtype=latents.dtype,
)
if not isinstance(decode_timestep, list):
decode_timestep = [decode_timestep] * latents.shape[0]
if decode_noise_scale is None:
decode_noise_scale = decode_timestep
elif not isinstance(decode_noise_scale, list):
decode_noise_scale = [decode_noise_scale] * latents.shape[0]
decode_timestep = torch.tensor(decode_timestep).to(latents.device)
decode_noise_scale = torch.tensor(decode_noise_scale).to(latents.device)[
:, None, None, None, None
]
latents = latents * (1 - decode_noise_scale) + noise * decode_noise_scale
return latents
# Underlying VAE expects b, c, n, h, w dimensions order and dtype specific dtype.
# However in Comfy the convension is n, h, w, c.
def encode(self, pixel_samples):
preprocessed = self.image_processor.preprocess(
pixel_samples.permute(3, 0, 1, 2)
)
input = preprocessed.unsqueeze(0).to(torch.bfloat16).to(self.device)
latents = vae_encode(
input, self.first_stage_model, vae_per_channel_normalize=True
).to(comfy.model_management.get_torch_device())
return latents
# class DecoderNoise:
# @classmethod
# def INPUT_TYPES(cls):
# return {
# "required": {
# "vae": ("VAE",),
# "timestep": (
# "FLOAT",
# {
# "default": 0.05,
# "min": 0.0,
# "max": 1.0,
# "step": 0.01,
# "tooltip": "The timestep used for decoding the noise.",
# },
# ),
# "scale": (
# "FLOAT",
# {
# "default": 0.025,
# "min": 0.0,
# "max": 1.0,
# "step": 0.001,
# "tooltip": "The scale of the noise added to the decoder.",
# },
# ),
# "seed": (
# "INT",
# {
# "default": 42,
# "min": 0,
# "max": 0xFFFFFFFFFFFFFFFF,
# "tooltip": "The random seed used for creating the noise.",
# },
# ),
# }
# }
# FUNCTION = "add_noise"
# RETURN_TYPES = ("VAE",)
# CATEGORY = "lightricks/LTXV"
# def add_noise(self, vae, timestep, scale, seed):
# result = copy(vae)
# result.decode_timestep = timestep
# result.decode_noise_scale = scale
# result.seed = seed
# return (result,)

View File

@ -34,4 +34,5 @@ aio_pika
torchao
insightface
onnxruntime-gpu
cairosvg
cairosvg
lxml