From e5c3f4b87febd790b316f82813ba8d89d275fee4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 22 Nov 2024 17:17:11 -0500 Subject: [PATCH] LTXV lowvram fixes. --- comfy/ldm/lightricks/model.py | 4 ++-- comfy/ldm/lightricks/vae/causal_conv3d.py | 4 +++- comfy/ldm/lightricks/vae/causal_video_autoencoder.py | 4 ++-- comfy/ldm/lightricks/vae/conv_nd_factory.py | 11 ++++++----- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 2792384d..87ed0995 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -304,7 +304,7 @@ class BasicTransformerBlock(nn.Module): self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None] + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa @@ -479,7 +479,7 @@ class LTXVModel(torch.nn.Module): # 3. Output scale_shift_values = ( - self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] ) shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] x = self.norm_out(x) diff --git a/comfy/ldm/lightricks/vae/causal_conv3d.py b/comfy/ldm/lightricks/vae/causal_conv3d.py index 146dea19..c572e7e8 100644 --- a/comfy/ldm/lightricks/vae/causal_conv3d.py +++ b/comfy/ldm/lightricks/vae/causal_conv3d.py @@ -2,6 +2,8 @@ from typing import Tuple, Union import torch import torch.nn as nn +import comfy.ops +ops = comfy.ops.disable_weight_init class CausalConv3d(nn.Module): @@ -29,7 +31,7 @@ class CausalConv3d(nn.Module): width_pad = kernel_size[2] // 2 padding = (0, height_pad, width_pad) - self.conv = nn.Conv3d( + self.conv = ops.Conv3d( in_channels, out_channels, kernel_size, diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 4138fdf3..33b2c2d4 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -628,10 +628,10 @@ class processor(nn.Module): self.register_buffer("channel", torch.empty(128)) def un_normalize(self, x): - return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1) + return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x) def normalize(self, x): - return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1) + return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x) class VideoVAE(nn.Module): def __init__(self): diff --git a/comfy/ldm/lightricks/vae/conv_nd_factory.py b/comfy/ldm/lightricks/vae/conv_nd_factory.py index 389f8165..c5f067bf 100644 --- a/comfy/ldm/lightricks/vae/conv_nd_factory.py +++ b/comfy/ldm/lightricks/vae/conv_nd_factory.py @@ -4,7 +4,8 @@ import torch from .dual_conv3d import DualConv3d from .causal_conv3d import CausalConv3d - +import comfy.ops +ops = comfy.ops.disable_weight_init def make_conv_nd( dims: Union[int, Tuple[int, int]], @@ -19,7 +20,7 @@ def make_conv_nd( causal=False, ): if dims == 2: - return torch.nn.Conv2d( + return ops.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, @@ -41,7 +42,7 @@ def make_conv_nd( groups=groups, bias=bias, ) - return torch.nn.Conv3d( + return ops.Conv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, @@ -71,11 +72,11 @@ def make_linear_nd( bias=True, ): if dims == 2: - return torch.nn.Conv2d( + return ops.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias ) elif dims == 3 or dims == (2, 1): - return torch.nn.Conv3d( + return ops.Conv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias ) else: