Support for Control Loras.

Control loras are controlnets where some of the weights are stored in
"lora" format: an up and a down low rank matrice that when multiplied
together and added to the unet weight give the controlnet weight.

This allows a much smaller memory footprint depending on the rank of the
matrices.

These controlnets are used just like regular ones.
This commit is contained in:
comfyanonymous 2023-08-18 02:46:11 -04:00
parent 39ac856a33
commit d6e4b342e6
6 changed files with 216 additions and 92 deletions

View File

@ -6,8 +6,6 @@ import torch as th
import torch.nn as nn import torch.nn as nn
from ..ldm.modules.diffusionmodules.util import ( from ..ldm.modules.diffusionmodules.util import (
conv_nd,
linear,
zero_module, zero_module,
timestep_embedding, timestep_embedding,
) )
@ -15,7 +13,7 @@ from ..ldm.modules.diffusionmodules.util import (
from ..ldm.modules.attention import SpatialTransformer from ..ldm.modules.attention import SpatialTransformer
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
from ..ldm.util import exists from ..ldm.util import exists
import comfy.ops
class ControlledUnetModel(UNetModel): class ControlledUnetModel(UNetModel):
#implemented in the ldm unet #implemented in the ldm unet
@ -55,6 +53,8 @@ class ControlNet(nn.Module):
use_linear_in_transformer=False, use_linear_in_transformer=False,
adm_in_channels=None, adm_in_channels=None,
transformer_depth_middle=None, transformer_depth_middle=None,
device=None,
operations=comfy.ops,
): ):
super().__init__() super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true" assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
@ -117,9 +117,9 @@ class ControlNet(nn.Module):
time_embed_dim = model_channels * 4 time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential( self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim), operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
linear(time_embed_dim, time_embed_dim), operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
) )
if self.num_classes is not None: if self.num_classes is not None:
@ -132,9 +132,9 @@ class ControlNet(nn.Module):
assert adm_in_channels is not None assert adm_in_channels is not None
self.label_emb = nn.Sequential( self.label_emb = nn.Sequential(
nn.Sequential( nn.Sequential(
linear(adm_in_channels, time_embed_dim), operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
linear(time_embed_dim, time_embed_dim), operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
) )
) )
else: else:
@ -143,28 +143,28 @@ class ControlNet(nn.Module):
self.input_blocks = nn.ModuleList( self.input_blocks = nn.ModuleList(
[ [
TimestepEmbedSequential( TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1) operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
) )
] ]
) )
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)]) self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations)])
self.input_hint_block = TimestepEmbedSequential( self.input_hint_block = TimestepEmbedSequential(
conv_nd(dims, hint_channels, 16, 3, padding=1), operations.conv_nd(dims, hint_channels, 16, 3, padding=1),
nn.SiLU(), nn.SiLU(),
conv_nd(dims, 16, 16, 3, padding=1), operations.conv_nd(dims, 16, 16, 3, padding=1),
nn.SiLU(), nn.SiLU(),
conv_nd(dims, 16, 32, 3, padding=1, stride=2), operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2),
nn.SiLU(), nn.SiLU(),
conv_nd(dims, 32, 32, 3, padding=1), operations.conv_nd(dims, 32, 32, 3, padding=1),
nn.SiLU(), nn.SiLU(),
conv_nd(dims, 32, 96, 3, padding=1, stride=2), operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2),
nn.SiLU(), nn.SiLU(),
conv_nd(dims, 96, 96, 3, padding=1), operations.conv_nd(dims, 96, 96, 3, padding=1),
nn.SiLU(), nn.SiLU(),
conv_nd(dims, 96, 256, 3, padding=1, stride=2), operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2),
nn.SiLU(), nn.SiLU(),
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1)) zero_module(operations.conv_nd(dims, 256, model_channels, 3, padding=1))
) )
self._feature_size = model_channels self._feature_size = model_channels
@ -182,6 +182,7 @@ class ControlNet(nn.Module):
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
operations=operations
) )
] ]
ch = mult * model_channels ch = mult * model_channels
@ -204,11 +205,11 @@ class ControlNet(nn.Module):
SpatialTransformer( SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint use_checkpoint=use_checkpoint, operations=operations
) )
) )
self.input_blocks.append(TimestepEmbedSequential(*layers)) self.input_blocks.append(TimestepEmbedSequential(*layers))
self.zero_convs.append(self.make_zero_conv(ch)) self.zero_convs.append(self.make_zero_conv(ch, operations=operations))
self._feature_size += ch self._feature_size += ch
input_block_chans.append(ch) input_block_chans.append(ch)
if level != len(channel_mult) - 1: if level != len(channel_mult) - 1:
@ -224,16 +225,17 @@ class ControlNet(nn.Module):
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
down=True, down=True,
operations=operations
) )
if resblock_updown if resblock_updown
else Downsample( else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch ch, conv_resample, dims=dims, out_channels=out_ch, operations=operations
) )
) )
) )
ch = out_ch ch = out_ch
input_block_chans.append(ch) input_block_chans.append(ch)
self.zero_convs.append(self.make_zero_conv(ch)) self.zero_convs.append(self.make_zero_conv(ch, operations=operations))
ds *= 2 ds *= 2
self._feature_size += ch self._feature_size += ch
@ -253,11 +255,12 @@ class ControlNet(nn.Module):
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
operations=operations
), ),
SpatialTransformer( # always uses a self-attn SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint use_checkpoint=use_checkpoint, operations=operations
), ),
ResBlock( ResBlock(
ch, ch,
@ -266,13 +269,14 @@ class ControlNet(nn.Module):
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
operations=operations
), ),
) )
self.middle_block_out = self.make_zero_conv(ch) self.middle_block_out = self.make_zero_conv(ch, operations=operations)
self._feature_size += ch self._feature_size += ch
def make_zero_conv(self, channels): def make_zero_conv(self, channels, operations=None):
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0))) return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0)))
def forward(self, x, hint, timesteps, context, y=None, **kwargs): def forward(self, x, hint, timesteps, context, y=None, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)

