LTXV lowvram fixes.

This commit is contained in:
comfyanonymous 2024-11-22 17:17:11 -05:00
parent bc6be6c11e
commit e5c3f4b87f
4 changed files with 13 additions and 10 deletions

View File

@ -304,7 +304,7 @@ class BasicTransformerBlock(nn.Module):
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None): def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None] + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
@ -479,7 +479,7 @@ class LTXVModel(torch.nn.Module):
# 3. Output # 3. Output
scale_shift_values = ( scale_shift_values = (
self.scale_shift_table[None, None] + embedded_timestep[:, :, None] self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
) )
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
x = self.norm_out(x) x = self.norm_out(x)

View File

@ -2,6 +2,8 @@ from typing import Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import comfy.ops
ops = comfy.ops.disable_weight_init
class CausalConv3d(nn.Module): class CausalConv3d(nn.Module):
@ -29,7 +31,7 @@ class CausalConv3d(nn.Module):
width_pad = kernel_size[2] // 2 width_pad = kernel_size[2] // 2
padding = (0, height_pad, width_pad) padding = (0, height_pad, width_pad)
self.conv = nn.Conv3d( self.conv = ops.Conv3d(
in_channels, in_channels,
out_channels, out_channels,
kernel_size, kernel_size,

View File

@ -628,10 +628,10 @@ class processor(nn.Module):
self.register_buffer("channel", torch.empty(128)) self.register_buffer("channel", torch.empty(128))
def un_normalize(self, x): def un_normalize(self, x):
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1) return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
def normalize(self, x): def normalize(self, x):
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1) return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
class VideoVAE(nn.Module): class VideoVAE(nn.Module):
def __init__(self): def __init__(self):

View File

@ -4,7 +4,8 @@ import torch
from .dual_conv3d import DualConv3d from .dual_conv3d import DualConv3d
from .causal_conv3d import CausalConv3d from .causal_conv3d import CausalConv3d
import comfy.ops
ops = comfy.ops.disable_weight_init
def make_conv_nd( def make_conv_nd(
dims: Union[int, Tuple[int, int]], dims: Union[int, Tuple[int, int]],
@ -19,7 +20,7 @@ def make_conv_nd(
causal=False, causal=False,
): ):
if dims == 2: if dims == 2:
return torch.nn.Conv2d( return ops.Conv2d(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
kernel_size=kernel_size, kernel_size=kernel_size,
@ -41,7 +42,7 @@ def make_conv_nd(
groups=groups, groups=groups,
bias=bias, bias=bias,
) )
return torch.nn.Conv3d( return ops.Conv3d(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
kernel_size=kernel_size, kernel_size=kernel_size,
@ -71,11 +72,11 @@ def make_linear_nd(
bias=True, bias=True,
): ):
if dims == 2: if dims == 2:
return torch.nn.Conv2d( return ops.Conv2d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
) )
elif dims == 3 or dims == (2, 1): elif dims == 3 or dims == (2, 1):
return torch.nn.Conv3d( return ops.Conv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
) )
else: else: