2023-01-03 06:53:32 +00:00
import torch
2023-02-17 20:31:38 +00:00
import contextlib
2023-03-31 21:19:58 +00:00
import copy
2023-06-13 20:05:26 +00:00
import inspect
2023-01-03 06:53:32 +00:00
2023-04-15 22:55:17 +00:00
from comfy import model_management
2023-03-07 16:00:35 +00:00
from . ldm . util import instantiate_from_config
from . ldm . models . autoencoder import AutoencoderKL
2023-03-13 18:49:18 +00:00
import yaml
2023-02-16 15:38:08 +00:00
from . cldm import cldm
2023-02-25 05:55:42 +00:00
from . t2i_adapter import adapter
2023-02-16 15:38:08 +00:00
from . import utils
2023-04-02 03:19:15 +00:00
from . import clip_vision
2023-04-19 13:36:19 +00:00
from . import gligen
2023-05-28 06:02:09 +00:00
from . import diffusers_convert
2023-06-09 16:24:24 +00:00
from . import model_base
2023-06-22 17:03:50 +00:00
from . import model_detection
2023-02-03 07:06:34 +00:00
2023-06-22 17:03:50 +00:00
from . import sd1_clip
from . import sd2_clip
2023-06-25 05:40:38 +00:00
from . import sdxl_clip
2023-06-09 16:24:24 +00:00
2023-06-22 17:03:50 +00:00
def load_model_weights ( model , sd ) :
2023-01-03 06:53:32 +00:00
m , u = model . load_state_dict ( sd , strict = False )
2023-06-22 17:03:50 +00:00
m = set ( m )
unexpected_keys = set ( u )
2023-01-03 06:53:32 +00:00
k = list ( sd . keys ( ) )
for x in k :
2023-06-22 17:03:50 +00:00
if x not in unexpected_keys :
w = sd . pop ( x )
del w
if len ( m ) > 0 :
print ( " missing " , m )
return model
def load_clip_weights ( model , sd ) :
k = list ( sd . keys ( ) )
for x in k :
2023-01-03 06:53:32 +00:00
if x . startswith ( " cond_stage_model.transformer. " ) and not x . startswith ( " cond_stage_model.transformer.text_model. " ) :
y = x . replace ( " cond_stage_model.transformer. " , " cond_stage_model.transformer.text_model. " )
sd [ y ] = sd . pop ( x )
2023-01-28 07:14:22 +00:00
if ' cond_stage_model.transformer.text_model.embeddings.position_ids ' in sd :
ids = sd [ ' cond_stage_model.transformer.text_model.embeddings.position_ids ' ]
if ids . dtype == torch . float32 :
sd [ ' cond_stage_model.transformer.text_model.embeddings.position_ids ' ] = ids . round ( )
2023-01-28 05:19:33 +00:00
2023-06-22 17:03:50 +00:00
sd = utils . transformers_convert ( sd , " cond_stage_model.model. " , " cond_stage_model.transformer.text_model. " , 24 )
return load_model_weights ( model , sd )
2023-01-03 06:53:32 +00:00
2023-02-03 07:06:34 +00:00
LORA_CLIP_MAP = {
" mlp.fc1 " : " mlp_fc1 " ,
" mlp.fc2 " : " mlp_fc2 " ,
" self_attn.k_proj " : " self_attn_k_proj " ,
" self_attn.q_proj " : " self_attn_q_proj " ,
" self_attn.v_proj " : " self_attn_v_proj " ,
" self_attn.out_proj " : " self_attn_out_proj " ,
}
2023-03-10 02:41:24 +00:00
LORA_UNET_MAP_ATTENTIONS = {
2023-02-03 07:06:34 +00:00
" proj_in " : " proj_in " ,
" proj_out " : " proj_out " ,
}
2023-06-22 17:03:50 +00:00
transformer_lora_blocks = {
" transformer_blocks. {} .attn1.to_q " : " transformer_blocks_ {} _attn1_to_q " ,
" transformer_blocks. {} .attn1.to_k " : " transformer_blocks_ {} _attn1_to_k " ,
" transformer_blocks. {} .attn1.to_v " : " transformer_blocks_ {} _attn1_to_v " ,
" transformer_blocks. {} .attn1.to_out.0 " : " transformer_blocks_ {} _attn1_to_out_0 " ,
" transformer_blocks. {} .attn2.to_q " : " transformer_blocks_ {} _attn2_to_q " ,
" transformer_blocks. {} .attn2.to_k " : " transformer_blocks_ {} _attn2_to_k " ,
" transformer_blocks. {} .attn2.to_v " : " transformer_blocks_ {} _attn2_to_v " ,
" transformer_blocks. {} .attn2.to_out.0 " : " transformer_blocks_ {} _attn2_to_out_0 " ,
" transformer_blocks. {} .ff.net.0.proj " : " transformer_blocks_ {} _ff_net_0_proj " ,
" transformer_blocks. {} .ff.net.2 " : " transformer_blocks_ {} _ff_net_2 " ,
}
for i in range ( 10 ) :
for k in transformer_lora_blocks :
LORA_UNET_MAP_ATTENTIONS [ k . format ( i ) ] = transformer_lora_blocks [ k ] . format ( i )
2023-03-10 02:41:24 +00:00
LORA_UNET_MAP_RESNET = {
" in_layers.2 " : " resnets_ {} _conv1 " ,
" emb_layers.1 " : " resnets_ {} _time_emb_proj " ,
" out_layers.3 " : " resnets_ {} _conv2 " ,
" skip_connection " : " resnets_ {} _conv_shortcut "
}
2023-02-03 07:06:34 +00:00
def load_lora ( path , to_load ) :
2023-06-13 14:12:03 +00:00
lora = utils . load_torch_file ( path , safe_load = True )
2023-02-03 07:06:34 +00:00
patch_dict = { }
loaded_keys = set ( )
for x in to_load :
2023-03-23 07:40:12 +00:00
alpha_name = " {} .alpha " . format ( x )
alpha = None
if alpha_name in lora . keys ( ) :
alpha = lora [ alpha_name ] . item ( )
loaded_keys . add ( alpha_name )
2023-02-03 07:06:34 +00:00
A_name = " {} .lora_up.weight " . format ( x )
B_name = " {} .lora_down.weight " . format ( x )
2023-03-21 18:51:51 +00:00
mid_name = " {} .lora_mid.weight " . format ( x )
2023-03-23 07:40:12 +00:00
2023-02-03 07:06:34 +00:00
if A_name in lora . keys ( ) :
2023-03-21 18:51:51 +00:00
mid = None
if mid_name in lora . keys ( ) :
mid = lora [ mid_name ]
loaded_keys . add ( mid_name )
patch_dict [ to_load [ x ] ] = ( lora [ A_name ] , lora [ B_name ] , alpha , mid )
2023-02-03 07:06:34 +00:00
loaded_keys . add ( A_name )
loaded_keys . add ( B_name )
2023-03-23 07:40:12 +00:00
2023-05-01 22:11:58 +00:00
######## loha
2023-03-23 07:40:12 +00:00
hada_w1_a_name = " {} .hada_w1_a " . format ( x )
hada_w1_b_name = " {} .hada_w1_b " . format ( x )
hada_w2_a_name = " {} .hada_w2_a " . format ( x )
hada_w2_b_name = " {} .hada_w2_b " . format ( x )
2023-03-23 08:32:25 +00:00
hada_t1_name = " {} .hada_t1 " . format ( x )
hada_t2_name = " {} .hada_t2 " . format ( x )
2023-03-23 07:40:12 +00:00
if hada_w1_a_name in lora . keys ( ) :
2023-03-23 08:32:25 +00:00
hada_t1 = None
hada_t2 = None
if hada_t1_name in lora . keys ( ) :
hada_t1 = lora [ hada_t1_name ]
hada_t2 = lora [ hada_t2_name ]
loaded_keys . add ( hada_t1_name )
loaded_keys . add ( hada_t2_name )
patch_dict [ to_load [ x ] ] = ( lora [ hada_w1_a_name ] , lora [ hada_w1_b_name ] , alpha , lora [ hada_w2_a_name ] , lora [ hada_w2_b_name ] , hada_t1 , hada_t2 )
2023-03-23 07:40:12 +00:00
loaded_keys . add ( hada_w1_a_name )
loaded_keys . add ( hada_w1_b_name )
loaded_keys . add ( hada_w2_a_name )
loaded_keys . add ( hada_w2_b_name )
2023-05-01 22:11:58 +00:00
######## lokr
lokr_w1_name = " {} .lokr_w1 " . format ( x )
lokr_w2_name = " {} .lokr_w2 " . format ( x )
lokr_w1_a_name = " {} .lokr_w1_a " . format ( x )
lokr_w1_b_name = " {} .lokr_w1_b " . format ( x )
lokr_t2_name = " {} .lokr_t2 " . format ( x )
lokr_w2_a_name = " {} .lokr_w2_a " . format ( x )
lokr_w2_b_name = " {} .lokr_w2_b " . format ( x )
lokr_w1 = None
if lokr_w1_name in lora . keys ( ) :
lokr_w1 = lora [ lokr_w1_name ]
loaded_keys . add ( lokr_w1_name )
lokr_w2 = None
if lokr_w2_name in lora . keys ( ) :
lokr_w2 = lora [ lokr_w2_name ]
loaded_keys . add ( lokr_w2_name )
lokr_w1_a = None
if lokr_w1_a_name in lora . keys ( ) :
lokr_w1_a = lora [ lokr_w1_a_name ]
loaded_keys . add ( lokr_w1_a_name )
lokr_w1_b = None
if lokr_w1_b_name in lora . keys ( ) :
lokr_w1_b = lora [ lokr_w1_b_name ]
loaded_keys . add ( lokr_w1_b_name )
lokr_w2_a = None
if lokr_w2_a_name in lora . keys ( ) :
lokr_w2_a = lora [ lokr_w2_a_name ]
loaded_keys . add ( lokr_w2_a_name )
lokr_w2_b = None
if lokr_w2_b_name in lora . keys ( ) :
lokr_w2_b = lora [ lokr_w2_b_name ]
loaded_keys . add ( lokr_w2_b_name )
lokr_t2 = None
if lokr_t2_name in lora . keys ( ) :
lokr_t2 = lora [ lokr_t2_name ]
loaded_keys . add ( lokr_t2_name )
if ( lokr_w1 is not None ) or ( lokr_w2 is not None ) or ( lokr_w1_a is not None ) or ( lokr_w2_a is not None ) :
patch_dict [ to_load [ x ] ] = ( lokr_w1 , lokr_w2 , alpha , lokr_w1_a , lokr_w1_b , lokr_w2_a , lokr_w2_b , lokr_t2 )
2023-02-03 07:06:34 +00:00
for x in lora . keys ( ) :
if x not in loaded_keys :
print ( " lora key not loaded " , x )
return patch_dict
def model_lora_keys ( model , key_map = { } ) :
sdk = model . state_dict ( ) . keys ( )
counter = 0
for b in range ( 12 ) :
2023-06-09 16:24:24 +00:00
tk = " diffusion_model.input_blocks. {} .1 " . format ( b )
2023-02-03 07:06:34 +00:00
up_counter = 0
2023-03-10 02:41:24 +00:00
for c in LORA_UNET_MAP_ATTENTIONS :
2023-02-03 07:06:34 +00:00
k = " {} . {} .weight " . format ( tk , c )
if k in sdk :
2023-03-10 02:41:24 +00:00
lora_key = " lora_unet_down_blocks_ {} _attentions_ {} _ {} " . format ( counter / / 2 , counter % 2 , LORA_UNET_MAP_ATTENTIONS [ c ] )
2023-02-05 19:36:28 +00:00
key_map [ lora_key ] = k
2023-02-03 07:06:34 +00:00
up_counter + = 1
if up_counter > = 4 :
counter + = 1
2023-03-10 02:41:24 +00:00
for c in LORA_UNET_MAP_ATTENTIONS :
2023-06-09 16:24:24 +00:00
k = " diffusion_model.middle_block.1. {} .weight " . format ( c )
2023-02-03 07:06:34 +00:00
if k in sdk :
2023-03-10 02:41:24 +00:00
lora_key = " lora_unet_mid_block_attentions_0_ {} " . format ( LORA_UNET_MAP_ATTENTIONS [ c ] )
2023-02-05 19:36:28 +00:00
key_map [ lora_key ] = k
2023-02-03 07:06:34 +00:00
counter = 3
for b in range ( 12 ) :
2023-06-09 16:24:24 +00:00
tk = " diffusion_model.output_blocks. {} .1 " . format ( b )
2023-02-03 07:06:34 +00:00
up_counter = 0
2023-03-10 02:41:24 +00:00
for c in LORA_UNET_MAP_ATTENTIONS :
2023-02-03 07:06:34 +00:00
k = " {} . {} .weight " . format ( tk , c )
if k in sdk :
2023-03-10 02:41:24 +00:00
lora_key = " lora_unet_up_blocks_ {} _attentions_ {} _ {} " . format ( counter / / 3 , counter % 3 , LORA_UNET_MAP_ATTENTIONS [ c ] )
2023-02-05 19:36:28 +00:00
key_map [ lora_key ] = k
2023-02-03 07:06:34 +00:00
up_counter + = 1
if up_counter > = 4 :
counter + = 1
counter = 0
2023-02-05 06:54:09 +00:00
text_model_lora_key = " lora_te_text_model_encoder_layers_ {} _ {} "
2023-06-28 06:22:49 +00:00
clip_l_present = False
for b in range ( 32 ) :
2023-02-03 07:06:34 +00:00
for c in LORA_CLIP_MAP :
k = " transformer.text_model.encoder.layers. {} . {} .weight " . format ( b , c )
if k in sdk :
2023-02-05 06:54:09 +00:00
lora_key = text_model_lora_key . format ( b , LORA_CLIP_MAP [ c ] )
2023-02-05 19:36:28 +00:00
key_map [ lora_key ] = k
2023-02-05 06:54:09 +00:00
2023-06-28 06:22:49 +00:00
k = " clip_l.transformer.text_model.encoder.layers. {} . {} .weight " . format ( b , c )
if k in sdk :
lora_key = " lora_te1_text_model_encoder_layers_ {} _ {} " . format ( b , LORA_CLIP_MAP [ c ] ) #SDXL base
key_map [ lora_key ] = k
clip_l_present = True
k = " clip_g.transformer.text_model.encoder.layers. {} . {} .weight " . format ( b , c )
if k in sdk :
if clip_l_present :
lora_key = " lora_te2_text_model_encoder_layers_ {} _ {} " . format ( b , LORA_CLIP_MAP [ c ] ) #SDXL base
else :
lora_key = " lora_te_text_model_encoder_layers_ {} _ {} " . format ( b , LORA_CLIP_MAP [ c ] ) #TODO: test if this is correct for SDXL-Refiner
key_map [ lora_key ] = k
2023-03-10 02:41:24 +00:00
#Locon stuff
ds_counter = 0
counter = 0
for b in range ( 12 ) :
2023-06-09 16:24:24 +00:00
tk = " diffusion_model.input_blocks. {} .0 " . format ( b )
2023-03-10 02:41:24 +00:00
key_in = False
for c in LORA_UNET_MAP_RESNET :
k = " {} . {} .weight " . format ( tk , c )
if k in sdk :
lora_key = " lora_unet_down_blocks_ {} _ {} " . format ( counter / / 2 , LORA_UNET_MAP_RESNET [ c ] . format ( counter % 2 ) )
key_map [ lora_key ] = k
key_in = True
for bb in range ( 3 ) :
k = " {} . {} .op.weight " . format ( tk [ : - 2 ] , bb )
if k in sdk :
lora_key = " lora_unet_down_blocks_ {} _downsamplers_0_conv " . format ( ds_counter )
key_map [ lora_key ] = k
ds_counter + = 1
if key_in :
counter + = 1
counter = 0
for b in range ( 3 ) :
2023-06-09 16:24:24 +00:00
tk = " diffusion_model.middle_block. {} " . format ( b )
2023-03-10 02:41:24 +00:00
key_in = False
for c in LORA_UNET_MAP_RESNET :
k = " {} . {} .weight " . format ( tk , c )
if k in sdk :
lora_key = " lora_unet_mid_block_ {} " . format ( LORA_UNET_MAP_RESNET [ c ] . format ( counter ) )
key_map [ lora_key ] = k
key_in = True
if key_in :
counter + = 1
counter = 0
us_counter = 0
for b in range ( 12 ) :
2023-06-09 16:24:24 +00:00
tk = " diffusion_model.output_blocks. {} .0 " . format ( b )
2023-03-10 02:41:24 +00:00
key_in = False
for c in LORA_UNET_MAP_RESNET :
k = " {} . {} .weight " . format ( tk , c )
if k in sdk :
lora_key = " lora_unet_up_blocks_ {} _ {} " . format ( counter / / 3 , LORA_UNET_MAP_RESNET [ c ] . format ( counter % 3 ) )
key_map [ lora_key ] = k
key_in = True
for bb in range ( 3 ) :
k = " {} . {} .conv.weight " . format ( tk [ : - 2 ] , bb )
if k in sdk :
lora_key = " lora_unet_up_blocks_ {} _upsamplers_0_conv " . format ( us_counter )
key_map [ lora_key ] = k
us_counter + = 1
if key_in :
counter + = 1
2023-06-26 06:56:11 +00:00
for k in sdk :
if k . startswith ( " diffusion_model. " ) and k . endswith ( " .weight " ) :
key_lora = k [ len ( " diffusion_model. " ) : - len ( " .weight " ) ] . replace ( " . " , " _ " )
key_map [ " lora_unet_ {} " . format ( key_lora ) ] = k
2023-02-03 07:06:34 +00:00
return key_map
2023-05-30 16:36:41 +00:00
2023-02-03 07:06:34 +00:00
class ModelPatcher :
2023-05-30 16:36:41 +00:00
def __init__ ( self , model , size = 0 ) :
self . size = size
2023-02-03 07:06:34 +00:00
self . model = model
self . patches = [ ]
self . backup = { }
2023-03-31 21:19:58 +00:00
self . model_options = { " transformer_options " : { } }
2023-05-30 16:36:41 +00:00
self . model_size ( )
def model_size ( self ) :
if self . size > 0 :
return self . size
model_sd = self . model . state_dict ( )
size = 0
for k in model_sd :
t = model_sd [ k ]
size + = t . nelement ( ) * t . element_size ( )
self . size = size
2023-06-20 23:08:48 +00:00
self . model_keys = set ( model_sd . keys ( ) )
2023-05-30 16:36:41 +00:00
return size
2023-02-03 07:06:34 +00:00
def clone ( self ) :
2023-05-30 16:36:41 +00:00
n = ModelPatcher ( self . model , self . size )
2023-02-03 07:06:34 +00:00
n . patches = self . patches [ : ]
2023-03-31 21:19:58 +00:00
n . model_options = copy . deepcopy ( self . model_options )
2023-06-20 23:08:48 +00:00
n . model_keys = self . model_keys
2023-02-03 07:06:34 +00:00
return n
2023-04-17 15:05:15 +00:00
def set_model_sampler_cfg_function ( self , sampler_cfg_function ) :
2023-06-13 20:05:26 +00:00
if len ( inspect . signature ( sampler_cfg_function ) . parameters ) == 3 :
self . model_options [ " sampler_cfg_function " ] = lambda args : sampler_cfg_function ( args [ " cond " ] , args [ " uncond " ] , args [ " cond_scale " ] ) #Old way
else :
self . model_options [ " sampler_cfg_function " ] = sampler_cfg_function
2023-04-23 16:35:25 +00:00
def set_model_patch ( self , patch , name ) :
to = self . model_options [ " transformer_options " ]
if " patches " not in to :
to [ " patches " ] = { }
to [ " patches " ] [ name ] = to [ " patches " ] . get ( name , [ ] ) + [ patch ]
2023-06-24 00:17:45 +00:00
def set_model_patch_replace ( self , patch , name , block_name , number ) :
to = self . model_options [ " transformer_options " ]
if " patches_replace " not in to :
to [ " patches_replace " ] = { }
if name not in to [ " patches_replace " ] :
to [ " patches_replace " ] [ name ] = { }
to [ " patches_replace " ] [ name ] [ ( block_name , number ) ] = patch
2023-04-23 16:35:25 +00:00
def set_model_attn1_patch ( self , patch ) :
self . set_model_patch ( patch , " attn1_patch " )
def set_model_attn2_patch ( self , patch ) :
self . set_model_patch ( patch , " attn2_patch " )
2023-06-24 00:17:45 +00:00
def set_model_attn1_replace ( self , patch , block_name , number ) :
self . set_model_patch_replace ( patch , " attn1 " , block_name , number )
def set_model_attn2_replace ( self , patch , block_name , number ) :
self . set_model_patch_replace ( patch , " attn2 " , block_name , number )
def set_model_attn1_output_patch ( self , patch ) :
self . set_model_patch ( patch , " attn1_output_patch " )
2023-06-19 02:58:22 +00:00
def set_model_attn2_output_patch ( self , patch ) :
self . set_model_patch ( patch , " attn2_output_patch " )
2023-04-23 16:35:25 +00:00
def model_patches_to ( self , device ) :
to = self . model_options [ " transformer_options " ]
if " patches " in to :
patches = to [ " patches " ]
for name in patches :
patch_list = patches [ name ]
for i in range ( len ( patch_list ) ) :
if hasattr ( patch_list [ i ] , " to " ) :
patch_list [ i ] = patch_list [ i ] . to ( device )
2023-06-24 00:17:45 +00:00
if " patches_replace " in to :
patches = to [ " patches_replace " ]
for name in patches :
patch_list = patches [ name ]
for k in patch_list :
if hasattr ( patch_list [ k ] , " to " ) :
patch_list [ k ] = patch_list [ k ] . to ( device )
2023-04-23 16:35:25 +00:00
2023-03-31 21:19:58 +00:00
def model_dtype ( self ) :
2023-06-09 16:24:24 +00:00
return self . model . get_dtype ( )
2023-03-31 21:19:58 +00:00
2023-06-20 21:34:11 +00:00
def add_patches ( self , patches , strength_patch = 1.0 , strength_model = 1.0 ) :
2023-02-03 07:06:34 +00:00
p = { }
for k in patches :
2023-06-20 23:08:48 +00:00
if k in self . model_keys :
2023-02-03 07:06:34 +00:00
p [ k ] = patches [ k ]
2023-06-20 21:34:11 +00:00
self . patches + = [ ( strength_patch , p , strength_model ) ]
2023-02-03 07:06:34 +00:00
return p . keys ( )
2023-06-20 23:37:43 +00:00
def model_state_dict ( self , filter_prefix = None ) :
2023-06-20 21:34:11 +00:00
sd = self . model . state_dict ( )
keys = list ( sd . keys ( ) )
2023-06-20 23:37:43 +00:00
if filter_prefix is not None :
for k in keys :
if not k . startswith ( filter_prefix ) :
sd . pop ( k )
2023-06-20 21:34:11 +00:00
return sd
2023-02-03 07:06:34 +00:00
def patch_model ( self ) :
2023-06-20 23:08:48 +00:00
model_sd = self . model_state_dict ( )
2023-02-03 07:06:34 +00:00
for p in self . patches :
for k in p [ 1 ] :
v = p [ 1 ] [ k ]
2023-02-05 19:36:28 +00:00
key = k
2023-02-05 06:54:09 +00:00
if key not in model_sd :
2023-02-03 07:06:34 +00:00
print ( " could not patch. key doesn ' t exist in model: " , k )
continue
2023-02-05 06:54:09 +00:00
weight = model_sd [ key ]
if key not in self . backup :
self . backup [ key ] = weight . clone ( )
2023-02-03 07:06:34 +00:00
alpha = p [ 0 ]
2023-06-20 21:34:11 +00:00
strength_model = p [ 2 ]
if strength_model != 1.0 :
weight * = strength_model
2023-03-23 07:40:12 +00:00
2023-06-20 21:34:11 +00:00
if len ( v ) == 1 :
2023-06-22 23:08:31 +00:00
w1 = v [ 0 ]
if w1 . shape != weight . shape :
print ( " WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {} " . format ( key , w1 . shape , weight . shape ) )
else :
weight + = alpha * w1 . type ( weight . dtype ) . to ( weight . device )
2023-06-20 21:34:11 +00:00
elif len ( v ) == 4 : #lora/locon
2023-03-23 07:40:12 +00:00
mat1 = v [ 0 ]
mat2 = v [ 1 ]
if v [ 2 ] is not None :
alpha * = v [ 2 ] / mat2 . shape [ 0 ]
if v [ 3 ] is not None :
#locon mid weights, hopefully the math is fine because I didn't properly test it
final_shape = [ mat2 . shape [ 1 ] , mat2 . shape [ 0 ] , v [ 3 ] . shape [ 2 ] , v [ 3 ] . shape [ 3 ] ]
mat2 = torch . mm ( mat2 . transpose ( 0 , 1 ) . flatten ( start_dim = 1 ) . float ( ) , v [ 3 ] . transpose ( 0 , 1 ) . flatten ( start_dim = 1 ) . float ( ) ) . reshape ( final_shape ) . transpose ( 0 , 1 )
weight + = ( alpha * torch . mm ( mat1 . flatten ( start_dim = 1 ) . float ( ) , mat2 . flatten ( start_dim = 1 ) . float ( ) ) ) . reshape ( weight . shape ) . type ( weight . dtype ) . to ( weight . device )
2023-05-01 22:11:58 +00:00
elif len ( v ) == 8 : #lokr
w1 = v [ 0 ]
w2 = v [ 1 ]
w1_a = v [ 3 ]
w1_b = v [ 4 ]
w2_a = v [ 5 ]
w2_b = v [ 6 ]
t2 = v [ 7 ]
dim = None
if w1 is None :
dim = w1_b . shape [ 0 ]
w1 = torch . mm ( w1_a . float ( ) , w1_b . float ( ) )
if w2 is None :
dim = w2_b . shape [ 0 ]
if t2 is None :
w2 = torch . mm ( w2_a . float ( ) , w2_b . float ( ) )
else :
w2 = torch . einsum ( ' i j k l, j r, i p -> p r k l ' , t2 . float ( ) , w2_b . float ( ) , w2_a . float ( ) )
if len ( w2 . shape ) == 4 :
w1 = w1 . unsqueeze ( 2 ) . unsqueeze ( 2 )
if v [ 2 ] is not None and dim is not None :
alpha * = v [ 2 ] / dim
weight + = alpha * torch . kron ( w1 . float ( ) , w2 . float ( ) ) . reshape ( weight . shape ) . type ( weight . dtype ) . to ( weight . device )
2023-03-23 07:40:12 +00:00
else : #loha
w1a = v [ 0 ]
w1b = v [ 1 ]
if v [ 2 ] is not None :
alpha * = v [ 2 ] / w1b . shape [ 0 ]
w2a = v [ 3 ]
w2b = v [ 4 ]
2023-03-23 08:32:25 +00:00
if v [ 5 ] is not None : #cp decomposition
t1 = v [ 5 ]
t2 = v [ 6 ]
m1 = torch . einsum ( ' i j k l, j r, i p -> p r k l ' , t1 . float ( ) , w1b . float ( ) , w1a . float ( ) )
m2 = torch . einsum ( ' i j k l, j r, i p -> p r k l ' , t2 . float ( ) , w2b . float ( ) , w2a . float ( ) )
else :
m1 = torch . mm ( w1a . float ( ) , w1b . float ( ) )
m2 = torch . mm ( w2a . float ( ) , w2b . float ( ) )
weight + = ( alpha * m1 * m2 ) . reshape ( weight . shape ) . type ( weight . dtype ) . to ( weight . device )
2023-02-03 07:06:34 +00:00
return self . model
def unpatch_model ( self ) :
2023-06-20 23:37:43 +00:00
model_sd = self . model_state_dict ( )
2023-02-28 17:27:43 +00:00
keys = list ( self . backup . keys ( ) )
for k in keys :
2023-02-03 07:06:34 +00:00
model_sd [ k ] [ : ] = self . backup [ k ]
2023-02-28 17:27:43 +00:00
del self . backup [ k ]
2023-02-03 07:06:34 +00:00
self . backup = { }
def load_lora_for_models ( model , clip , lora_path , strength_model , strength_clip ) :
key_map = model_lora_keys ( model . model )
key_map = model_lora_keys ( clip . cond_stage_model , key_map )
loaded = load_lora ( lora_path , key_map )
new_modelpatcher = model . clone ( )
k = new_modelpatcher . add_patches ( loaded , strength_model )
new_clip = clip . clone ( )
k1 = new_clip . add_patches ( loaded , strength_clip )
k = set ( k )
k1 = set ( k1 )
for x in loaded :
if ( x not in k ) and ( x not in k1 ) :
print ( " NOT LOADED " , x )
return ( new_modelpatcher , new_clip )
2023-01-03 06:53:32 +00:00
class CLIP :
2023-06-22 17:03:50 +00:00
def __init__ ( self , target = None , embedding_directory = None , no_init = False ) :
2023-02-03 07:06:34 +00:00
if no_init :
return
2023-06-22 17:03:50 +00:00
params = target . params
clip = target . clip
tokenizer = target . tokenizer
2023-01-29 23:46:44 +00:00
2023-06-15 19:21:37 +00:00
self . device = model_management . text_encoder_device ( )
params [ " device " ] = self . device
2023-01-29 23:46:44 +00:00
self . cond_stage_model = clip ( * * ( params ) )
2023-06-15 19:21:37 +00:00
self . cond_stage_model = self . cond_stage_model . to ( self . device )
2023-02-05 20:49:03 +00:00
self . tokenizer = tokenizer ( embedding_directory = embedding_directory )
2023-02-03 07:06:34 +00:00
self . patcher = ModelPatcher ( self . cond_stage_model )
2023-03-06 16:34:02 +00:00
self . layer_idx = None
2023-02-03 07:06:34 +00:00
def clone ( self ) :
n = CLIP ( no_init = True )
n . patcher = self . patcher . clone ( )
n . cond_stage_model = self . cond_stage_model
n . tokenizer = self . tokenizer
2023-03-03 18:04:36 +00:00
n . layer_idx = self . layer_idx
2023-06-22 17:03:50 +00:00
n . device = self . device
2023-02-03 07:06:34 +00:00
return n
2023-02-05 20:20:18 +00:00
def load_from_state_dict ( self , sd ) :
2023-06-25 05:40:38 +00:00
self . cond_stage_model . load_sd ( sd )
2023-02-05 20:20:18 +00:00
2023-02-03 07:06:34 +00:00
def add_patches ( self , patches , strength = 1.0 ) :
return self . patcher . add_patches ( patches , strength )
2023-01-03 06:53:32 +00:00
2023-02-05 20:20:18 +00:00
def clip_layer ( self , layer_idx ) :
2023-03-03 18:04:36 +00:00
self . layer_idx = layer_idx
2023-02-05 20:20:18 +00:00
2023-04-14 19:16:55 +00:00
def tokenize ( self , text , return_word_ids = False ) :
return self . tokenizer . tokenize_with_weights ( text , return_word_ids )
2023-04-13 20:06:50 +00:00
2023-04-19 13:36:19 +00:00
def encode_from_tokens ( self , tokens , return_pooled = False ) :
2023-03-06 16:34:02 +00:00
if self . layer_idx is not None :
self . cond_stage_model . clip_layer ( self . layer_idx )
2023-02-03 07:06:34 +00:00
try :
2023-06-26 16:21:07 +00:00
self . patch_model ( )
2023-06-22 17:03:50 +00:00
cond , pooled = self . cond_stage_model . encode_token_weights ( tokens )
2023-06-26 16:21:07 +00:00
self . unpatch_model ( )
2023-02-03 07:06:34 +00:00
except Exception as e :
2023-06-26 16:21:07 +00:00
self . unpatch_model ( )
2023-02-03 07:06:34 +00:00
raise e
2023-06-22 17:03:50 +00:00
cond_out = cond
2023-04-19 13:36:19 +00:00
if return_pooled :
2023-06-22 17:03:50 +00:00
return cond_out , pooled
return cond_out
2023-01-03 06:53:32 +00:00
2023-04-15 22:46:58 +00:00
def encode ( self , text ) :
2023-04-15 22:55:17 +00:00
tokens = self . tokenize ( text )
2023-04-15 22:46:58 +00:00
return self . encode_from_tokens ( tokens )
2023-06-25 05:40:38 +00:00
def load_sd ( self , sd ) :
return self . cond_stage_model . load_sd ( sd )
2023-06-22 17:03:50 +00:00
2023-06-26 16:21:07 +00:00
def get_sd ( self ) :
return self . cond_stage_model . state_dict ( )
def patch_model ( self ) :
self . patcher . patch_model ( )
def unpatch_model ( self ) :
self . patcher . unpatch_model ( )
2023-01-03 06:53:32 +00:00
class VAE :
2023-06-23 06:14:12 +00:00
def __init__ ( self , ckpt_path = None , device = None , config = None ) :
2023-01-03 06:53:32 +00:00
if config is None :
#default SD1.x/SD2.x VAE parameters
ddconfig = { ' double_z ' : True , ' z_channels ' : 4 , ' resolution ' : 256 , ' in_channels ' : 3 , ' out_ch ' : 3 , ' ch ' : 128 , ' ch_mult ' : [ 1 , 2 , 4 , 4 ] , ' num_res_blocks ' : 2 , ' attn_resolutions ' : [ ] , ' dropout ' : 0.0 }
2023-05-28 06:02:09 +00:00
self . first_stage_model = AutoencoderKL ( ddconfig , { ' target ' : ' torch.nn.Identity ' } , 4 , monitor = " val/rec_loss " )
2023-01-03 06:53:32 +00:00
else :
2023-05-28 06:02:09 +00:00
self . first_stage_model = AutoencoderKL ( * * ( config [ ' params ' ] ) )
2023-01-03 06:53:32 +00:00
self . first_stage_model = self . first_stage_model . eval ( )
2023-05-28 06:02:09 +00:00
if ckpt_path is not None :
sd = utils . load_torch_file ( ckpt_path )
if ' decoder.up_blocks.0.resnets.0.norm1.weight ' in sd . keys ( ) : #diffusers format
sd = diffusers_convert . convert_vae_state_dict ( sd )
self . first_stage_model . load_state_dict ( sd , strict = False )
2023-03-06 15:50:50 +00:00
if device is None :
device = model_management . get_torch_device ( )
2023-01-03 06:53:32 +00:00
self . device = device
2023-03-22 18:49:00 +00:00
def decode_tiled_ ( self , samples , tile_x = 64 , tile_y = 64 , overlap = 16 ) :
2023-05-03 16:33:19 +00:00
steps = samples . shape [ 0 ] * utils . get_tiled_scale_steps ( samples . shape [ 3 ] , samples . shape [ 2 ] , tile_x , tile_y , overlap )
2023-05-03 21:48:35 +00:00
steps + = samples . shape [ 0 ] * utils . get_tiled_scale_steps ( samples . shape [ 3 ] , samples . shape [ 2 ] , tile_x / / 2 , tile_y * 2 , overlap )
steps + = samples . shape [ 0 ] * utils . get_tiled_scale_steps ( samples . shape [ 3 ] , samples . shape [ 2 ] , tile_x * 2 , tile_y / / 2 , overlap )
2023-05-03 17:19:22 +00:00
pbar = utils . ProgressBar ( steps )
2023-04-24 10:55:44 +00:00
2023-06-23 06:14:12 +00:00
decode_fn = lambda a : ( self . first_stage_model . decode ( a . to ( self . device ) ) + 1.0 )
2023-03-22 18:49:00 +00:00
output = torch . clamp ( (
2023-04-24 10:55:44 +00:00
( utils . tiled_scale ( samples , decode_fn , tile_x / / 2 , tile_y * 2 , overlap , upscale_amount = 8 , pbar = pbar ) +
utils . tiled_scale ( samples , decode_fn , tile_x * 2 , tile_y / / 2 , overlap , upscale_amount = 8 , pbar = pbar ) +
utils . tiled_scale ( samples , decode_fn , tile_x , tile_y , overlap , upscale_amount = 8 , pbar = pbar ) )
2023-03-22 18:49:00 +00:00
/ 3.0 ) / 2.0 , min = 0.0 , max = 1.0 )
return output
2023-06-12 03:25:39 +00:00
def encode_tiled_ ( self , pixel_samples , tile_x = 512 , tile_y = 512 , overlap = 64 ) :
steps = pixel_samples . shape [ 0 ] * utils . get_tiled_scale_steps ( pixel_samples . shape [ 3 ] , pixel_samples . shape [ 2 ] , tile_x , tile_y , overlap )
steps + = pixel_samples . shape [ 0 ] * utils . get_tiled_scale_steps ( pixel_samples . shape [ 3 ] , pixel_samples . shape [ 2 ] , tile_x / / 2 , tile_y * 2 , overlap )
steps + = pixel_samples . shape [ 0 ] * utils . get_tiled_scale_steps ( pixel_samples . shape [ 3 ] , pixel_samples . shape [ 2 ] , tile_x * 2 , tile_y / / 2 , overlap )
pbar = utils . ProgressBar ( steps )
2023-06-23 06:14:12 +00:00
encode_fn = lambda a : self . first_stage_model . encode ( 2. * a . to ( self . device ) - 1. ) . sample ( )
2023-06-12 03:25:39 +00:00
samples = utils . tiled_scale ( pixel_samples , encode_fn , tile_x , tile_y , overlap , upscale_amount = ( 1 / 8 ) , out_channels = 4 , pbar = pbar )
samples + = utils . tiled_scale ( pixel_samples , encode_fn , tile_x * 2 , tile_y / / 2 , overlap , upscale_amount = ( 1 / 8 ) , out_channels = 4 , pbar = pbar )
samples + = utils . tiled_scale ( pixel_samples , encode_fn , tile_x / / 2 , tile_y * 2 , overlap , upscale_amount = ( 1 / 8 ) , out_channels = 4 , pbar = pbar )
samples / = 3.0
return samples
2023-03-22 18:49:00 +00:00
def decode ( self , samples_in ) :
2023-02-08 08:17:54 +00:00
model_management . unload_model ( )
2023-01-03 06:53:32 +00:00
self . first_stage_model = self . first_stage_model . to ( self . device )
2023-03-22 18:49:00 +00:00
try :
2023-03-29 06:24:37 +00:00
free_memory = model_management . get_free_memory ( self . device )
batch_number = int ( ( free_memory * 0.7 ) / ( 2562 * samples_in . shape [ 2 ] * samples_in . shape [ 3 ] * 64 ) )
batch_number = max ( 1 , batch_number )
pixel_samples = torch . empty ( ( samples_in . shape [ 0 ] , 3 , round ( samples_in . shape [ 2 ] * 8 ) , round ( samples_in . shape [ 3 ] * 8 ) ) , device = " cpu " )
for x in range ( 0 , samples_in . shape [ 0 ] , batch_number ) :
samples = samples_in [ x : x + batch_number ] . to ( self . device )
2023-06-23 06:14:12 +00:00
pixel_samples [ x : x + batch_number ] = torch . clamp ( ( self . first_stage_model . decode ( samples ) + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 ) . cpu ( )
2023-03-22 18:49:00 +00:00
except model_management . OOM_EXCEPTION as e :
print ( " Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding. " )
pixel_samples = self . decode_tiled_ ( samples_in )
2023-01-03 06:53:32 +00:00
self . first_stage_model = self . first_stage_model . cpu ( )
pixel_samples = pixel_samples . cpu ( ) . movedim ( 1 , - 1 )
return pixel_samples
2023-03-22 07:29:09 +00:00
def decode_tiled ( self , samples , tile_x = 64 , tile_y = 64 , overlap = 16 ) :
2023-02-24 07:10:10 +00:00
model_management . unload_model ( )
self . first_stage_model = self . first_stage_model . to ( self . device )
2023-03-22 18:49:00 +00:00
output = self . decode_tiled_ ( samples , tile_x , tile_y , overlap )
2023-02-24 07:10:10 +00:00
self . first_stage_model = self . first_stage_model . cpu ( )
return output . movedim ( 1 , - 1 )
2023-01-03 06:53:32 +00:00
def encode ( self , pixel_samples ) :
2023-02-08 08:17:54 +00:00
model_management . unload_model ( )
2023-01-03 06:53:32 +00:00
self . first_stage_model = self . first_stage_model . to ( self . device )
2023-06-12 03:25:39 +00:00
pixel_samples = pixel_samples . movedim ( - 1 , 1 )
try :
2023-06-12 04:21:50 +00:00
free_memory = model_management . get_free_memory ( self . device )
batch_number = int ( ( free_memory * 0.7 ) / ( 2078 * pixel_samples . shape [ 2 ] * pixel_samples . shape [ 3 ] ) ) #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
batch_number = max ( 1 , batch_number )
2023-06-12 03:25:39 +00:00
samples = torch . empty ( ( pixel_samples . shape [ 0 ] , 4 , round ( pixel_samples . shape [ 2 ] / / 8 ) , round ( pixel_samples . shape [ 3 ] / / 8 ) ) , device = " cpu " )
for x in range ( 0 , pixel_samples . shape [ 0 ] , batch_number ) :
pixels_in = ( 2. * pixel_samples [ x : x + batch_number ] - 1. ) . to ( self . device )
2023-06-23 06:14:12 +00:00
samples [ x : x + batch_number ] = self . first_stage_model . encode ( pixels_in ) . sample ( ) . cpu ( )
2023-06-12 04:21:50 +00:00
2023-06-12 03:25:39 +00:00
except model_management . OOM_EXCEPTION as e :
print ( " Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding. " )
samples = self . encode_tiled_ ( pixel_samples )
2023-01-03 06:53:32 +00:00
self . first_stage_model = self . first_stage_model . cpu ( )
return samples
2023-03-11 20:28:15 +00:00
def encode_tiled ( self , pixel_samples , tile_x = 512 , tile_y = 512 , overlap = 64 ) :
model_management . unload_model ( )
self . first_stage_model = self . first_stage_model . to ( self . device )
2023-06-12 03:25:39 +00:00
pixel_samples = pixel_samples . movedim ( - 1 , 1 )
samples = self . encode_tiled_ ( pixel_samples , tile_x = tile_x , tile_y = tile_y , overlap = overlap )
2023-03-11 20:28:15 +00:00
self . first_stage_model = self . first_stage_model . cpu ( )
return samples
2023-02-25 19:57:28 +00:00
2023-06-26 16:21:07 +00:00
def get_sd ( self ) :
return self . first_stage_model . state_dict ( )
2023-05-12 21:49:09 +00:00
def broadcast_image_to ( tensor , target_batch_size , batched_number ) :
2023-02-25 19:57:28 +00:00
current_batch_size = tensor . shape [ 0 ]
2023-05-12 21:57:40 +00:00
#print(current_batch_size, target_batch_size)
2023-02-25 19:57:28 +00:00
if current_batch_size == 1 :
return tensor
per_batch = target_batch_size / / batched_number
tensor = tensor [ : per_batch ]
if per_batch > tensor . shape [ 0 ] :
tensor = torch . cat ( [ tensor ] * ( per_batch / / tensor . shape [ 0 ] ) + [ tensor [ : ( per_batch % tensor . shape [ 0 ] ) ] ] , dim = 0 )
current_batch_size = tensor . shape [ 0 ]
if current_batch_size == target_batch_size :
return tensor
else :
return torch . cat ( [ tensor ] * batched_number , dim = 0 )
2023-02-16 15:38:08 +00:00
class ControlNet :
2023-06-03 05:47:21 +00:00
def __init__ ( self , control_model , global_average_pooling = False , device = None ) :
2023-02-16 15:38:08 +00:00
self . control_model = control_model
self . cond_hint_original = None
self . cond_hint = None
2023-02-16 23:08:01 +00:00
self . strength = 1.0
2023-03-06 15:50:50 +00:00
if device is None :
device = model_management . get_torch_device ( )
2023-02-17 20:31:38 +00:00
self . device = device
2023-02-21 06:18:53 +00:00
self . previous_controlnet = None
2023-06-03 05:47:21 +00:00
self . global_average_pooling = global_average_pooling
2023-02-16 15:38:08 +00:00
2023-06-22 17:03:50 +00:00
def get_control ( self , x_noisy , t , cond , batched_number ) :
2023-02-21 06:18:53 +00:00
control_prev = None
if self . previous_controlnet is not None :
2023-06-22 17:03:50 +00:00
control_prev = self . previous_controlnet . get_control ( x_noisy , t , cond , batched_number )
2023-02-21 06:18:53 +00:00
2023-02-17 20:31:38 +00:00
output_dtype = x_noisy . dtype
2023-02-16 15:38:08 +00:00
if self . cond_hint is None or x_noisy . shape [ 2 ] * 8 != self . cond_hint . shape [ 2 ] or x_noisy . shape [ 3 ] * 8 != self . cond_hint . shape [ 3 ] :
if self . cond_hint is not None :
del self . cond_hint
self . cond_hint = None
2023-05-12 21:49:09 +00:00
self . cond_hint = utils . common_upscale ( self . cond_hint_original , x_noisy . shape [ 3 ] * 8 , x_noisy . shape [ 2 ] * 8 , ' nearest-exact ' , " center " ) . to ( self . control_model . dtype ) . to ( self . device )
if x_noisy . shape [ 0 ] != self . cond_hint . shape [ 0 ] :
self . cond_hint = broadcast_image_to ( self . cond_hint , x_noisy . shape [ 0 ] , batched_number )
2023-02-17 20:31:38 +00:00
if self . control_model . dtype == torch . float16 :
precision_scope = torch . autocast
else :
precision_scope = contextlib . nullcontext
2023-03-06 15:50:50 +00:00
with precision_scope ( model_management . get_autocast_device ( self . device ) ) :
2023-02-17 20:45:29 +00:00
self . control_model = model_management . load_if_low_vram ( self . control_model )
2023-06-22 17:03:50 +00:00
context = torch . cat ( cond [ ' c_crossattn ' ] , 1 )
y = cond . get ( ' c_adm ' , None )
control = self . control_model ( x = x_noisy , hint = self . cond_hint , timesteps = t , context = context , y = y )
2023-02-17 20:45:29 +00:00
self . control_model = model_management . unload_if_low_vram ( self . control_model )
2023-02-25 05:55:42 +00:00
out = { ' middle ' : [ ] , ' output ' : [ ] }
2023-02-17 20:31:38 +00:00
autocast_enabled = torch . is_autocast_enabled ( )
2023-02-21 06:18:53 +00:00
for i in range ( len ( control ) ) :
2023-02-25 04:36:17 +00:00
if i == ( len ( control ) - 1 ) :
key = ' middle '
index = 0
else :
key = ' output '
index = i
2023-02-21 06:18:53 +00:00
x = control [ i ]
2023-06-03 05:47:21 +00:00
if self . global_average_pooling :
x = torch . mean ( x , dim = ( 2 , 3 ) , keepdim = True ) . repeat ( 1 , 1 , x . shape [ 2 ] , x . shape [ 3 ] )
2023-02-16 23:08:01 +00:00
x * = self . strength
2023-02-17 20:31:38 +00:00
if x . dtype != output_dtype and not autocast_enabled :
x = x . to ( output_dtype )
2023-02-21 06:18:53 +00:00
2023-02-25 04:36:17 +00:00
if control_prev is not None and key in control_prev :
prev = control_prev [ key ] [ index ]
if prev is not None :
x + = prev
out [ key ] . append ( x )
if control_prev is not None and ' input ' in control_prev :
out [ ' input ' ] = control_prev [ ' input ' ]
2023-02-17 20:31:38 +00:00
return out
2023-02-16 15:38:08 +00:00
2023-02-16 23:08:01 +00:00
def set_cond_hint ( self , cond_hint , strength = 1.0 ) :
2023-02-16 15:38:08 +00:00
self . cond_hint_original = cond_hint
2023-02-16 23:08:01 +00:00
self . strength = strength
2023-02-16 15:38:08 +00:00
return self
2023-02-21 06:18:53 +00:00
def set_previous_controlnet ( self , controlnet ) :
self . previous_controlnet = controlnet
return self
2023-02-16 15:38:08 +00:00
def cleanup ( self ) :
2023-02-21 06:18:53 +00:00
if self . previous_controlnet is not None :
self . previous_controlnet . cleanup ( )
2023-02-16 15:38:08 +00:00
if self . cond_hint is not None :
del self . cond_hint
self . cond_hint = None
def copy ( self ) :
2023-06-03 05:47:21 +00:00
c = ControlNet ( self . control_model , global_average_pooling = self . global_average_pooling )
2023-02-16 15:38:08 +00:00
c . cond_hint_original = self . cond_hint_original
2023-02-16 23:08:01 +00:00
c . strength = self . strength
2023-02-16 15:38:08 +00:00
return c
2023-04-19 13:36:19 +00:00
def get_models ( self ) :
2023-02-21 06:18:53 +00:00
out = [ ]
if self . previous_controlnet is not None :
2023-04-19 13:36:19 +00:00
out + = self . previous_controlnet . get_models ( )
2023-02-21 06:18:53 +00:00
out . append ( self . control_model )
return out
2023-02-23 04:22:03 +00:00
def load_controlnet ( ckpt_path , model = None ) :
2023-06-13 14:12:03 +00:00
controlnet_data = utils . load_torch_file ( ckpt_path , safe_load = True )
2023-06-22 17:03:50 +00:00
pth_key = ' control_model.zero_convs.0.0.weight '
2023-02-16 15:38:08 +00:00
pth = False
2023-06-22 17:03:50 +00:00
key = ' zero_convs.0.0.weight '
2023-02-16 15:38:08 +00:00
if pth_key in controlnet_data :
pth = True
key = pth_key
2023-06-22 17:03:50 +00:00
prefix = " control_model. "
2023-02-16 15:38:08 +00:00
elif key in controlnet_data :
2023-06-22 17:03:50 +00:00
prefix = " "
2023-02-16 15:38:08 +00:00
else :
2023-03-17 22:17:59 +00:00
net = load_t2i_adapter ( controlnet_data )
if net is None :
print ( " error checkpoint does not contain controlnet or t2i adapter data " , ckpt_path )
return net
2023-02-16 15:38:08 +00:00
2023-06-22 17:03:50 +00:00
use_fp16 = model_management . should_use_fp16 ( )
controlnet_config = model_detection . model_config_from_unet ( controlnet_data , prefix , use_fp16 ) . unet_config
controlnet_config . pop ( " out_channels " )
controlnet_config [ " hint_channels " ] = 3
control_model = cldm . ControlNet ( * * controlnet_config )
2023-02-16 15:38:08 +00:00
if pth :
2023-02-23 04:22:03 +00:00
if ' difference ' in controlnet_data :
if model is not None :
m = model . patch_model ( )
model_sd = m . state_dict ( )
for x in controlnet_data :
c_m = " control_model. "
if x . startswith ( c_m ) :
2023-06-09 16:24:24 +00:00
sd_key = " diffusion_model. {} " . format ( x [ len ( c_m ) : ] )
2023-02-23 04:22:03 +00:00
if sd_key in model_sd :
cd = controlnet_data [ x ]
cd + = model_sd [ sd_key ] . type ( cd . dtype ) . to ( cd . device )
model . unpatch_model ( )
else :
print ( " WARNING: Loaded a diff controlnet without a model. It will very likely not work. " )
2023-02-16 15:38:08 +00:00
class WeightsLoader ( torch . nn . Module ) :
pass
w = WeightsLoader ( )
w . control_model = control_model
2023-06-22 17:03:50 +00:00
missing , unexpected = w . load_state_dict ( controlnet_data , strict = False )
2023-02-16 15:38:08 +00:00
else :
2023-06-22 17:03:50 +00:00
missing , unexpected = control_model . load_state_dict ( controlnet_data , strict = False )
print ( missing , unexpected )
2023-02-16 15:38:08 +00:00
2023-03-19 14:50:38 +00:00
if use_fp16 :
control_model = control_model . half ( )
2023-06-03 05:47:21 +00:00
global_average_pooling = False
if ckpt_path . endswith ( " _shuffle.pth " ) or ckpt_path . endswith ( " _shuffle.safetensors " ) or ckpt_path . endswith ( " _shuffle_fp16.safetensors " ) : #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True
control = ControlNet ( control_model , global_average_pooling = global_average_pooling )
2023-02-16 15:38:08 +00:00
return control
2023-02-25 05:55:42 +00:00
class T2IAdapter :
2023-03-06 15:50:50 +00:00
def __init__ ( self , t2i_model , channels_in , device = None ) :
2023-02-25 05:55:42 +00:00
self . t2i_model = t2i_model
self . channels_in = channels_in
self . strength = 1.0
2023-03-06 15:50:50 +00:00
if device is None :
device = model_management . get_torch_device ( )
2023-02-25 05:55:42 +00:00
self . device = device
self . previous_controlnet = None
self . control_input = None
self . cond_hint_original = None
self . cond_hint = None
2023-06-22 17:03:50 +00:00
def get_control ( self , x_noisy , t , cond , batched_number ) :
2023-02-25 05:55:42 +00:00
control_prev = None
if self . previous_controlnet is not None :
2023-06-22 17:03:50 +00:00
control_prev = self . previous_controlnet . get_control ( x_noisy , t , cond , batched_number )
2023-02-25 05:55:42 +00:00
if self . cond_hint is None or x_noisy . shape [ 2 ] * 8 != self . cond_hint . shape [ 2 ] or x_noisy . shape [ 3 ] * 8 != self . cond_hint . shape [ 3 ] :
if self . cond_hint is not None :
del self . cond_hint
2023-05-12 21:49:09 +00:00
self . control_input = None
2023-02-25 05:55:42 +00:00
self . cond_hint = None
2023-05-12 21:49:09 +00:00
self . cond_hint = utils . common_upscale ( self . cond_hint_original , x_noisy . shape [ 3 ] * 8 , x_noisy . shape [ 2 ] * 8 , ' nearest-exact ' , " center " ) . float ( ) . to ( self . device )
2023-02-25 05:55:42 +00:00
if self . channels_in == 1 and self . cond_hint . shape [ 1 ] > 1 :
self . cond_hint = torch . mean ( self . cond_hint , 1 , keepdim = True )
2023-05-12 21:49:09 +00:00
if x_noisy . shape [ 0 ] != self . cond_hint . shape [ 0 ] :
self . cond_hint = broadcast_image_to ( self . cond_hint , x_noisy . shape [ 0 ] , batched_number )
if self . control_input is None :
2023-02-25 05:55:42 +00:00
self . t2i_model . to ( self . device )
self . control_input = self . t2i_model ( self . cond_hint )
self . t2i_model . cpu ( )
output_dtype = x_noisy . dtype
out = { ' input ' : [ ] }
2023-02-25 17:19:03 +00:00
autocast_enabled = torch . is_autocast_enabled ( )
2023-02-25 05:55:42 +00:00
for i in range ( len ( self . control_input ) ) :
key = ' input '
x = self . control_input [ i ] * self . strength
if x . dtype != output_dtype and not autocast_enabled :
x = x . to ( output_dtype )
if control_prev is not None and key in control_prev :
index = len ( control_prev [ key ] ) - i * 3 - 3
prev = control_prev [ key ] [ index ]
if prev is not None :
x + = prev
out [ key ] . insert ( 0 , None )
out [ key ] . insert ( 0 , None )
out [ key ] . insert ( 0 , x )
if control_prev is not None and ' input ' in control_prev :
for i in range ( len ( out [ ' input ' ] ) ) :
if out [ ' input ' ] [ i ] is None :
out [ ' input ' ] [ i ] = control_prev [ ' input ' ] [ i ]
if control_prev is not None and ' middle ' in control_prev :
out [ ' middle ' ] = control_prev [ ' middle ' ]
if control_prev is not None and ' output ' in control_prev :
out [ ' output ' ] = control_prev [ ' output ' ]
return out
def set_cond_hint ( self , cond_hint , strength = 1.0 ) :
self . cond_hint_original = cond_hint
self . strength = strength
return self
def set_previous_controlnet ( self , controlnet ) :
self . previous_controlnet = controlnet
return self
def copy ( self ) :
c = T2IAdapter ( self . t2i_model , self . channels_in )
c . cond_hint_original = self . cond_hint_original
c . strength = self . strength
return c
def cleanup ( self ) :
if self . previous_controlnet is not None :
self . previous_controlnet . cleanup ( )
if self . cond_hint is not None :
del self . cond_hint
self . cond_hint = None
2023-04-19 13:36:19 +00:00
def get_models ( self ) :
2023-02-25 05:55:42 +00:00
out = [ ]
if self . previous_controlnet is not None :
2023-04-19 13:36:19 +00:00
out + = self . previous_controlnet . get_models ( )
2023-02-25 05:55:42 +00:00
return out
2023-03-17 22:17:59 +00:00
def load_t2i_adapter ( t2i_data ) :
2023-03-03 23:58:22 +00:00
keys = t2i_data . keys ( )
2023-06-22 17:03:50 +00:00
if ' adapter ' in keys :
t2i_data = t2i_data [ ' adapter ' ]
keys = t2i_data . keys ( )
2023-03-05 23:39:25 +00:00
if " body.0.in_conv.weight " in keys :
2023-03-03 23:58:22 +00:00
cin = t2i_data [ ' body.0.in_conv.weight ' ] . shape [ 1 ]
model_ad = adapter . Adapter_light ( cin = cin , channels = [ 320 , 640 , 1280 , 1280 ] , nums_rb = 4 )
2023-03-17 22:17:59 +00:00
elif ' conv_in.weight ' in keys :
2023-03-03 23:58:22 +00:00
cin = t2i_data [ ' conv_in.weight ' ] . shape [ 1 ]
2023-06-22 17:03:50 +00:00
channel = t2i_data [ ' conv_in.weight ' ] . shape [ 0 ]
ksize = t2i_data [ ' body.0.block2.weight ' ] . shape [ 2 ]
use_conv = False
down_opts = list ( filter ( lambda a : a . endswith ( " down_opt.op.weight " ) , keys ) )
if len ( down_opts ) > 0 :
use_conv = True
model_ad = adapter . Adapter ( cin = cin , channels = [ channel , channel * 2 , channel * 4 , channel * 4 ] [ : 4 ] , nums_rb = 2 , ksize = ksize , sk = True , use_conv = use_conv )
2023-03-17 22:17:59 +00:00
else :
return None
2023-02-25 05:55:42 +00:00
model_ad . load_state_dict ( t2i_data )
return T2IAdapter ( model_ad , cin / / 64 )
2023-02-16 15:38:08 +00:00
2023-03-05 23:39:25 +00:00
class StyleModel :
def __init__ ( self , model , device = " cpu " ) :
self . model = model
def get_cond ( self , input ) :
return self . model ( input . last_hidden_state )
def load_style_model ( ckpt_path ) :
2023-06-13 14:12:03 +00:00
model_data = utils . load_torch_file ( ckpt_path , safe_load = True )
2023-03-05 23:39:25 +00:00
keys = model_data . keys ( )
if " style_embedding " in keys :
model = adapter . StyleAdapter ( width = 1024 , context_dim = 768 , num_head = 8 , n_layes = 3 , num_token = 8 )
else :
raise Exception ( " invalid style model {} " . format ( ckpt_path ) )
model . load_state_dict ( model_data )
return StyleModel ( model )
2023-06-25 05:40:38 +00:00
def load_clip ( ckpt_paths , embedding_directory = None ) :
clip_data = [ ]
for p in ckpt_paths :
clip_data . append ( utils . load_torch_file ( p , safe_load = True ) )
2023-06-24 17:56:46 +00:00
class EmptyClass :
pass
2023-06-25 05:40:38 +00:00
for i in range ( len ( clip_data ) ) :
if " transformer.resblocks.0.ln_1.weight " in clip_data [ i ] :
clip_data [ i ] = utils . transformers_convert ( clip_data [ i ] , " " , " text_model. " , 32 )
2023-06-24 17:56:46 +00:00
clip_target = EmptyClass ( )
clip_target . params = { }
2023-06-25 05:40:38 +00:00
if len ( clip_data ) == 1 :
if " text_model.encoder.layers.30.mlp.fc1.weight " in clip_data [ 0 ] :
clip_target . clip = sdxl_clip . SDXLRefinerClipModel
clip_target . tokenizer = sdxl_clip . SDXLTokenizer
elif " text_model.encoder.layers.22.mlp.fc1.weight " in clip_data [ 0 ] :
clip_target . clip = sd2_clip . SD2ClipModel
clip_target . tokenizer = sd2_clip . SD2Tokenizer
else :
clip_target . clip = sd1_clip . SD1ClipModel
clip_target . tokenizer = sd1_clip . SD1Tokenizer
2023-02-05 20:20:18 +00:00
else :
2023-06-25 05:40:38 +00:00
clip_target . clip = sdxl_clip . SDXLClipModel
clip_target . tokenizer = sdxl_clip . SDXLTokenizer
2023-06-24 17:56:46 +00:00
clip = CLIP ( clip_target , embedding_directory = embedding_directory )
2023-06-25 05:40:38 +00:00
for c in clip_data :
m , u = clip . load_sd ( c )
if len ( m ) > 0 :
print ( " clip missing: " , m )
if len ( u ) > 0 :
print ( " clip unexpected: " , u )
2023-02-05 20:20:18 +00:00
return clip
2023-01-03 06:53:32 +00:00
2023-04-19 13:36:19 +00:00
def load_gligen ( ckpt_path ) :
2023-06-13 14:12:03 +00:00
data = utils . load_torch_file ( ckpt_path , safe_load = True )
2023-04-19 13:36:19 +00:00
model = gligen . load_gligen ( data )
if model_management . should_use_fp16 ( ) :
model = model . half ( )
return model
2023-06-09 16:24:24 +00:00
def load_checkpoint ( config_path = None , ckpt_path = None , output_vae = True , output_clip = True , embedding_directory = None , state_dict = None , config = None ) :
2023-06-23 06:14:12 +00:00
#TODO: this function is a mess and should be removed eventually
2023-06-09 16:24:24 +00:00
if config is None :
with open ( config_path , ' r ' ) as stream :
config = yaml . safe_load ( stream )
2023-01-03 06:53:32 +00:00
model_config_params = config [ ' model ' ] [ ' params ' ]
clip_config = model_config_params [ ' cond_stage_config ' ]
scale_factor = model_config_params [ ' scale_factor ' ]
vae_config = model_config_params [ ' first_stage_config ' ]
2023-03-14 01:12:48 +00:00
fp16 = False
if " unet_config " in model_config_params :
if " params " in model_config_params [ " unet_config " ] :
2023-06-09 16:24:24 +00:00
unet_config = model_config_params [ " unet_config " ] [ " params " ]
if " use_fp16 " in unet_config :
fp16 = unet_config [ " use_fp16 " ]
noise_aug_config = None
if " noise_aug_config " in model_config_params :
noise_aug_config = model_config_params [ " noise_aug_config " ]
v_prediction = False
if " parameterization " in model_config_params :
if model_config_params [ " parameterization " ] == " v " :
v_prediction = True
2023-03-14 01:12:48 +00:00
2023-01-03 06:53:32 +00:00
clip = None
vae = None
class WeightsLoader ( torch . nn . Module ) :
pass
2023-06-22 17:03:50 +00:00
if state_dict is None :
state_dict = utils . load_torch_file ( ckpt_path )
2023-01-03 06:53:32 +00:00
2023-06-23 06:14:12 +00:00
class EmptyClass :
pass
model_config = EmptyClass ( )
model_config . unet_config = unet_config
from . import latent_formats
model_config . latent_format = latent_formats . SD15 ( scale_factor = scale_factor )
2023-06-09 16:24:24 +00:00
if config [ ' model ' ] [ " target " ] . endswith ( " LatentInpaintDiffusion " ) :
2023-06-23 06:14:12 +00:00
model = model_base . SDInpaint ( model_config , v_prediction = v_prediction )
2023-06-09 16:24:24 +00:00
elif config [ ' model ' ] [ " target " ] . endswith ( " ImageEmbeddingConditionedLatentDiffusion " ) :
2023-06-23 06:14:12 +00:00
model = model_base . SD21UNCLIP ( model_config , noise_aug_config [ " params " ] , v_prediction = v_prediction )
2023-06-09 16:24:24 +00:00
else :
2023-06-23 06:14:12 +00:00
model = model_base . BaseModel ( model_config , v_prediction = v_prediction )
2023-06-09 16:24:24 +00:00
2023-03-14 01:12:48 +00:00
if fp16 :
model = model . half ( )
2023-06-22 17:03:50 +00:00
model . load_model_weights ( state_dict , " model.diffusion_model. " )
if output_vae :
w = WeightsLoader ( )
2023-06-23 06:14:12 +00:00
vae = VAE ( config = vae_config )
2023-06-22 17:03:50 +00:00
w . first_stage_model = vae . first_stage_model
load_model_weights ( w , state_dict )
if output_clip :
w = WeightsLoader ( )
clip_target = EmptyClass ( )
2023-06-23 05:12:59 +00:00
clip_target . params = clip_config . get ( " params " , { } )
2023-06-22 17:03:50 +00:00
if clip_config [ " target " ] . endswith ( " FrozenOpenCLIPEmbedder " ) :
clip_target . clip = sd2_clip . SD2ClipModel
clip_target . tokenizer = sd2_clip . SD2Tokenizer
elif clip_config [ " target " ] . endswith ( " FrozenCLIPEmbedder " ) :
clip_target . clip = sd1_clip . SD1ClipModel
clip_target . tokenizer = sd1_clip . SD1Tokenizer
clip = CLIP ( clip_target , embedding_directory = embedding_directory )
w . cond_stage_model = clip . cond_stage_model
load_clip_weights ( w , state_dict )
2023-02-03 07:06:34 +00:00
return ( ModelPatcher ( model ) , clip , vae )
2023-03-03 08:37:35 +00:00
2023-04-02 03:19:15 +00:00
def load_checkpoint_guess_config ( ckpt_path , output_vae = True , output_clip = True , output_clipvision = False , embedding_directory = None ) :
sd = utils . load_torch_file ( ckpt_path )
2023-03-03 08:37:35 +00:00
sd_keys = sd . keys ( )
clip = None
2023-04-02 03:19:15 +00:00
clipvision = None
2023-03-03 08:37:35 +00:00
vae = None
2023-06-22 17:03:50 +00:00
model = None
clip_target = None
2023-03-03 08:37:35 +00:00
2023-03-03 16:07:10 +00:00
fp16 = model_management . should_use_fp16 ( )
2023-03-03 08:37:35 +00:00
class WeightsLoader ( torch . nn . Module ) :
pass
2023-06-22 17:03:50 +00:00
model_config = model_detection . model_config_from_unet ( sd , " model.diffusion_model. " , fp16 )
if model_config is None :
raise RuntimeError ( " ERROR: Could not detect model type of: {} " . format ( ckpt_path ) )
2023-04-02 03:19:15 +00:00
2023-06-22 17:03:50 +00:00
if model_config . clip_vision_prefix is not None :
2023-04-02 03:19:15 +00:00
if output_clipvision :
2023-06-23 05:08:05 +00:00
clipvision = clip_vision . load_clipvision_from_sd ( sd , model_config . clip_vision_prefix , True )
2023-03-03 08:37:35 +00:00
2023-06-22 17:03:50 +00:00
model = model_config . get_model ( sd )
model . load_model_weights ( sd , " model.diffusion_model. " )
2023-04-02 03:19:15 +00:00
2023-06-22 17:03:50 +00:00
if output_vae :
2023-06-23 06:14:12 +00:00
vae = VAE ( )
2023-06-22 17:03:50 +00:00
w = WeightsLoader ( )
w . first_stage_model = vae . first_stage_model
load_model_weights ( w , sd )
2023-03-03 08:37:35 +00:00
2023-06-22 17:03:50 +00:00
if output_clip :
w = WeightsLoader ( )
clip_target = model_config . clip_target ( )
clip = CLIP ( clip_target , embedding_directory = embedding_directory )
w . cond_stage_model = clip . cond_stage_model
sd = model_config . process_clip_state_dict ( sd )
load_model_weights ( w , sd )
2023-06-09 16:24:24 +00:00
2023-06-22 17:03:50 +00:00
left_over = sd . keys ( )
if len ( left_over ) > 0 :
print ( " left over keys: " , left_over )
2023-06-14 16:48:02 +00:00
2023-04-02 03:19:15 +00:00
return ( ModelPatcher ( model ) , clip , vae , clipvision )
2023-06-26 16:21:07 +00:00
def save_checkpoint ( output_path , model , clip , vae , metadata = None ) :
try :
model . patch_model ( )
clip . patch_model ( )
sd = model . model . state_dict_for_saving ( clip . get_sd ( ) , vae . get_sd ( ) )
utils . save_torch_file ( sd , output_path , metadata = metadata )
model . unpatch_model ( )
clip . unpatch_model ( )
except Exception as e :
model . unpatch_model ( )
clip . unpatch_model ( )
raise e