mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 05:57:20 +00:00
Speed up model loading a bit.
Default pytorch Linear initializes the weights which is useless and slow.
This commit is contained in:
parent
84f13f828a
commit
6971646b8b
@ -10,6 +10,7 @@ from .diffusionmodules.util import checkpoint
|
||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
|
||||
from comfy import model_management
|
||||
import comfy.ops
|
||||
|
||||
from . import tomesd
|
||||
|
||||
@ -52,7 +53,7 @@ def init_(tensor):
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
self.proj = comfy.ops.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
@ -65,14 +66,14 @@ class FeedForward(nn.Module):
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
comfy.ops.Linear(dim, inner_dim),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out)
|
||||
comfy.ops.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@ -154,12 +155,12 @@ class CrossAttentionBirchSan(nn.Module):
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, query_dim),
|
||||
comfy.ops.Linear(inner_dim, query_dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
@ -251,12 +252,12 @@ class CrossAttentionDoggettx(nn.Module):
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, query_dim),
|
||||
comfy.ops.Linear(inner_dim, query_dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
@ -349,12 +350,12 @@ class CrossAttention(nn.Module):
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, query_dim),
|
||||
comfy.ops.Linear(inner_dim, query_dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
@ -407,11 +408,11 @@ class MemoryEfficientCrossAttention(nn.Module):
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
@ -456,11 +457,11 @@ class CrossAttentionPytorch(nn.Module):
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
@ -601,7 +602,7 @@ class SpatialTransformer(nn.Module):
|
||||
stride=1,
|
||||
padding=0)
|
||||
else:
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
self.proj_in = comfy.ops.Linear(in_channels, inner_dim)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
||||
@ -615,7 +616,7 @@ class SpatialTransformer(nn.Module):
|
||||
stride=1,
|
||||
padding=0))
|
||||
else:
|
||||
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
||||
self.proj_out = zero_module(comfy.ops.Linear(in_channels, inner_dim))
|
||||
self.use_linear = use_linear
|
||||
|
||||
def forward(self, x, context=None, transformer_options={}):
|
||||
|
17
comfy/ops.py
Normal file
17
comfy/ops.py
Normal file
@ -0,0 +1,17 @@
|
||||
import torch
|
||||
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight = torch.nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(out_features, **factory_kwargs))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
def forward(self, input):
|
||||
return torch.nn.functional.linear(input, self.weight, self.bias)
|
Loading…
Reference in New Issue
Block a user