View File

@ -10,7 +10,6 @@ from .diffusionmodules.util import checkpoint
from .sub_quadratic_attention import efficient_dot_product_attention from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management from comfy import model_management
import comfy.ops
if model_management.xformers_enabled(): if model_management.xformers_enabled():
import xformers import xformers
@ -52,9 +51,9 @@ def init_(tensor):
# feedforward # feedforward
class GEGLU(nn.Module): class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out, dtype=None, device=None): def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
self.proj = comfy.ops.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)
def forward(self, x): def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1) x, gate = self.proj(x).chunk(2, dim=-1)
@ -62,19 +61,19 @@ class GEGLU(nn.Module):
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None):
super().__init__() super().__init__()
inner_dim = int(dim * mult) inner_dim = int(dim * mult)
dim_out = default(dim_out, dim) dim_out = default(dim_out, dim)
project_in = nn.Sequential( project_in = nn.Sequential(
comfy.ops.Linear(dim, inner_dim, dtype=dtype, device=device), operations.Linear(dim, inner_dim, dtype=dtype, device=device),
nn.GELU() nn.GELU()
) if not glu else GEGLU(dim, inner_dim, dtype=dtype, device=device) ) if not glu else GEGLU(dim, inner_dim, dtype=dtype, device=device, operations=operations)
self.net = nn.Sequential( self.net = nn.Sequential(
project_in, project_in,
nn.Dropout(dropout), nn.Dropout(dropout),
comfy.ops.Linear(inner_dim, dim_out, dtype=dtype, device=device) operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
) )
def forward(self, x): def forward(self, x):
@ -148,7 +147,7 @@ class SpatialSelfAttention(nn.Module):
class CrossAttentionBirchSan(nn.Module): class CrossAttentionBirchSan(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=None):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
@ -156,12 +155,12 @@ class CrossAttentionBirchSan(nn.Module):
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
self.heads = heads self.heads = heads
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout) nn.Dropout(dropout)
) )
@ -245,7 +244,7 @@ class CrossAttentionBirchSan(nn.Module):
class CrossAttentionDoggettx(nn.Module): class CrossAttentionDoggettx(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=None):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
@ -253,12 +252,12 @@ class CrossAttentionDoggettx(nn.Module):
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
self.heads = heads self.heads = heads
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout) nn.Dropout(dropout)
) )
@ -343,7 +342,7 @@ class CrossAttentionDoggettx(nn.Module):
return self.to_out(r2) return self.to_out(r2)
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=None):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
@ -351,12 +350,12 @@ class CrossAttention(nn.Module):
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
self.heads = heads self.heads = heads
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout) nn.Dropout(dropout)
) )
@ -399,7 +398,7 @@ class CrossAttention(nn.Module):
class MemoryEfficientCrossAttention(nn.Module): class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{heads} heads.") f"{heads} heads.")
@ -409,11 +408,11 @@ class MemoryEfficientCrossAttention(nn.Module):
self.heads = heads self.heads = heads
self.dim_head = dim_head self.dim_head = dim_head
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None self.attention_op: Optional[Any] = None
def forward(self, x, context=None, value=None, mask=None): def forward(self, x, context=None, value=None, mask=None):
@ -450,7 +449,7 @@ class MemoryEfficientCrossAttention(nn.Module):
return self.to_out(out) return self.to_out(out)
class CrossAttentionPytorch(nn.Module): class CrossAttentionPytorch(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=None):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
@ -458,11 +457,11 @@ class CrossAttentionPytorch(nn.Module):
self.heads = heads self.heads = heads
self.dim_head = dim_head self.dim_head = dim_head
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None self.attention_op: Optional[Any] = None
def forward(self, x, context=None, value=None, mask=None): def forward(self, x, context=None, value=None, mask=None):
@ -508,14 +507,14 @@ else:
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
disable_self_attn=False, dtype=None, device=None): disable_self_attn=False, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
self.disable_self_attn = disable_self_attn self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device) # is a self-attention if not self.disable_self_attn context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device) self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device) # is self-attn if context is none heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device) self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device) self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device) self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device)
@ -648,7 +647,7 @@ class SpatialTransformer(nn.Module):
def __init__(self, in_channels, n_heads, d_head, def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None, depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False, disable_self_attn=False, use_linear=False,
use_checkpoint=True, dtype=None, device=None): use_checkpoint=True, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
if exists(context_dim) and not isinstance(context_dim, list): if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim] * depth context_dim = [context_dim] * depth
@ -656,26 +655,26 @@ class SpatialTransformer(nn.Module):
inner_dim = n_heads * d_head inner_dim = n_heads * d_head
self.norm = Normalize(in_channels, dtype=dtype, device=device) self.norm = Normalize(in_channels, dtype=dtype, device=device)
if not use_linear: if not use_linear:
self.proj_in = nn.Conv2d(in_channels, self.proj_in = operations.Conv2d(in_channels,
inner_dim, inner_dim,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, dtype=dtype, device=device) padding=0, dtype=dtype, device=device)
else: else:
self.proj_in = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype, device=device) self.proj_in = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
self.transformer_blocks = nn.ModuleList( self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device) disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device, operations=operations)
for d in range(depth)] for d in range(depth)]
) )
if not use_linear: if not use_linear:
self.proj_out = nn.Conv2d(inner_dim,in_channels, self.proj_out = operations.Conv2d(inner_dim,in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, dtype=dtype, device=device) padding=0, dtype=dtype, device=device)
else: else:
self.proj_out = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype, device=device) self.proj_out = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
self.use_linear = use_linear self.use_linear = use_linear
def forward(self, x, context=None, transformer_options={}): def forward(self, x, context=None, transformer_options={}):

