mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
LTXV lowvram fixes.
This commit is contained in:
parent
bc6be6c11e
commit
e5c3f4b87f
@ -304,7 +304,7 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
|
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None] + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||||
|
|
||||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
|
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
|
||||||
|
|
||||||
@ -479,7 +479,7 @@ class LTXVModel(torch.nn.Module):
|
|||||||
|
|
||||||
# 3. Output
|
# 3. Output
|
||||||
scale_shift_values = (
|
scale_shift_values = (
|
||||||
self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
|
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
|
||||||
)
|
)
|
||||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||||
x = self.norm_out(x)
|
x = self.norm_out(x)
|
||||||
|
@ -2,6 +2,8 @@ from typing import Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
|
||||||
class CausalConv3d(nn.Module):
|
class CausalConv3d(nn.Module):
|
||||||
@ -29,7 +31,7 @@ class CausalConv3d(nn.Module):
|
|||||||
width_pad = kernel_size[2] // 2
|
width_pad = kernel_size[2] // 2
|
||||||
padding = (0, height_pad, width_pad)
|
padding = (0, height_pad, width_pad)
|
||||||
|
|
||||||
self.conv = nn.Conv3d(
|
self.conv = ops.Conv3d(
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
|
@ -628,10 +628,10 @@ class processor(nn.Module):
|
|||||||
self.register_buffer("channel", torch.empty(128))
|
self.register_buffer("channel", torch.empty(128))
|
||||||
|
|
||||||
def un_normalize(self, x):
|
def un_normalize(self, x):
|
||||||
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1)
|
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||||
|
|
||||||
def normalize(self, x):
|
def normalize(self, x):
|
||||||
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1)
|
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||||
|
|
||||||
class VideoVAE(nn.Module):
|
class VideoVAE(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -4,7 +4,8 @@ import torch
|
|||||||
|
|
||||||
from .dual_conv3d import DualConv3d
|
from .dual_conv3d import DualConv3d
|
||||||
from .causal_conv3d import CausalConv3d
|
from .causal_conv3d import CausalConv3d
|
||||||
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
def make_conv_nd(
|
def make_conv_nd(
|
||||||
dims: Union[int, Tuple[int, int]],
|
dims: Union[int, Tuple[int, int]],
|
||||||
@ -19,7 +20,7 @@ def make_conv_nd(
|
|||||||
causal=False,
|
causal=False,
|
||||||
):
|
):
|
||||||
if dims == 2:
|
if dims == 2:
|
||||||
return torch.nn.Conv2d(
|
return ops.Conv2d(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
@ -41,7 +42,7 @@ def make_conv_nd(
|
|||||||
groups=groups,
|
groups=groups,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
return torch.nn.Conv3d(
|
return ops.Conv3d(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
@ -71,11 +72,11 @@ def make_linear_nd(
|
|||||||
bias=True,
|
bias=True,
|
||||||
):
|
):
|
||||||
if dims == 2:
|
if dims == 2:
|
||||||
return torch.nn.Conv2d(
|
return ops.Conv2d(
|
||||||
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
||||||
)
|
)
|
||||||
elif dims == 3 or dims == (2, 1):
|
elif dims == 3 or dims == (2, 1):
|
||||||
return torch.nn.Conv3d(
|
return ops.Conv3d(
|
||||||
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user