Speed up model loading a bit.

Default pytorch Linear initializes the weights which is useless and slow.
This commit is contained in:
comfyanonymous 2023-06-14 11:17:59 -04:00
parent 84f13f828a
commit 6971646b8b
2 changed files with 43 additions and 25 deletions

View File

@ -10,6 +10,7 @@ from .diffusionmodules.util import checkpoint
from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management
import comfy.ops
from . import tomesd
@ -52,7 +53,7 @@ def init_(tensor):
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
self.proj = comfy.ops.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
@ -65,14 +66,14 @@ class FeedForward(nn.Module):
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
comfy.ops.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
comfy.ops.Linear(inner_dim, dim_out)
)
def forward(self, x):
@ -154,12 +155,12 @@ class CrossAttentionBirchSan(nn.Module):
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
comfy.ops.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
@ -251,12 +252,12 @@ class CrossAttentionDoggettx(nn.Module):
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
comfy.ops.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
@ -349,12 +350,12 @@ class CrossAttention(nn.Module):
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
comfy.ops.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
@ -407,11 +408,11 @@ class MemoryEfficientCrossAttention(nn.Module):
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, value=None, mask=None):
@ -456,11 +457,11 @@ class CrossAttentionPytorch(nn.Module):
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, value=None, mask=None):
@ -601,7 +602,7 @@ class SpatialTransformer(nn.Module):
stride=1,
padding=0)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
self.proj_in = comfy.ops.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
@ -615,7 +616,7 @@ class SpatialTransformer(nn.Module):
stride=1,
padding=0))
else:
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.proj_out = zero_module(comfy.ops.Linear(in_channels, inner_dim))
self.use_linear = use_linear
def forward(self, x, context=None, transformer_options={}):

17
comfy/ops.py Normal file
View File

@ -0,0 +1,17 @@
import torch
class Linear(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = torch.nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter('bias', None)
def forward(self, input):
return torch.nn.functional.linear(input, self.weight, self.bias)