diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 573f4e1c6..4670ca578 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -10,6 +10,7 @@ from .diffusionmodules.util import checkpoint from .sub_quadratic_attention import efficient_dot_product_attention from comfy import model_management +import comfy.ops from . import tomesd @@ -52,7 +53,7 @@ def init_(tensor): class GEGLU(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) + self.proj = comfy.ops.Linear(dim_in, dim_out * 2) def forward(self, x): x, gate = self.proj(x).chunk(2, dim=-1) @@ -65,14 +66,14 @@ class FeedForward(nn.Module): inner_dim = int(dim * mult) dim_out = default(dim_out, dim) project_in = nn.Sequential( - nn.Linear(dim, inner_dim), + comfy.ops.Linear(dim, inner_dim), nn.GELU() ) if not glu else GEGLU(dim, inner_dim) self.net = nn.Sequential( project_in, nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) + comfy.ops.Linear(inner_dim, dim_out) ) def forward(self, x): @@ -154,12 +155,12 @@ class CrossAttentionBirchSan(nn.Module): self.scale = dim_head ** -0.5 self.heads = heads - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False) + self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False) + self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), + comfy.ops.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) @@ -251,12 +252,12 @@ class CrossAttentionDoggettx(nn.Module): self.scale = dim_head ** -0.5 self.heads = heads - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False) + self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False) + self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), + comfy.ops.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) @@ -349,12 +350,12 @@ class CrossAttention(nn.Module): self.scale = dim_head ** -0.5 self.heads = heads - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False) + self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False) + self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), + comfy.ops.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) @@ -407,11 +408,11 @@ class MemoryEfficientCrossAttention(nn.Module): self.heads = heads self.dim_head = dim_head - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False) + self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False) + self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False) - self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim), nn.Dropout(dropout)) self.attention_op: Optional[Any] = None def forward(self, x, context=None, value=None, mask=None): @@ -456,11 +457,11 @@ class CrossAttentionPytorch(nn.Module): self.heads = heads self.dim_head = dim_head - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False) + self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False) + self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False) - self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim), nn.Dropout(dropout)) self.attention_op: Optional[Any] = None def forward(self, x, context=None, value=None, mask=None): @@ -601,7 +602,7 @@ class SpatialTransformer(nn.Module): stride=1, padding=0) else: - self.proj_in = nn.Linear(in_channels, inner_dim) + self.proj_in = comfy.ops.Linear(in_channels, inner_dim) self.transformer_blocks = nn.ModuleList( [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], @@ -615,7 +616,7 @@ class SpatialTransformer(nn.Module): stride=1, padding=0)) else: - self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.proj_out = zero_module(comfy.ops.Linear(in_channels, inner_dim)) self.use_linear = use_linear def forward(self, x, context=None, transformer_options={}): diff --git a/comfy/ops.py b/comfy/ops.py new file mode 100644 index 000000000..0654dbcd9 --- /dev/null +++ b/comfy/ops.py @@ -0,0 +1,17 @@ +import torch + +class Linear(torch.nn.Module): + def __init__(self, in_features: int, out_features: int, bias: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = torch.nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs)) + if bias: + self.bias = torch.nn.Parameter(torch.empty(out_features, **factory_kwargs)) + else: + self.register_parameter('bias', None) + + def forward(self, input): + return torch.nn.functional.linear(input, self.weight, self.bias)