View File

@ -8,8 +8,6 @@ import torch.nn.functional as F
from .util import ( from .util import (
checkpoint, checkpoint,
conv_nd,
linear,
avg_pool_nd, avg_pool_nd,
zero_module, zero_module,
normalization, normalization,
@ -17,7 +15,7 @@ from .util import (
) )
from ..attention import SpatialTransformer from ..attention import SpatialTransformer
from comfy.ldm.util import exists from comfy.ldm.util import exists
import comfy.ops
class TimestepBlock(nn.Module): class TimestepBlock(nn.Module):
""" """
@ -72,14 +70,14 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions. upsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None): def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
self.use_conv = use_conv self.use_conv = use_conv
self.dims = dims self.dims = dims
if use_conv: if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype, device=device) self.conv = operations.conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype, device=device)
def forward(self, x, output_shape=None): def forward(self, x, output_shape=None):
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
@ -108,7 +106,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions. downsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None): def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
@ -116,7 +114,7 @@ class Downsample(nn.Module):
self.dims = dims self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2) stride = 2 if dims != 3 else (1, 2, 2)
if use_conv: if use_conv:
self.op = conv_nd( self.op = operations.conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype, device=device dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype, device=device
) )
else: else:
@ -158,6 +156,7 @@ class ResBlock(TimestepBlock):
down=False, down=False,
dtype=None, dtype=None,
device=None, device=None,
operations=None
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
@ -171,7 +170,7 @@ class ResBlock(TimestepBlock):
self.in_layers = nn.Sequential( self.in_layers = nn.Sequential(
nn.GroupNorm(32, channels, dtype=dtype, device=device), nn.GroupNorm(32, channels, dtype=dtype, device=device),
nn.SiLU(), nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device), operations.conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device),
) )
self.updown = up or down self.updown = up or down
@ -187,7 +186,7 @@ class ResBlock(TimestepBlock):
self.emb_layers = nn.Sequential( self.emb_layers = nn.Sequential(
nn.SiLU(), nn.SiLU(),
linear( operations.Linear(
emb_channels, emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device 2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
), ),
@ -197,18 +196,18 @@ class ResBlock(TimestepBlock):
nn.SiLU(), nn.SiLU(),
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
zero_module( zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device) operations.conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device)
), ),
) )
if self.out_channels == channels: if self.out_channels == channels:
self.skip_connection = nn.Identity() self.skip_connection = nn.Identity()
elif use_conv: elif use_conv:
self.skip_connection = conv_nd( self.skip_connection = operations.conv_nd(
dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device
) )
else: else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device) self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
def forward(self, x, emb): def forward(self, x, emb):
""" """
@ -317,6 +316,7 @@ class UNetModel(nn.Module):
adm_in_channels=None, adm_in_channels=None,
transformer_depth_middle=None, transformer_depth_middle=None,
device=None, device=None,
operations=comfy.ops,
): ):
super().__init__() super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true" assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
@ -379,9 +379,9 @@ class UNetModel(nn.Module):
time_embed_dim = model_channels * 4 time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential( self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim, dtype=self.dtype, device=device), operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device), operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
) )
if self.num_classes is not None: if self.num_classes is not None:
@ -394,9 +394,9 @@ class UNetModel(nn.Module):
assert adm_in_channels is not None assert adm_in_channels is not None
self.label_emb = nn.Sequential( self.label_emb = nn.Sequential(
nn.Sequential( nn.Sequential(
linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device), operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device), operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
) )
) )
else: else:
@ -405,7 +405,7 @@ class UNetModel(nn.Module):
self.input_blocks = nn.ModuleList( self.input_blocks = nn.ModuleList(
[ [
TimestepEmbedSequential( TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device) operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
) )
] ]
) )
@ -426,6 +426,7 @@ class UNetModel(nn.Module):
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype, dtype=self.dtype,
device=device, device=device,
operations=operations,
) )
] ]
ch = mult * model_channels ch = mult * model_channels
@ -447,7 +448,7 @@ class UNetModel(nn.Module):
layers.append(SpatialTransformer( layers.append(SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
) )
) )
self.input_blocks.append(TimestepEmbedSequential(*layers)) self.input_blocks.append(TimestepEmbedSequential(*layers))
@ -468,10 +469,11 @@ class UNetModel(nn.Module):
down=True, down=True,
dtype=self.dtype, dtype=self.dtype,
device=device, device=device,
operations=operations
) )
if resblock_updown if resblock_updown
else Downsample( else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
) )
) )
) )
@ -498,11 +500,12 @@ class UNetModel(nn.Module):
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype, dtype=self.dtype,
device=device, device=device,
operations=operations
), ),
SpatialTransformer( # always uses a self-attn SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
), ),
ResBlock( ResBlock(
ch, ch,
@ -513,6 +516,7 @@ class UNetModel(nn.Module):
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype, dtype=self.dtype,
device=device, device=device,
operations=operations
), ),
) )
self._feature_size += ch self._feature_size += ch
@ -532,6 +536,7 @@ class UNetModel(nn.Module):
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype, dtype=self.dtype,
device=device, device=device,
operations=operations
) )
] ]
ch = model_channels * mult ch = model_channels * mult
@ -554,7 +559,7 @@ class UNetModel(nn.Module):
SpatialTransformer( SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
) )
) )
if level and i == self.num_res_blocks[level]: if level and i == self.num_res_blocks[level]:
@ -571,9 +576,10 @@ class UNetModel(nn.Module):
up=True, up=True,
dtype=self.dtype, dtype=self.dtype,
device=device, device=device,
operations=operations
) )
if resblock_updown if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device) else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations)
) )
ds //= 2 ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers)) self.output_blocks.append(TimestepEmbedSequential(*layers))
@ -582,12 +588,12 @@ class UNetModel(nn.Module):
self.out = nn.Sequential( self.out = nn.Sequential(
nn.GroupNorm(32, ch, dtype=self.dtype, device=device), nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)), zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
) )
if self.predict_codebook_ids: if self.predict_codebook_ids:
self.id_predictor = nn.Sequential( self.id_predictor = nn.Sequential(
nn.GroupNorm(32, ch, dtype=self.dtype, device=device), nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device), operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
) )

View File

@ -21,6 +21,11 @@ class Conv2d(torch.nn.Conv2d):
def reset_parameters(self): def reset_parameters(self):
return None return None
def conv_nd(dims, *args, **kwargs):
if dims == 2:
return Conv2d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")
@contextmanager @contextmanager
def use_comfy_ops(): # Kind of an ugly hack but I can't think of a better way def use_comfy_ops(): # Kind of an ugly hack but I can't think of a better way

View File

@ -478,7 +478,7 @@ def pre_run_control(model, conds):
timestep_end = None timestep_end = None
percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0)) percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0))
if 'control' in x[1]: if 'control' in x[1]:
x[1]['control'].pre_run(model.inner_model, percent_to_timestep_function) x[1]['control'].pre_run(model.inner_model.inner_model, percent_to_timestep_function)
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = [] cond_cnets = []

View File

@ -844,9 +844,119 @@ class ControlNet(ControlBase):
out.append(self.control_model_wrapped) out.append(self.control_model_wrapped)
return out return out
class ControlLoraOps:
class Linear(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = None
self.up = None
self.down = None
self.bias = None
def forward(self, input):
if self.up is not None:
return torch.nn.functional.linear(input, self.weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(self.weight.dtype), self.bias)
else:
return torch.nn.functional.linear(input, self.weight, self.bias)
class Conv2d(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode='zeros',
device=None,
dtype=None
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.transposed = False
self.output_padding = 0
self.groups = groups
self.padding_mode = padding_mode
self.weight = None
self.bias = None
self.up = None
self.down = None
def forward(self, input):
if self.up is not None:
return torch.nn.functional.conv2d(input, self.weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(self.weight.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
else:
return torch.nn.functional.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
def conv_nd(self, dims, *args, **kwargs):
if dims == 2:
return self.Conv2d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")
class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, device=None):
ControlBase.__init__(self, device)
self.control_weights = control_weights
self.global_average_pooling = global_average_pooling
def pre_run(self, model, percent_to_timestep_function):
super().pre_run(model, percent_to_timestep_function)
controlnet_config = model.model_config.unet_config.copy()
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
controlnet_config["operations"] = ControlLoraOps()
self.control_model = cldm.ControlNet(**controlnet_config)
if model_management.should_use_fp16():
self.control_model.half()
self.control_model.to(model_management.get_torch_device())
diffusion_model = model.diffusion_model
sd = diffusion_model.state_dict()
cm = self.control_model.state_dict()
for k in sd:
try:
set_attr(self.control_model, k, sd[k])
except:
pass
for k in self.control_weights:
if k not in {"lora_controlnet"}:
set_attr(self.control_model, k, self.control_weights[k].to(model_management.get_torch_device()))
def copy(self):
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
self.copy_to(c)
return c
def cleanup(self):
del self.control_model
self.control_model = None
super().cleanup()
def get_models(self):
out = ControlBase.get_models(self)
return out
def load_controlnet(ckpt_path, model=None): def load_controlnet(ckpt_path, model=None):
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True) controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
if "lora_controlnet" in controlnet_data:
return ControlLora(controlnet_data)
controlnet_config = None controlnet_config = None
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format