ModelPatcher Overhaul and Hook Support (#5583)
* Added hook_patches to ModelPatcher for weights (model)
* Initial changes to calc_cond_batch to eventually support hook_patches
* Added current_patcher property to BaseModel
* Consolidated add_hook_patches_as_diffs into add_hook_patches func, fixed fp8 support for model-as-lora feature
* Added call to initialize_timesteps on hooks in process_conds func, and added call prepare current keyframe on hooks in calc_cond_batch
* Added default_conds support in calc_cond_batch func
* Added initial set of hook-related nodes, added code to register hooks for loras/model-as-loras, small renaming/refactoring
* Made CLIP work with hook patches
* Added initial hook scheduling nodes, small renaming/refactoring
* Fixed MaxSpeed and default conds implementations
* Added support for adding weight hooks that aren't registered on the ModelPatcher at sampling time
* Made Set Clip Hooks node work with hooks from Create Hook nodes, began work on better Create Hook Model As LoRA node
* Initial work on adding 'model_as_lora' lora type to calculate_weight
* Continued work on simpler Create Hook Model As LoRA node, started to implement ModelPatcher callbacks, attachments, and additional_models
* Fix incorrect ref to create_hook_patches_clone after moving function
* Added injections support to ModelPatcher + necessary bookkeeping, added additional_models support in ModelPatcher, conds, and hooks
* Added wrappers to ModelPatcher to facilitate standardized function wrapping
* Started scaffolding for other hook types, refactored get_hooks_from_cond to organize hooks by type
* Fix skip_until_exit logic bug breaking injection after first run of model
* Updated clone_has_same_weights function to account for new ModelPatcher properties, improved AutoPatcherEjector usage in partially_load
* Added WrapperExecutor for non-classbound functions, added calc_cond_batch wrappers
* Refactored callbacks+wrappers to allow storing lists by id
* Added forward_timestep_embed_patch type, added helper functions on ModelPatcher for emb_patch and forward_timestep_embed_patch, added helper functions for removing callbacks/wrappers/additional_models by key, added custom_should_register prop to hooks
* Added get_attachment func on ModelPatcher
* Implement basic MemoryCounter system for determing with cached weights due to hooks should be offloaded in hooks_backup
* Modified ControlNet/T2IAdapter get_control function to receive transformer_options as additional parameter, made the model_options stored in extra_args in inner_sample be a clone of the original model_options instead of same ref
* Added create_model_options_clone func, modified type annotations to use __future__ so that I can use the better type annotations
* Refactored WrapperExecutor code to remove need for WrapperClassExecutor (now gone), added sampler.sample wrapper (pending review, will likely keep but will see what hacks this could currently let me get rid of in ACN/ADE)
* Added Combine versions of Cond/Cond Pair Set Props nodes, renamed Pair Cond to Cond Pair, fixed default conds never applying hooks (due to hooks key typo)
* Renamed Create Hook Model As LoRA nodes to make the test node the main one (more changes pending)
* Added uuid to conds in CFGGuider and uuids to transformer_options to allow uniquely identifying conds in batches during sampling
* Fixed models not being unloaded properly due to current_patcher reference; the current ComfyUI model cleanup code requires that nothing else has a reference to the ModelPatcher instances
* Fixed default conds not respecting hook keyframes, made keyframes not reset cache when strength is unchanged, fixed Cond Set Default Combine throwing error, fixed model-as-lora throwing error during calculate_weight after a recent ComfyUI update, small refactoring/scaffolding changes for hooks
* Changed CreateHookModelAsLoraTest to be the new CreateHookModelAsLora, rename old ones as 'direct' and will be removed prior to merge
* Added initial support within CLIP Text Encode (Prompt) node for scheduling weight hook CLIP strength via clip_start_percent/clip_end_percent on conds, added schedule_clip toggle to Set CLIP Hooks node, small cleanup/fixes
* Fix range check in get_hooks_for_clip_schedule so that proper keyframes get assigned to corresponding ranges
* Optimized CLIP hook scheduling to treat same strength as same keyframe
* Less fragile memory management.
* Make encode_from_tokens_scheduled call cleaner, rollback change in model_patcher.py for hook_patches_backup dict
* Fix issue.
* Remove useless function.
* Prevent and detect some types of memory leaks.
* Run garbage collector when switching workflow if needed.
* Moved WrappersMP/CallbacksMP/WrapperExecutor to patcher_extension.py
* Refactored code to store wrappers and callbacks in transformer_options, added apply_model and diffusion_model.forward wrappers
* Fix issue.
* Refactored hooks in calc_cond_batch to be part of get_area_and_mult tuple, added extra_hooks to ControlBase to allow custom controlnets w/ hooks, small cleanup and renaming
* Fixed inconsistency of results when schedule_clip is set to False, small renaming/typo fixing, added initial support for ControlNet extra_hooks to work in tandem with normal cond hooks, initial work on calc_cond_batch merging all subdicts in returned transformer_options
* Modified callbacks and wrappers so that unregistered types can be used, allowing custom_nodes to have their own unique callbacks/wrappers if desired
* Updated different hook types to reflect actual progress of implementation, initial scaffolding for working WrapperHook functionality
* Fixed existing weight hook_patches (pre-registered) not working properly for CLIP
* Removed Register/Direct hook nodes since they were present only for testing, removed diff-related weight hook calculation as improved_memory removes unload_model_clones and using sample time registered hooks is less hacky
* Added clip scheduling support to all other native ComfyUI text encoding nodes (sdxl, flux, hunyuan, sd3)
* Made WrapperHook functional, added another wrapper/callback getter, added ON_DETACH callback to ModelPatcher
* Made opt_hooks append by default instead of replace, renamed comfy.hooks set functions to be more accurate
* Added apply_to_conds to Set CLIP Hooks, modified relevant code to allow text encoding to automatically apply hooks to output conds when apply_to_conds is set to True
* Fix cached_hook_patches not respecting target_device/memory_counter results
* Fixed issue with setting weights from hooks instead of copying them, added additional memory_counter check when caching hook patches
* Remove unnecessary torch.no_grad calls for hook patches
* Increased MemoryCounter minimum memory to leave free by *2 until a better way to get inference memory estimate of currently loaded models exists
* For encode_from_tokens_scheduled, allow start_percent and end_percent in add_dict to limit which scheduled conds get encoded for optimization purposes
* Removed a .to call on results of calculate_weight in patch_hook_weight_to_device that was screwing up the intermediate results for fp8 prior to being passed into stochastic_rounding call
* Made encode_from_tokens_scheduled work when no hooks are set on patcher
* Small cleanup of comments
* Turn off hook patch caching when only 1 hook present in sampling, replace some current_hook = None with calls to self.patch_hooks(None) instead to avoid a potential edge case
* On Cond/Cond Pair nodes, removed opt_ prefix from optional inputs
* Allow both FLOATS and FLOAT for floats_strength input
* Revert change, does not work
* Made patch_hook_weight_to_device respect set_func and convert_func
* Make discard_model_sampling True by default
* Add changes manually from 'master' so merge conflict resolution goes more smoothly
* Cleaned up text encode nodes with just a single clip.encode_from_tokens_scheduled call
* Make sure encode_from_tokens_scheduled will respect use_clip_schedule on clip
* Made nodes in nodes_hooks be marked as experimental (beta)
* Add get_nested_additional_models for cases where additional_models could have their own additional_models, and add robustness for circular additional_models references
* Made finalize_default_conds area math consistent with other sampling code
* Changed 'opt_hooks' input of Cond/Cond Pair Set Default Combine nodes to 'hooks'
* Remove a couple old TODO's and a no longer necessary workaround
2024-12-02 19:51:02 +00:00
from __future__ import annotations
2023-01-03 06:53:32 +00:00
import torch
2024-02-16 18:29:04 +00:00
from enum import Enum
2024-03-10 15:37:08 +00:00
import logging
2023-01-03 06:53:32 +00:00
2023-04-15 22:55:17 +00:00
from comfy import model_management
ModelPatcher Overhaul and Hook Support (#5583)
* Added hook_patches to ModelPatcher for weights (model)
* Initial changes to calc_cond_batch to eventually support hook_patches
* Added current_patcher property to BaseModel
* Consolidated add_hook_patches_as_diffs into add_hook_patches func, fixed fp8 support for model-as-lora feature
* Added call to initialize_timesteps on hooks in process_conds func, and added call prepare current keyframe on hooks in calc_cond_batch
* Added default_conds support in calc_cond_batch func
* Added initial set of hook-related nodes, added code to register hooks for loras/model-as-loras, small renaming/refactoring
* Made CLIP work with hook patches
* Added initial hook scheduling nodes, small renaming/refactoring
* Fixed MaxSpeed and default conds implementations
* Added support for adding weight hooks that aren't registered on the ModelPatcher at sampling time
* Made Set Clip Hooks node work with hooks from Create Hook nodes, began work on better Create Hook Model As LoRA node
* Initial work on adding 'model_as_lora' lora type to calculate_weight
* Continued work on simpler Create Hook Model As LoRA node, started to implement ModelPatcher callbacks, attachments, and additional_models
* Fix incorrect ref to create_hook_patches_clone after moving function
* Added injections support to ModelPatcher + necessary bookkeeping, added additional_models support in ModelPatcher, conds, and hooks
* Added wrappers to ModelPatcher to facilitate standardized function wrapping
* Started scaffolding for other hook types, refactored get_hooks_from_cond to organize hooks by type
* Fix skip_until_exit logic bug breaking injection after first run of model
* Updated clone_has_same_weights function to account for new ModelPatcher properties, improved AutoPatcherEjector usage in partially_load
* Added WrapperExecutor for non-classbound functions, added calc_cond_batch wrappers
* Refactored callbacks+wrappers to allow storing lists by id
* Added forward_timestep_embed_patch type, added helper functions on ModelPatcher for emb_patch and forward_timestep_embed_patch, added helper functions for removing callbacks/wrappers/additional_models by key, added custom_should_register prop to hooks
* Added get_attachment func on ModelPatcher
* Implement basic MemoryCounter system for determing with cached weights due to hooks should be offloaded in hooks_backup
* Modified ControlNet/T2IAdapter get_control function to receive transformer_options as additional parameter, made the model_options stored in extra_args in inner_sample be a clone of the original model_options instead of same ref
* Added create_model_options_clone func, modified type annotations to use __future__ so that I can use the better type annotations
* Refactored WrapperExecutor code to remove need for WrapperClassExecutor (now gone), added sampler.sample wrapper (pending review, will likely keep but will see what hacks this could currently let me get rid of in ACN/ADE)
* Added Combine versions of Cond/Cond Pair Set Props nodes, renamed Pair Cond to Cond Pair, fixed default conds never applying hooks (due to hooks key typo)
* Renamed Create Hook Model As LoRA nodes to make the test node the main one (more changes pending)
* Added uuid to conds in CFGGuider and uuids to transformer_options to allow uniquely identifying conds in batches during sampling
* Fixed models not being unloaded properly due to current_patcher reference; the current ComfyUI model cleanup code requires that nothing else has a reference to the ModelPatcher instances
* Fixed default conds not respecting hook keyframes, made keyframes not reset cache when strength is unchanged, fixed Cond Set Default Combine throwing error, fixed model-as-lora throwing error during calculate_weight after a recent ComfyUI update, small refactoring/scaffolding changes for hooks
* Changed CreateHookModelAsLoraTest to be the new CreateHookModelAsLora, rename old ones as 'direct' and will be removed prior to merge
* Added initial support within CLIP Text Encode (Prompt) node for scheduling weight hook CLIP strength via clip_start_percent/clip_end_percent on conds, added schedule_clip toggle to Set CLIP Hooks node, small cleanup/fixes
* Fix range check in get_hooks_for_clip_schedule so that proper keyframes get assigned to corresponding ranges
* Optimized CLIP hook scheduling to treat same strength as same keyframe
* Less fragile memory management.
* Make encode_from_tokens_scheduled call cleaner, rollback change in model_patcher.py for hook_patches_backup dict
* Fix issue.
* Remove useless function.
* Prevent and detect some types of memory leaks.
* Run garbage collector when switching workflow if needed.
* Moved WrappersMP/CallbacksMP/WrapperExecutor to patcher_extension.py
* Refactored code to store wrappers and callbacks in transformer_options, added apply_model and diffusion_model.forward wrappers
* Fix issue.
* Refactored hooks in calc_cond_batch to be part of get_area_and_mult tuple, added extra_hooks to ControlBase to allow custom controlnets w/ hooks, small cleanup and renaming
* Fixed inconsistency of results when schedule_clip is set to False, small renaming/typo fixing, added initial support for ControlNet extra_hooks to work in tandem with normal cond hooks, initial work on calc_cond_batch merging all subdicts in returned transformer_options
* Modified callbacks and wrappers so that unregistered types can be used, allowing custom_nodes to have their own unique callbacks/wrappers if desired
* Updated different hook types to reflect actual progress of implementation, initial scaffolding for working WrapperHook functionality
* Fixed existing weight hook_patches (pre-registered) not working properly for CLIP
* Removed Register/Direct hook nodes since they were present only for testing, removed diff-related weight hook calculation as improved_memory removes unload_model_clones and using sample time registered hooks is less hacky
* Added clip scheduling support to all other native ComfyUI text encoding nodes (sdxl, flux, hunyuan, sd3)
* Made WrapperHook functional, added another wrapper/callback getter, added ON_DETACH callback to ModelPatcher
* Made opt_hooks append by default instead of replace, renamed comfy.hooks set functions to be more accurate
* Added apply_to_conds to Set CLIP Hooks, modified relevant code to allow text encoding to automatically apply hooks to output conds when apply_to_conds is set to True
* Fix cached_hook_patches not respecting target_device/memory_counter results
* Fixed issue with setting weights from hooks instead of copying them, added additional memory_counter check when caching hook patches
* Remove unnecessary torch.no_grad calls for hook patches
* Increased MemoryCounter minimum memory to leave free by *2 until a better way to get inference memory estimate of currently loaded models exists
* For encode_from_tokens_scheduled, allow start_percent and end_percent in add_dict to limit which scheduled conds get encoded for optimization purposes
* Removed a .to call on results of calculate_weight in patch_hook_weight_to_device that was screwing up the intermediate results for fp8 prior to being passed into stochastic_rounding call
* Made encode_from_tokens_scheduled work when no hooks are set on patcher
* Small cleanup of comments
* Turn off hook patch caching when only 1 hook present in sampling, replace some current_hook = None with calls to self.patch_hooks(None) instead to avoid a potential edge case
* On Cond/Cond Pair nodes, removed opt_ prefix from optional inputs
* Allow both FLOATS and FLOAT for floats_strength input
* Revert change, does not work
* Made patch_hook_weight_to_device respect set_func and convert_func
* Make discard_model_sampling True by default
* Add changes manually from 'master' so merge conflict resolution goes more smoothly
* Cleaned up text encode nodes with just a single clip.encode_from_tokens_scheduled call
* Make sure encode_from_tokens_scheduled will respect use_clip_schedule on clip
* Made nodes in nodes_hooks be marked as experimental (beta)
* Add get_nested_additional_models for cases where additional_models could have their own additional_models, and add robustness for circular additional_models references
* Made finalize_default_conds area math consistent with other sampling code
* Changed 'opt_hooks' input of Cond/Cond Pair Set Default Combine nodes to 'hooks'
* Remove a couple old TODO's and a no longer necessary workaround
2024-12-02 19:51:02 +00:00
from comfy . utils import ProgressBar
2023-10-17 18:51:51 +00:00
from . ldm . models . autoencoder import AutoencoderKL , AutoencodingEngine
2024-02-16 11:30:39 +00:00
from . ldm . cascade . stage_a import StageA
2024-02-19 09:06:49 +00:00
from . ldm . cascade . stage_c_coder import StageC_coder
2024-06-15 16:14:56 +00:00
from . ldm . audio . autoencoder import AudioOobleckVAE
2024-10-26 10:54:00 +00:00
import comfy . ldm . genmo . vae . model
2024-11-22 13:44:42 +00:00
import comfy . ldm . lightricks . vae . causal_video_autoencoder
2023-03-13 18:49:18 +00:00
import yaml
2023-02-16 15:38:08 +00:00
2023-08-25 21:25:39 +00:00
import comfy . utils
2023-04-02 03:19:15 +00:00
from . import clip_vision
2023-04-19 13:36:19 +00:00
from . import gligen
2023-05-28 06:02:09 +00:00
from . import diffusers_convert
2023-06-22 17:03:50 +00:00
from . import model_detection
2023-02-03 07:06:34 +00:00
2023-06-22 17:03:50 +00:00
from . import sd1_clip
2023-06-25 05:40:38 +00:00
from . import sdxl_clip
2024-07-28 05:19:20 +00:00
import comfy . text_encoders . sd2_clip
2024-07-15 21:36:24 +00:00
import comfy . text_encoders . sd3_clip
import comfy . text_encoders . sa_t5
2024-07-11 20:51:06 +00:00
import comfy . text_encoders . aura_t5
2024-07-25 22:21:08 +00:00
import comfy . text_encoders . hydit
2024-08-01 08:03:59 +00:00
import comfy . text_encoders . flux
2024-08-20 14:42:40 +00:00
import comfy . text_encoders . long_clipl
2024-10-26 10:54:00 +00:00
import comfy . text_encoders . genmo
2024-11-22 13:44:42 +00:00
import comfy . text_encoders . lt
2024-12-17 00:35:40 +00:00
import comfy . text_encoders . hunyuan_video
2023-06-09 16:24:24 +00:00
2023-08-28 18:49:18 +00:00
import comfy . model_patcher
2023-08-25 21:11:51 +00:00
import comfy . lora
2024-11-21 13:38:23 +00:00
import comfy . lora_convert
ModelPatcher Overhaul and Hook Support (#5583)
* Added hook_patches to ModelPatcher for weights (model)
* Initial changes to calc_cond_batch to eventually support hook_patches
* Added current_patcher property to BaseModel
* Consolidated add_hook_patches_as_diffs into add_hook_patches func, fixed fp8 support for model-as-lora feature
* Added call to initialize_timesteps on hooks in process_conds func, and added call prepare current keyframe on hooks in calc_cond_batch
* Added default_conds support in calc_cond_batch func
* Added initial set of hook-related nodes, added code to register hooks for loras/model-as-loras, small renaming/refactoring
* Made CLIP work with hook patches
* Added initial hook scheduling nodes, small renaming/refactoring
* Fixed MaxSpeed and default conds implementations
* Added support for adding weight hooks that aren't registered on the ModelPatcher at sampling time
* Made Set Clip Hooks node work with hooks from Create Hook nodes, began work on better Create Hook Model As LoRA node
* Initial work on adding 'model_as_lora' lora type to calculate_weight
* Continued work on simpler Create Hook Model As LoRA node, started to implement ModelPatcher callbacks, attachments, and additional_models
* Fix incorrect ref to create_hook_patches_clone after moving function
* Added injections support to ModelPatcher + necessary bookkeeping, added additional_models support in ModelPatcher, conds, and hooks
* Added wrappers to ModelPatcher to facilitate standardized function wrapping
* Started scaffolding for other hook types, refactored get_hooks_from_cond to organize hooks by type
* Fix skip_until_exit logic bug breaking injection after first run of model
* Updated clone_has_same_weights function to account for new ModelPatcher properties, improved AutoPatcherEjector usage in partially_load
* Added WrapperExecutor for non-classbound functions, added calc_cond_batch wrappers
* Refactored callbacks+wrappers to allow storing lists by id
* Added forward_timestep_embed_patch type, added helper functions on ModelPatcher for emb_patch and forward_timestep_embed_patch, added helper functions for removing callbacks/wrappers/additional_models by key, added custom_should_register prop to hooks
* Added get_attachment func on ModelPatcher
* Implement basic MemoryCounter system for determing with cached weights due to hooks should be offloaded in hooks_backup
* Modified ControlNet/T2IAdapter get_control function to receive transformer_options as additional parameter, made the model_options stored in extra_args in inner_sample be a clone of the original model_options instead of same ref
* Added create_model_options_clone func, modified type annotations to use __future__ so that I can use the better type annotations
* Refactored WrapperExecutor code to remove need for WrapperClassExecutor (now gone), added sampler.sample wrapper (pending review, will likely keep but will see what hacks this could currently let me get rid of in ACN/ADE)
* Added Combine versions of Cond/Cond Pair Set Props nodes, renamed Pair Cond to Cond Pair, fixed default conds never applying hooks (due to hooks key typo)
* Renamed Create Hook Model As LoRA nodes to make the test node the main one (more changes pending)
* Added uuid to conds in CFGGuider and uuids to transformer_options to allow uniquely identifying conds in batches during sampling
* Fixed models not being unloaded properly due to current_patcher reference; the current ComfyUI model cleanup code requires that nothing else has a reference to the ModelPatcher instances
* Fixed default conds not respecting hook keyframes, made keyframes not reset cache when strength is unchanged, fixed Cond Set Default Combine throwing error, fixed model-as-lora throwing error during calculate_weight after a recent ComfyUI update, small refactoring/scaffolding changes for hooks
* Changed CreateHookModelAsLoraTest to be the new CreateHookModelAsLora, rename old ones as 'direct' and will be removed prior to merge
* Added initial support within CLIP Text Encode (Prompt) node for scheduling weight hook CLIP strength via clip_start_percent/clip_end_percent on conds, added schedule_clip toggle to Set CLIP Hooks node, small cleanup/fixes
* Fix range check in get_hooks_for_clip_schedule so that proper keyframes get assigned to corresponding ranges
* Optimized CLIP hook scheduling to treat same strength as same keyframe
* Less fragile memory management.
* Make encode_from_tokens_scheduled call cleaner, rollback change in model_patcher.py for hook_patches_backup dict
* Fix issue.
* Remove useless function.
* Prevent and detect some types of memory leaks.
* Run garbage collector when switching workflow if needed.
* Moved WrappersMP/CallbacksMP/WrapperExecutor to patcher_extension.py
* Refactored code to store wrappers and callbacks in transformer_options, added apply_model and diffusion_model.forward wrappers
* Fix issue.
* Refactored hooks in calc_cond_batch to be part of get_area_and_mult tuple, added extra_hooks to ControlBase to allow custom controlnets w/ hooks, small cleanup and renaming
* Fixed inconsistency of results when schedule_clip is set to False, small renaming/typo fixing, added initial support for ControlNet extra_hooks to work in tandem with normal cond hooks, initial work on calc_cond_batch merging all subdicts in returned transformer_options
* Modified callbacks and wrappers so that unregistered types can be used, allowing custom_nodes to have their own unique callbacks/wrappers if desired
* Updated different hook types to reflect actual progress of implementation, initial scaffolding for working WrapperHook functionality
* Fixed existing weight hook_patches (pre-registered) not working properly for CLIP
* Removed Register/Direct hook nodes since they were present only for testing, removed diff-related weight hook calculation as improved_memory removes unload_model_clones and using sample time registered hooks is less hacky
* Added clip scheduling support to all other native ComfyUI text encoding nodes (sdxl, flux, hunyuan, sd3)
* Made WrapperHook functional, added another wrapper/callback getter, added ON_DETACH callback to ModelPatcher
* Made opt_hooks append by default instead of replace, renamed comfy.hooks set functions to be more accurate
* Added apply_to_conds to Set CLIP Hooks, modified relevant code to allow text encoding to automatically apply hooks to output conds when apply_to_conds is set to True
* Fix cached_hook_patches not respecting target_device/memory_counter results
* Fixed issue with setting weights from hooks instead of copying them, added additional memory_counter check when caching hook patches
* Remove unnecessary torch.no_grad calls for hook patches
* Increased MemoryCounter minimum memory to leave free by *2 until a better way to get inference memory estimate of currently loaded models exists
* For encode_from_tokens_scheduled, allow start_percent and end_percent in add_dict to limit which scheduled conds get encoded for optimization purposes
* Removed a .to call on results of calculate_weight in patch_hook_weight_to_device that was screwing up the intermediate results for fp8 prior to being passed into stochastic_rounding call
* Made encode_from_tokens_scheduled work when no hooks are set on patcher
* Small cleanup of comments
* Turn off hook patch caching when only 1 hook present in sampling, replace some current_hook = None with calls to self.patch_hooks(None) instead to avoid a potential edge case
* On Cond/Cond Pair nodes, removed opt_ prefix from optional inputs
* Allow both FLOATS and FLOAT for floats_strength input
* Revert change, does not work
* Made patch_hook_weight_to_device respect set_func and convert_func
* Make discard_model_sampling True by default
* Add changes manually from 'master' so merge conflict resolution goes more smoothly
* Cleaned up text encode nodes with just a single clip.encode_from_tokens_scheduled call
* Make sure encode_from_tokens_scheduled will respect use_clip_schedule on clip
* Made nodes in nodes_hooks be marked as experimental (beta)
* Add get_nested_additional_models for cases where additional_models could have their own additional_models, and add robustness for circular additional_models references
* Made finalize_default_conds area math consistent with other sampling code
* Changed 'opt_hooks' input of Cond/Cond Pair Set Default Combine nodes to 'hooks'
* Remove a couple old TODO's and a no longer necessary workaround
2024-12-02 19:51:02 +00:00
import comfy . hooks
2023-08-25 21:25:39 +00:00
import comfy . t2i_adapter . adapter
2023-11-21 17:54:19 +00:00
import comfy . taesd . taesd
2023-08-25 21:11:51 +00:00
2024-11-21 13:38:23 +00:00
import comfy . ldm . flux . redux
2023-06-30 03:40:02 +00:00
def load_lora_for_models ( model , clip , lora , strength_model , strength_clip ) :
2023-11-02 00:27:20 +00:00
key_map = { }
if model is not None :
key_map = comfy . lora . model_lora_keys_unet ( model . model , key_map )
if clip is not None :
key_map = comfy . lora . model_lora_keys_clip ( clip . cond_stage_model , key_map )
2024-11-21 13:38:23 +00:00
lora = comfy . lora_convert . convert_lora ( lora )
2023-08-25 21:11:51 +00:00
loaded = comfy . lora . load_lora ( lora , key_map )
2023-11-02 00:27:20 +00:00
if model is not None :
new_modelpatcher = model . clone ( )
k = new_modelpatcher . add_patches ( loaded , strength_model )
else :
k = ( )
new_modelpatcher = None
if clip is not None :
new_clip = clip . clone ( )
k1 = new_clip . add_patches ( loaded , strength_clip )
else :
k1 = ( )
new_clip = None
2023-02-03 07:06:34 +00:00
k = set ( k )
k1 = set ( k1 )
for x in loaded :
if ( x not in k ) and ( x not in k1 ) :
2024-03-10 15:37:08 +00:00
logging . warning ( " NOT LOADED {} " . format ( x ) )
2023-02-03 07:06:34 +00:00
return ( new_modelpatcher , new_clip )
2023-01-03 06:53:32 +00:00
class CLIP :
2024-08-17 14:15:13 +00:00
def __init__ ( self , target = None , embedding_directory = None , no_init = False , tokenizer_data = { } , parameters = 0 , model_options = { } ) :
2023-02-03 07:06:34 +00:00
if no_init :
return
2023-07-03 20:09:02 +00:00
params = target . params . copy ( )
2023-06-22 17:03:50 +00:00
clip = target . clip
tokenizer = target . tokenizer
2023-01-29 23:46:44 +00:00
2024-09-17 07:49:54 +00:00
load_device = model_options . get ( " load_device " , model_management . text_encoder_device ( ) )
offload_device = model_options . get ( " offload_device " , model_management . text_encoder_offload_device ( ) )
2024-08-17 14:15:13 +00:00
dtype = model_options . get ( " dtype " , None )
if dtype is None :
dtype = model_management . text_encoder_dtype ( load_device )
2024-06-11 21:03:26 +00:00
params [ ' dtype ' ] = dtype
2024-09-17 07:49:54 +00:00
params [ ' device ' ] = model_options . get ( " initial_device " , model_management . text_encoder_initial_device ( load_device , offload_device , parameters * model_management . dtype_size ( dtype ) ) )
2024-08-17 14:15:13 +00:00
params [ ' model_options ' ] = model_options
2023-08-24 01:01:15 +00:00
self . cond_stage_model = clip ( * * ( params ) )
2023-06-15 19:21:37 +00:00
2024-06-11 21:03:26 +00:00
for dt in self . cond_stage_model . dtypes :
if not model_management . supports_cast ( load_device , dt ) :
load_device = offload_device
2024-08-12 04:23:29 +00:00
if params [ ' device ' ] != offload_device :
self . cond_stage_model . to ( offload_device )
logging . warning ( " Had to shift TE back. " )
2024-06-11 21:03:26 +00:00
2024-07-24 20:43:53 +00:00
self . tokenizer = tokenizer ( embedding_directory = embedding_directory , tokenizer_data = tokenizer_data )
2023-08-28 18:49:18 +00:00
self . patcher = comfy . model_patcher . ModelPatcher ( self . cond_stage_model , load_device = load_device , offload_device = offload_device )
ModelPatcher Overhaul and Hook Support (#5583)
* Added hook_patches to ModelPatcher for weights (model)
* Initial changes to calc_cond_batch to eventually support hook_patches
* Added current_patcher property to BaseModel
* Consolidated add_hook_patches_as_diffs into add_hook_patches func, fixed fp8 support for model-as-lora feature
* Added call to initialize_timesteps on hooks in process_conds func, and added call prepare current keyframe on hooks in calc_cond_batch
* Added default_conds support in calc_cond_batch func
* Added initial set of hook-related nodes, added code to register hooks for loras/model-as-loras, small renaming/refactoring
* Made CLIP work with hook patches
* Added initial hook scheduling nodes, small renaming/refactoring
* Fixed MaxSpeed and default conds implementations
* Added support for adding weight hooks that aren't registered on the ModelPatcher at sampling time
* Made Set Clip Hooks node work with hooks from Create Hook nodes, began work on better Create Hook Model As LoRA node
* Initial work on adding 'model_as_lora' lora type to calculate_weight
* Continued work on simpler Create Hook Model As LoRA node, started to implement ModelPatcher callbacks, attachments, and additional_models
* Fix incorrect ref to create_hook_patches_clone after moving function
* Added injections support to ModelPatcher + necessary bookkeeping, added additional_models support in ModelPatcher, conds, and hooks
* Added wrappers to ModelPatcher to facilitate standardized function wrapping
* Started scaffolding for other hook types, refactored get_hooks_from_cond to organize hooks by type
* Fix skip_until_exit logic bug breaking injection after first run of model
* Updated clone_has_same_weights function to account for new ModelPatcher properties, improved AutoPatcherEjector usage in partially_load
* Added WrapperExecutor for non-classbound functions, added calc_cond_batch wrappers
* Refactored callbacks+wrappers to allow storing lists by id
* Added forward_timestep_embed_patch type, added helper functions on ModelPatcher for emb_patch and forward_timestep_embed_patch, added helper functions for removing callbacks/wrappers/additional_models by key, added custom_should_register prop to hooks
* Added get_attachment func on ModelPatcher
* Implement basic MemoryCounter system for determing with cached weights due to hooks should be offloaded in hooks_backup
* Modified ControlNet/T2IAdapter get_control function to receive transformer_options as additional parameter, made the model_options stored in extra_args in inner_sample be a clone of the original model_options instead of same ref
* Added create_model_options_clone func, modified type annotations to use __future__ so that I can use the better type annotations
* Refactored WrapperExecutor code to remove need for WrapperClassExecutor (now gone), added sampler.sample wrapper (pending review, will likely keep but will see what hacks this could currently let me get rid of in ACN/ADE)
* Added Combine versions of Cond/Cond Pair Set Props nodes, renamed Pair Cond to Cond Pair, fixed default conds never applying hooks (due to hooks key typo)
* Renamed Create Hook Model As LoRA nodes to make the test node the main one (more changes pending)
* Added uuid to conds in CFGGuider and uuids to transformer_options to allow uniquely identifying conds in batches during sampling
* Fixed models not being unloaded properly due to current_patcher reference; the current ComfyUI model cleanup code requires that nothing else has a reference to the ModelPatcher instances
* Fixed default conds not respecting hook keyframes, made keyframes not reset cache when strength is unchanged, fixed Cond Set Default Combine throwing error, fixed model-as-lora throwing error during calculate_weight after a recent ComfyUI update, small refactoring/scaffolding changes for hooks
* Changed CreateHookModelAsLoraTest to be the new CreateHookModelAsLora, rename old ones as 'direct' and will be removed prior to merge
* Added initial support within CLIP Text Encode (Prompt) node for scheduling weight hook CLIP strength via clip_start_percent/clip_end_percent on conds, added schedule_clip toggle to Set CLIP Hooks node, small cleanup/fixes
* Fix range check in get_hooks_for_clip_schedule so that proper keyframes get assigned to corresponding ranges
* Optimized CLIP hook scheduling to treat same strength as same keyframe
* Less fragile memory management.
* Make encode_from_tokens_scheduled call cleaner, rollback change in model_patcher.py for hook_patches_backup dict
* Fix issue.
* Remove useless function.
* Prevent and detect some types of memory leaks.
* Run garbage collector when switching workflow if needed.
* Moved WrappersMP/CallbacksMP/WrapperExecutor to patcher_extension.py
* Refactored code to store wrappers and callbacks in transformer_options, added apply_model and diffusion_model.forward wrappers
* Fix issue.
* Refactored hooks in calc_cond_batch to be part of get_area_and_mult tuple, added extra_hooks to ControlBase to allow custom controlnets w/ hooks, small cleanup and renaming
* Fixed inconsistency of results when schedule_clip is set to False, small renaming/typo fixing, added initial support for ControlNet extra_hooks to work in tandem with normal cond hooks, initial work on calc_cond_batch merging all subdicts in returned transformer_options
* Modified callbacks and wrappers so that unregistered types can be used, allowing custom_nodes to have their own unique callbacks/wrappers if desired
* Updated different hook types to reflect actual progress of implementation, initial scaffolding for working WrapperHook functionality
* Fixed existing weight hook_patches (pre-registered) not working properly for CLIP
* Removed Register/Direct hook nodes since they were present only for testing, removed diff-related weight hook calculation as improved_memory removes unload_model_clones and using sample time registered hooks is less hacky
* Added clip scheduling support to all other native ComfyUI text encoding nodes (sdxl, flux, hunyuan, sd3)
* Made WrapperHook functional, added another wrapper/callback getter, added ON_DETACH callback to ModelPatcher
* Made opt_hooks append by default instead of replace, renamed comfy.hooks set functions to be more accurate
* Added apply_to_conds to Set CLIP Hooks, modified relevant code to allow text encoding to automatically apply hooks to output conds when apply_to_conds is set to True
* Fix cached_hook_patches not respecting target_device/memory_counter results
* Fixed issue with setting weights from hooks instead of copying them, added additional memory_counter check when caching hook patches
* Remove unnecessary torch.no_grad calls for hook patches
* Increased MemoryCounter minimum memory to leave free by *2 until a better way to get inference memory estimate of currently loaded models exists
* For encode_from_tokens_scheduled, allow start_percent and end_percent in add_dict to limit which scheduled conds get encoded for optimization purposes
* Removed a .to call on results of calculate_weight in patch_hook_weight_to_device that was screwing up the intermediate results for fp8 prior to being passed into stochastic_rounding call
* Made encode_from_tokens_scheduled work when no hooks are set on patcher
* Small cleanup of comments
* Turn off hook patch caching when only 1 hook present in sampling, replace some current_hook = None with calls to self.patch_hooks(None) instead to avoid a potential edge case
* On Cond/Cond Pair nodes, removed opt_ prefix from optional inputs
* Allow both FLOATS and FLOAT for floats_strength input
* Revert change, does not work
* Made patch_hook_weight_to_device respect set_func and convert_func
* Make discard_model_sampling True by default
* Add changes manually from 'master' so merge conflict resolution goes more smoothly
* Cleaned up text encode nodes with just a single clip.encode_from_tokens_scheduled call
* Make sure encode_from_tokens_scheduled will respect use_clip_schedule on clip
* Made nodes in nodes_hooks be marked as experimental (beta)
* Add get_nested_additional_models for cases where additional_models could have their own additional_models, and add robustness for circular additional_models references
* Made finalize_default_conds area math consistent with other sampling code
* Changed 'opt_hooks' input of Cond/Cond Pair Set Default Combine nodes to 'hooks'
* Remove a couple old TODO's and a no longer necessary workaround
2024-12-02 19:51:02 +00:00
self . patcher . hook_mode = comfy . hooks . EnumHookMode . MinVram
self . patcher . is_clip = True
self . apply_hooks_to_conds = None
2024-08-12 04:06:01 +00:00
if params [ ' device ' ] == load_device :
2024-08-13 03:42:21 +00:00
model_management . load_models_gpu ( [ self . patcher ] , force_full_load = True )
2023-03-06 16:34:02 +00:00
self . layer_idx = None
ModelPatcher Overhaul and Hook Support (#5583)
* Added hook_patches to ModelPatcher for weights (model)
* Initial changes to calc_cond_batch to eventually support hook_patches
* Added current_patcher property to BaseModel
* Consolidated add_hook_patches_as_diffs into add_hook_patches func, fixed fp8 support for model-as-lora feature
* Added call to initialize_timesteps on hooks in process_conds func, and added call prepare current keyframe on hooks in calc_cond_batch
* Added default_conds support in calc_cond_batch func
* Added initial set of hook-related nodes, added code to register hooks for loras/model-as-loras, small renaming/refactoring
* Made CLIP work with hook patches
* Added initial hook scheduling nodes, small renaming/refactoring
* Fixed MaxSpeed and default conds implementations
* Added support for adding weight hooks that aren't registered on the ModelPatcher at sampling time
* Made Set Clip Hooks node work with hooks from Create Hook nodes, began work on better Create Hook Model As LoRA node
* Initial work on adding 'model_as_lora' lora type to calculate_weight
* Continued work on simpler Create Hook Model As LoRA node, started to implement ModelPatcher callbacks, attachments, and additional_models
* Fix incorrect ref to create_hook_patches_clone after moving function
* Added injections support to ModelPatcher + necessary bookkeeping, added additional_models support in ModelPatcher, conds, and hooks
* Added wrappers to ModelPatcher to facilitate standardized function wrapping
* Started scaffolding for other hook types, refactored get_hooks_from_cond to organize hooks by type
* Fix skip_until_exit logic bug breaking injection after first run of model
* Updated clone_has_same_weights function to account for new ModelPatcher properties, improved AutoPatcherEjector usage in partially_load
* Added WrapperExecutor for non-classbound functions, added calc_cond_batch wrappers
* Refactored callbacks+wrappers to allow storing lists by id
* Added forward_timestep_embed_patch type, added helper functions on ModelPatcher for emb_patch and forward_timestep_embed_patch, added helper functions for removing callbacks/wrappers/additional_models by key, added custom_should_register prop to hooks
* Added get_attachment func on ModelPatcher
* Implement basic MemoryCounter system for determing with cached weights due to hooks should be offloaded in hooks_backup
* Modified ControlNet/T2IAdapter get_control function to receive transformer_options as additional parameter, made the model_options stored in extra_args in inner_sample be a clone of the original model_options instead of same ref
* Added create_model_options_clone func, modified type annotations to use __future__ so that I can use the better type annotations
* Refactored WrapperExecutor code to remove need for WrapperClassExecutor (now gone), added sampler.sample wrapper (pending review, will likely keep but will see what hacks this could currently let me get rid of in ACN/ADE)
* Added Combine versions of Cond/Cond Pair Set Props nodes, renamed Pair Cond to Cond Pair, fixed default conds never applying hooks (due to hooks key typo)
* Renamed Create Hook Model As LoRA nodes to make the test node the main one (more changes pending)
* Added uuid to conds in CFGGuider and uuids to transformer_options to allow uniquely identifying conds in batches during sampling
* Fixed models not being unloaded properly due to current_patcher reference; the current ComfyUI model cleanup code requires that nothing else has a reference to the ModelPatcher instances
* Fixed default conds not respecting hook keyframes, made keyframes not reset cache when strength is unchanged, fixed Cond Set Default Combine throwing error, fixed model-as-lora throwing error during calculate_weight after a recent ComfyUI update, small refactoring/scaffolding changes for hooks
* Changed CreateHookModelAsLoraTest to be the new CreateHookModelAsLora, rename old ones as 'direct' and will be removed prior to merge
* Added initial support within CLIP Text Encode (Prompt) node for scheduling weight hook CLIP strength via clip_start_percent/clip_end_percent on conds, added schedule_clip toggle to Set CLIP Hooks node, small cleanup/fixes
* Fix range check in get_hooks_for_clip_schedule so that proper keyframes get assigned to corresponding ranges
* Optimized CLIP hook scheduling to treat same strength as same keyframe
* Less fragile memory management.
* Make encode_from_tokens_scheduled call cleaner, rollback change in model_patcher.py for hook_patches_backup dict
* Fix issue.
* Remove useless function.
* Prevent and detect some types of memory leaks.
* Run garbage collector when switching workflow if needed.
* Moved WrappersMP/CallbacksMP/WrapperExecutor to patcher_extension.py
* Refactored code to store wrappers and callbacks in transformer_options, added apply_model and diffusion_model.forward wrappers
* Fix issue.
* Refactored hooks in calc_cond_batch to be part of get_area_and_mult tuple, added extra_hooks to ControlBase to allow custom controlnets w/ hooks, small cleanup and renaming
* Fixed inconsistency of results when schedule_clip is set to False, small renaming/typo fixing, added initial support for ControlNet extra_hooks to work in tandem with normal cond hooks, initial work on calc_cond_batch merging all subdicts in returned transformer_options
* Modified callbacks and wrappers so that unregistered types can be used, allowing custom_nodes to have their own unique callbacks/wrappers if desired
* Updated different hook types to reflect actual progress of implementation, initial scaffolding for working WrapperHook functionality
* Fixed existing weight hook_patches (pre-registered) not working properly for CLIP
* Removed Register/Direct hook nodes since they were present only for testing, removed diff-related weight hook calculation as improved_memory removes unload_model_clones and using sample time registered hooks is less hacky
* Added clip scheduling support to all other native ComfyUI text encoding nodes (sdxl, flux, hunyuan, sd3)
* Made WrapperHook functional, added another wrapper/callback getter, added ON_DETACH callback to ModelPatcher
* Made opt_hooks append by default instead of replace, renamed comfy.hooks set functions to be more accurate
* Added apply_to_conds to Set CLIP Hooks, modified relevant code to allow text encoding to automatically apply hooks to output conds when apply_to_conds is set to True
* Fix cached_hook_patches not respecting target_device/memory_counter results
* Fixed issue with setting weights from hooks instead of copying them, added additional memory_counter check when caching hook patches
* Remove unnecessary torch.no_grad calls for hook patches
* Increased MemoryCounter minimum memory to leave free by *2 until a better way to get inference memory estimate of currently loaded models exists
* For encode_from_tokens_scheduled, allow start_percent and end_percent in add_dict to limit which scheduled conds get encoded for optimization purposes
* Removed a .to call on results of calculate_weight in patch_hook_weight_to_device that was screwing up the intermediate results for fp8 prior to being passed into stochastic_rounding call
* Made encode_from_tokens_scheduled work when no hooks are set on patcher
* Small cleanup of comments
* Turn off hook patch caching when only 1 hook present in sampling, replace some current_hook = None with calls to self.patch_hooks(None) instead to avoid a potential edge case
* On Cond/Cond Pair nodes, removed opt_ prefix from optional inputs
* Allow both FLOATS and FLOAT for floats_strength input
* Revert change, does not work
* Made patch_hook_weight_to_device respect set_func and convert_func
* Make discard_model_sampling True by default
* Add changes manually from 'master' so merge conflict resolution goes more smoothly
* Cleaned up text encode nodes with just a single clip.encode_from_tokens_scheduled call
* Make sure encode_from_tokens_scheduled will respect use_clip_schedule on clip
* Made nodes in nodes_hooks be marked as experimental (beta)
* Add get_nested_additional_models for cases where additional_models could have their own additional_models, and add robustness for circular additional_models references
* Made finalize_default_conds area math consistent with other sampling code
* Changed 'opt_hooks' input of Cond/Cond Pair Set Default Combine nodes to 'hooks'
* Remove a couple old TODO's and a no longer necessary workaround
2024-12-02 19:51:02 +00:00
self . use_clip_schedule = False
2024-08-12 03:50:01 +00:00
logging . debug ( " CLIP model load device: {} , offload device: {} , current: {} " . format ( load_device , offload_device , params [ ' device ' ] ) )
2023-02-03 07:06:34 +00:00
def clone ( self ) :
n = CLIP ( no_init = True )
n . patcher = self . patcher . clone ( )
n . cond_stage_model = self . cond_stage_model
n . tokenizer = self . tokenizer
2023-03-03 18:04:36 +00:00
n . layer_idx = self . layer_idx
ModelPatcher Overhaul and Hook Support (#5583)
* Added hook_patches to ModelPatcher for weights (model)
* Initial changes to calc_cond_batch to eventually support hook_patches
* Added current_patcher property to BaseModel
* Consolidated add_hook_patches_as_diffs into add_hook_patches func, fixed fp8 support for model-as-lora feature
* Added call to initialize_timesteps on hooks in process_conds func, and added call prepare current keyframe on hooks in calc_cond_batch
* Added default_conds support in calc_cond_batch func
* Added initial set of hook-related nodes, added code to register hooks for loras/model-as-loras, small renaming/refactoring
* Made CLIP work with hook patches
* Added initial hook scheduling nodes, small renaming/refactoring
* Fixed MaxSpeed and default conds implementations
* Added support for adding weight hooks that aren't registered on the ModelPatcher at sampling time
* Made Set Clip Hooks node work with hooks from Create Hook nodes, began work on better Create Hook Model As LoRA node
* Initial work on adding 'model_as_lora' lora type to calculate_weight
* Continued work on simpler Create Hook Model As LoRA node, started to implement ModelPatcher callbacks, attachments, and additional_models
* Fix incorrect ref to create_hook_patches_clone after moving function
* Added injections support to ModelPatcher + necessary bookkeeping, added additional_models support in ModelPatcher, conds, and hooks
* Added wrappers to ModelPatcher to facilitate standardized function wrapping
* Started scaffolding for other hook types, refactored get_hooks_from_cond to organize hooks by type
* Fix skip_until_exit logic bug breaking injection after first run of model
* Updated clone_has_same_weights function to account for new ModelPatcher properties, improved AutoPatcherEjector usage in partially_load
* Added WrapperExecutor for non-classbound functions, added calc_cond_batch wrappers
* Refactored callbacks+wrappers to allow storing lists by id
* Added forward_timestep_embed_patch type, added helper functions on ModelPatcher for emb_patch and forward_timestep_embed_patch, added helper functions for removing callbacks/wrappers/additional_models by key, added custom_should_register prop to hooks
* Added get_attachment func on ModelPatcher
* Implement basic MemoryCounter system for determing with cached weights due to hooks should be offloaded in hooks_backup
* Modified ControlNet/T2IAdapter get_control function to receive transformer_options as additional parameter, made the model_options stored in extra_args in inner_sample be a clone of the original model_options instead of same ref
* Added create_model_options_clone func, modified type annotations to use __future__ so that I can use the better type annotations
* Refactored WrapperExecutor code to remove need for WrapperClassExecutor (now gone), added sampler.sample wrapper (pending review, will likely keep but will see what hacks this could currently let me get rid of in ACN/ADE)
* Added Combine versions of Cond/Cond Pair Set Props nodes, renamed Pair Cond to Cond Pair, fixed default conds never applying hooks (due to hooks key typo)
* Renamed Create Hook Model As LoRA nodes to make the test node the main one (more changes pending)
* Added uuid to conds in CFGGuider and uuids to transformer_options to allow uniquely identifying conds in batches during sampling
* Fixed models not being unloaded properly due to current_patcher reference; the current ComfyUI model cleanup code requires that nothing else has a reference to the ModelPatcher instances
* Fixed default conds not respecting hook keyframes, made keyframes not reset cache when strength is unchanged, fixed Cond Set Default Combine throwing error, fixed model-as-lora throwing error during calculate_weight after a recent ComfyUI update, small refactoring/scaffolding changes for hooks
* Changed CreateHookModelAsLoraTest to be the new CreateHookModelAsLora, rename old ones as 'direct' and will be removed prior to merge
* Added initial support within CLIP Text Encode (Prompt) node for scheduling weight hook CLIP strength via clip_start_percent/clip_end_percent on conds, added schedule_clip toggle to Set CLIP Hooks node, small cleanup/fixes
* Fix range check in get_hooks_for_clip_schedule so that proper keyframes get assigned to corresponding ranges
* Optimized CLIP hook scheduling to treat same strength as same keyframe
* Less fragile memory management.
* Make encode_from_tokens_scheduled call cleaner, rollback change in model_patcher.py for hook_patches_backup dict
* Fix issue.
* Remove useless function.
* Prevent and detect some types of memory leaks.
* Run garbage collector when switching workflow if needed.
* Moved WrappersMP/CallbacksMP/WrapperExecutor to patcher_extension.py
* Refactored code to store wrappers and callbacks in transformer_options, added apply_model and diffusion_model.forward wrappers
* Fix issue.
* Refactored hooks in calc_cond_batch to be part of get_area_and_mult tuple, added extra_hooks to ControlBase to allow custom controlnets w/ hooks, small cleanup and renaming
* Fixed inconsistency of results when schedule_clip is set to False, small renaming/typo fixing, added initial support for ControlNet extra_hooks to work in tandem with normal cond hooks, initial work on calc_cond_batch merging all subdicts in returned transformer_options
* Modified callbacks and wrappers so that unregistered types can be used, allowing custom_nodes to have their own unique callbacks/wrappers if desired
* Updated different hook types to reflect actual progress of implementation, initial scaffolding for working WrapperHook functionality
* Fixed existing weight hook_patches (pre-registered) not working properly for CLIP
* Removed Register/Direct hook nodes since they were present only for testing, removed diff-related weight hook calculation as improved_memory removes unload_model_clones and using sample time registered hooks is less hacky
* Added clip scheduling support to all other native ComfyUI text encoding nodes (sdxl, flux, hunyuan, sd3)
* Made WrapperHook functional, added another wrapper/callback getter, added ON_DETACH callback to ModelPatcher
* Made opt_hooks append by default instead of replace, renamed comfy.hooks set functions to be more accurate
* Added apply_to_conds to Set CLIP Hooks, modified relevant code to allow text encoding to automatically apply hooks to output conds when apply_to_conds is set to True
* Fix cached_hook_patches not respecting target_device/memory_counter results
* Fixed issue with setting weights from hooks instead of copying them, added additional memory_counter check when caching hook patches
* Remove unnecessary torch.no_grad calls for hook patches
* Increased MemoryCounter minimum memory to leave free by *2 until a better way to get inference memory estimate of currently loaded models exists
* For encode_from_tokens_scheduled, allow start_percent and end_percent in add_dict to limit which scheduled conds get encoded for optimization purposes
* Removed a .to call on results of calculate_weight in patch_hook_weight_to_device that was screwing up the intermediate results for fp8 prior to being passed into stochastic_rounding call
* Made encode_from_tokens_scheduled work when no hooks are set on patcher
* Small cleanup of comments
* Turn off hook patch caching when only 1 hook present in sampling, replace some current_hook = None with calls to self.patch_hooks(None) instead to avoid a potential edge case
* On Cond/Cond Pair nodes, removed opt_ prefix from optional inputs
* Allow both FLOATS and FLOAT for floats_strength input
* Revert change, does not work
* Made patch_hook_weight_to_device respect set_func and convert_func
* Make discard_model_sampling True by default
* Add changes manually from 'master' so merge conflict resolution goes more smoothly
* Cleaned up text encode nodes with just a single clip.encode_from_tokens_scheduled call
* Make sure encode_from_tokens_scheduled will respect use_clip_schedule on clip
* Made nodes in nodes_hooks be marked as experimental (beta)
* Add get_nested_additional_models for cases where additional_models could have their own additional_models, and add robustness for circular additional_models references
* Made finalize_default_conds area math consistent with other sampling code
* Changed 'opt_hooks' input of Cond/Cond Pair Set Default Combine nodes to 'hooks'
* Remove a couple old TODO's and a no longer necessary workaround
2024-12-02 19:51:02 +00:00
n . use_clip_schedule = self . use_clip_schedule
n . apply_hooks_to_conds = self . apply_hooks_to_conds
2023-02-03 07:06:34 +00:00
return n
2023-07-14 06:37:30 +00:00
def add_patches ( self , patches , strength_patch = 1.0 , strength_model = 1.0 ) :
return self . patcher . add_patches ( patches , strength_patch , strength_model )
2023-01-03 06:53:32 +00:00
2023-02-05 20:20:18 +00:00
def clip_layer ( self , layer_idx ) :
2023-03-03 18:04:36 +00:00
self . layer_idx = layer_idx
2023-02-05 20:20:18 +00:00
2023-04-14 19:16:55 +00:00
def tokenize ( self , text , return_word_ids = False ) :
return self . tokenizer . tokenize_with_weights ( text , return_word_ids )
2023-04-13 20:06:50 +00:00
ModelPatcher Overhaul and Hook Support (#5583)
* Added hook_patches to ModelPatcher for weights (model)
* Initial changes to calc_cond_batch to eventually support hook_patches
* Added current_patcher property to BaseModel
* Consolidated add_hook_patches_as_diffs into add_hook_patches func, fixed fp8 support for model-as-lora feature
* Added call to initialize_timesteps on hooks in process_conds func, and added call prepare current keyframe on hooks in calc_cond_batch
* Added default_conds support in calc_cond_batch func
* Added initial set of hook-related nodes, added code to register hooks for loras/model-as-loras, small renaming/refactoring
* Made CLIP work with hook patches
* Added initial hook scheduling nodes, small renaming/refactoring
* Fixed MaxSpeed and default conds implementations
* Added support for adding weight hooks that aren't registered on the ModelPatcher at sampling time
* Made Set Clip Hooks node work with hooks from Create Hook nodes, began work on better Create Hook Model As LoRA node
* Initial work on adding 'model_as_lora' lora type to calculate_weight
* Continued work on simpler Create Hook Model As LoRA node, started to implement ModelPatcher callbacks, attachments, and additional_models
* Fix incorrect ref to create_hook_patches_clone after moving function
* Added injections support to ModelPatcher + necessary bookkeeping, added additional_models support in ModelPatcher, conds, and hooks
* Added wrappers to ModelPatcher to facilitate standardized function wrapping
* Started scaffolding for other hook types, refactored get_hooks_from_cond to organize hooks by type
* Fix skip_until_exit logic bug breaking injection after first run of model
* Updated clone_has_same_weights function to account for new ModelPatcher properties, improved AutoPatcherEjector usage in partially_load
* Added WrapperExecutor for non-classbound functions, added calc_cond_batch wrappers
* Refactored callbacks+wrappers to allow storing lists by id
* Added forward_timestep_embed_patch type, added helper functions on ModelPatcher for emb_patch and forward_timestep_embed_patch, added helper functions for removing callbacks/wrappers/additional_models by key, added custom_should_register prop to hooks
* Added get_attachment func on ModelPatcher
* Implement basic MemoryCounter system for determing with cached weights due to hooks should be offloaded in hooks_backup
* Modified ControlNet/T2IAdapter get_control function to receive transformer_options as additional parameter, made the model_options stored in extra_args in inner_sample be a clone of the original model_options instead of same ref
* Added create_model_options_clone func, modified type annotations to use __future__ so that I can use the better type annotations
* Refactored WrapperExecutor code to remove need for WrapperClassExecutor (now gone), added sampler.sample wrapper (pending review, will likely keep but will see what hacks this could currently let me get rid of in ACN/ADE)
* Added Combine versions of Cond/Cond Pair Set Props nodes, renamed Pair Cond to Cond Pair, fixed default conds never applying hooks (due to hooks key typo)
* Renamed Create Hook Model As LoRA nodes to make the test node the main one (more changes pending)
* Added uuid to conds in CFGGuider and uuids to transformer_options to allow uniquely identifying conds in batches during sampling
* Fixed models not being unloaded properly due to current_patcher reference; the current ComfyUI model cleanup code requires that nothing else has a reference to the ModelPatcher instances
* Fixed default conds not respecting hook keyframes, made keyframes not reset cache when strength is unchanged, fixed Cond Set Default Combine throwing error, fixed model-as-lora throwing error during calculate_weight after a recent ComfyUI update, small refactoring/scaffolding changes for hooks
* Changed CreateHookModelAsLoraTest to be the new CreateHookModelAsLora, rename old ones as 'direct' and will be removed prior to merge
* Added initial support within CLIP Text Encode (Prompt) node for scheduling weight hook CLIP strength via clip_start_percent/clip_end_percent on conds, added schedule_clip toggle to Set CLIP Hooks node, small cleanup/fixes
* Fix range check in get_hooks_for_clip_schedule so that proper keyframes get assigned to corresponding ranges
* Optimized CLIP hook scheduling to treat same strength as same keyframe
* Less fragile memory management.
* Make encode_from_tokens_scheduled call cleaner, rollback change in model_patcher.py for hook_patches_backup dict
* Fix issue.
* Remove useless function.
* Prevent and detect some types of memory leaks.
* Run garbage collector when switching workflow if needed.
* Moved WrappersMP/CallbacksMP/WrapperExecutor to patcher_extension.py
* Refactored code to store wrappers and callbacks in transformer_options, added apply_model and diffusion_model.forward wrappers
* Fix issue.
* Refactored hooks in calc_cond_batch to be part of get_area_and_mult tuple, added extra_hooks to ControlBase to allow custom controlnets w/ hooks, small cleanup and renaming
* Fixed inconsistency of results when schedule_clip is set to False, small renaming/typo fixing, added initial support for ControlNet extra_hooks to work in tandem with normal cond hooks, initial work on calc_cond_batch merging all subdicts in returned transformer_options
* Modified callbacks and wrappers so that unregistered types can be used, allowing custom_nodes to have their own unique callbacks/wrappers if desired
* Updated different hook types to reflect actual progress of implementation, initial scaffolding for working WrapperHook functionality
* Fixed existing weight hook_patches (pre-registered) not working properly for CLIP
* Removed Register/Direct hook nodes since they were present only for testing, removed diff-related weight hook calculation as improved_memory removes unload_model_clones and using sample time registered hooks is less hacky
* Added clip scheduling support to all other native ComfyUI text encoding nodes (sdxl, flux, hunyuan, sd3)
* Made WrapperHook functional, added another wrapper/callback getter, added ON_DETACH callback to ModelPatcher
* Made opt_hooks append by default instead of replace, renamed comfy.hooks set functions to be more accurate
* Added apply_to_conds to Set CLIP Hooks, modified relevant code to allow text encoding to automatically apply hooks to output conds when apply_to_conds is set to True
* Fix cached_hook_patches not respecting target_device/memory_counter results
* Fixed issue with setting weights from hooks instead of copying them, added additional memory_counter check when caching hook patches
* Remove unnecessary torch.no_grad calls for hook patches
* Increased MemoryCounter minimum memory to leave free by *2 until a better way to get inference memory estimate of currently loaded models exists
* For encode_from_tokens_scheduled, allow start_percent and end_percent in add_dict to limit which scheduled conds get encoded for optimization purposes
* Removed a .to call on results of calculate_weight in patch_hook_weight_to_device that was screwing up the intermediate results for fp8 prior to being passed into stochastic_rounding call
* Made encode_from_tokens_scheduled work when no hooks are set on patcher
* Small cleanup of comments
* Turn off hook patch caching when only 1 hook present in sampling, replace some current_hook = None with calls to self.patch_hooks(None) instead to avoid a potential edge case
* On Cond/Cond Pair nodes, removed opt_ prefix from optional inputs
* Allow both FLOATS and FLOAT for floats_strength input
* Revert change, does not work
* Made patch_hook_weight_to_device respect set_func and convert_func
* Make discard_model_sampling True by default
* Add changes manually from 'master' so merge conflict resolution goes more smoothly
* Cleaned up text encode nodes with just a single clip.encode_from_tokens_scheduled call
* Make sure encode_from_tokens_scheduled will respect use_clip_schedule on clip
* Made nodes in nodes_hooks be marked as experimental (beta)
* Add get_nested_additional_models for cases where additional_models could have their own additional_models, and add robustness for circular additional_models references
* Made finalize_default_conds area math consistent with other sampling code
* Changed 'opt_hooks' input of Cond/Cond Pair Set Default Combine nodes to 'hooks'
* Remove a couple old TODO's and a no longer necessary workaround
2024-12-02 19:51:02 +00:00
def add_hooks_to_dict ( self , pooled_dict : dict [ str ] ) :
if self . apply_hooks_to_conds :
pooled_dict [ " hooks " ] = self . apply_hooks_to_conds
return pooled_dict
def encode_from_tokens_scheduled ( self , tokens , unprojected = False , add_dict : dict [ str ] = { } , show_pbar = True ) :
all_cond_pooled : list [ tuple [ torch . Tensor , dict [ str ] ] ] = [ ]
all_hooks = self . patcher . forced_hooks
if all_hooks is None or not self . use_clip_schedule :
# if no hooks or shouldn't use clip schedule, do unscheduled encode_from_tokens and perform add_dict
return_pooled = " unprojected " if unprojected else True
pooled_dict = self . encode_from_tokens ( tokens , return_pooled = return_pooled , return_dict = True )
cond = pooled_dict . pop ( " cond " )
# add/update any keys with the provided add_dict
pooled_dict . update ( add_dict )
all_cond_pooled . append ( [ cond , pooled_dict ] )
else :
scheduled_keyframes = all_hooks . get_hooks_for_clip_schedule ( )
self . cond_stage_model . reset_clip_options ( )
if self . layer_idx is not None :
self . cond_stage_model . set_clip_options ( { " layer " : self . layer_idx } )
if unprojected :
self . cond_stage_model . set_clip_options ( { " projected_pooled " : False } )
self . load_model ( )
all_hooks . reset ( )
self . patcher . patch_hooks ( None )
if show_pbar :
pbar = ProgressBar ( len ( scheduled_keyframes ) )
for scheduled_opts in scheduled_keyframes :
t_range = scheduled_opts [ 0 ]
# don't bother encoding any conds outside of start_percent and end_percent bounds
if " start_percent " in add_dict :
if t_range [ 1 ] < add_dict [ " start_percent " ] :
continue
if " end_percent " in add_dict :
if t_range [ 0 ] > add_dict [ " end_percent " ] :
continue
hooks_keyframes = scheduled_opts [ 1 ]
for hook , keyframe in hooks_keyframes :
hook . hook_keyframe . _current_keyframe = keyframe
# apply appropriate hooks with values that match new hook_keyframe
self . patcher . patch_hooks ( all_hooks )
# perform encoding as normal
o = self . cond_stage_model . encode_token_weights ( tokens )
cond , pooled = o [ : 2 ]
pooled_dict = { " pooled_output " : pooled }
# add clip_start_percent and clip_end_percent in pooled
pooled_dict [ " clip_start_percent " ] = t_range [ 0 ]
pooled_dict [ " clip_end_percent " ] = t_range [ 1 ]
# add/update any keys with the provided add_dict
pooled_dict . update ( add_dict )
# add hooks stored on clip
self . add_hooks_to_dict ( pooled_dict )
all_cond_pooled . append ( [ cond , pooled_dict ] )
if show_pbar :
pbar . update ( 1 )
model_management . throw_exception_if_processing_interrupted ( )
all_hooks . reset ( )
return all_cond_pooled
2024-07-11 00:06:50 +00:00
def encode_from_tokens ( self , tokens , return_pooled = False , return_dict = False ) :
2024-02-25 12:20:31 +00:00
self . cond_stage_model . reset_clip_options ( )
2023-03-06 16:34:02 +00:00
if self . layer_idx is not None :
2024-02-25 12:20:31 +00:00
self . cond_stage_model . set_clip_options ( { " layer " : self . layer_idx } )
if return_pooled == " unprojected " :
self . cond_stage_model . set_clip_options ( { " projected_pooled " : False } )
2023-07-01 17:22:51 +00:00
2023-08-17 14:58:59 +00:00
self . load_model ( )
2024-07-11 00:06:50 +00:00
o = self . cond_stage_model . encode_token_weights ( tokens )
cond , pooled = o [ : 2 ]
if return_dict :
out = { " cond " : cond , " pooled_output " : pooled }
if len ( o ) > 2 :
for k in o [ 2 ] :
out [ k ] = o [ 2 ] [ k ]
ModelPatcher Overhaul and Hook Support (#5583)
* Added hook_patches to ModelPatcher for weights (model)
* Initial changes to calc_cond_batch to eventually support hook_patches
* Added current_patcher property to BaseModel
* Consolidated add_hook_patches_as_diffs into add_hook_patches func, fixed fp8 support for model-as-lora feature
* Added call to initialize_timesteps on hooks in process_conds func, and added call prepare current keyframe on hooks in calc_cond_batch
* Added default_conds support in calc_cond_batch func
* Added initial set of hook-related nodes, added code to register hooks for loras/model-as-loras, small renaming/refactoring
* Made CLIP work with hook patches
* Added initial hook scheduling nodes, small renaming/refactoring
* Fixed MaxSpeed and default conds implementations
* Added support for adding weight hooks that aren't registered on the ModelPatcher at sampling time
* Made Set Clip Hooks node work with hooks from Create Hook nodes, began work on better Create Hook Model As LoRA node
* Initial work on adding 'model_as_lora' lora type to calculate_weight
* Continued work on simpler Create Hook Model As LoRA node, started to implement ModelPatcher callbacks, attachments, and additional_models
* Fix incorrect ref to create_hook_patches_clone after moving function
* Added injections support to ModelPatcher + necessary bookkeeping, added additional_models support in ModelPatcher, conds, and hooks
* Added wrappers to ModelPatcher to facilitate standardized function wrapping
* Started scaffolding for other hook types, refactored get_hooks_from_cond to organize hooks by type
* Fix skip_until_exit logic bug breaking injection after first run of model
* Updated clone_has_same_weights function to account for new ModelPatcher properties, improved AutoPatcherEjector usage in partially_load
* Added WrapperExecutor for non-classbound functions, added calc_cond_batch wrappers
* Refactored callbacks+wrappers to allow storing lists by id
* Added forward_timestep_embed_patch type, added helper functions on ModelPatcher for emb_patch and forward_timestep_embed_patch, added helper functions for removing callbacks/wrappers/additional_models by key, added custom_should_register prop to hooks
* Added get_attachment func on ModelPatcher
* Implement basic MemoryCounter system for determing with cached weights due to hooks should be offloaded in hooks_backup
* Modified ControlNet/T2IAdapter get_control function to receive transformer_options as additional parameter, made the model_options stored in extra_args in inner_sample be a clone of the original model_options instead of same ref
* Added create_model_options_clone func, modified type annotations to use __future__ so that I can use the better type annotations
* Refactored WrapperExecutor code to remove need for WrapperClassExecutor (now gone), added sampler.sample wrapper (pending review, will likely keep but will see what hacks this could currently let me get rid of in ACN/ADE)
* Added Combine versions of Cond/Cond Pair Set Props nodes, renamed Pair Cond to Cond Pair, fixed default conds never applying hooks (due to hooks key typo)
* Renamed Create Hook Model As LoRA nodes to make the test node the main one (more changes pending)
* Added uuid to conds in CFGGuider and uuids to transformer_options to allow uniquely identifying conds in batches during sampling
* Fixed models not being unloaded properly due to current_patcher reference; the current ComfyUI model cleanup code requires that nothing else has a reference to the ModelPatcher instances
* Fixed default conds not respecting hook keyframes, made keyframes not reset cache when strength is unchanged, fixed Cond Set Default Combine throwing error, fixed model-as-lora throwing error during calculate_weight after a recent ComfyUI update, small refactoring/scaffolding changes for hooks
* Changed CreateHookModelAsLoraTest to be the new CreateHookModelAsLora, rename old ones as 'direct' and will be removed prior to merge
* Added initial support within CLIP Text Encode (Prompt) node for scheduling weight hook CLIP strength via clip_start_percent/clip_end_percent on conds, added schedule_clip toggle to Set CLIP Hooks node, small cleanup/fixes
* Fix range check in get_hooks_for_clip_schedule so that proper keyframes get assigned to corresponding ranges
* Optimized CLIP hook scheduling to treat same strength as same keyframe
* Less fragile memory management.
* Make encode_from_tokens_scheduled call cleaner, rollback change in model_patcher.py for hook_patches_backup dict
* Fix issue.
* Remove useless function.
* Prevent and detect some types of memory leaks.
* Run garbage collector when switching workflow if needed.
* Moved WrappersMP/CallbacksMP/WrapperExecutor to patcher_extension.py
* Refactored code to store wrappers and callbacks in transformer_options, added apply_model and diffusion_model.forward wrappers
* Fix issue.
* Refactored hooks in calc_cond_batch to be part of get_area_and_mult tuple, added extra_hooks to ControlBase to allow custom controlnets w/ hooks, small cleanup and renaming
* Fixed inconsistency of results when schedule_clip is set to False, small renaming/typo fixing, added initial support for ControlNet extra_hooks to work in tandem with normal cond hooks, initial work on calc_cond_batch merging all subdicts in returned transformer_options
* Modified callbacks and wrappers so that unregistered types can be used, allowing custom_nodes to have their own unique callbacks/wrappers if desired
* Updated different hook types to reflect actual progress of implementation, initial scaffolding for working WrapperHook functionality
* Fixed existing weight hook_patches (pre-registered) not working properly for CLIP
* Removed Register/Direct hook nodes since they were present only for testing, removed diff-related weight hook calculation as improved_memory removes unload_model_clones and using sample time registered hooks is less hacky
* Added clip scheduling support to all other native ComfyUI text encoding nodes (sdxl, flux, hunyuan, sd3)
* Made WrapperHook functional, added another wrapper/callback getter, added ON_DETACH callback to ModelPatcher
* Made opt_hooks append by default instead of replace, renamed comfy.hooks set functions to be more accurate
* Added apply_to_conds to Set CLIP Hooks, modified relevant code to allow text encoding to automatically apply hooks to output conds when apply_to_conds is set to True
* Fix cached_hook_patches not respecting target_device/memory_counter results
* Fixed issue with setting weights from hooks instead of copying them, added additional memory_counter check when caching hook patches
* Remove unnecessary torch.no_grad calls for hook patches
* Increased MemoryCounter minimum memory to leave free by *2 until a better way to get inference memory estimate of currently loaded models exists
* For encode_from_tokens_scheduled, allow start_percent and end_percent in add_dict to limit which scheduled conds get encoded for optimization purposes
* Removed a .to call on results of calculate_weight in patch_hook_weight_to_device that was screwing up the intermediate results for fp8 prior to being passed into stochastic_rounding call
* Made encode_from_tokens_scheduled work when no hooks are set on patcher
* Small cleanup of comments
* Turn off hook patch caching when only 1 hook present in sampling, replace some current_hook = None with calls to self.patch_hooks(None) instead to avoid a potential edge case
* On Cond/Cond Pair nodes, removed opt_ prefix from optional inputs
* Allow both FLOATS and FLOAT for floats_strength input
* Revert change, does not work
* Made patch_hook_weight_to_device respect set_func and convert_func
* Make discard_model_sampling True by default
* Add changes manually from 'master' so merge conflict resolution goes more smoothly
* Cleaned up text encode nodes with just a single clip.encode_from_tokens_scheduled call
* Make sure encode_from_tokens_scheduled will respect use_clip_schedule on clip
* Made nodes in nodes_hooks be marked as experimental (beta)
* Add get_nested_additional_models for cases where additional_models could have their own additional_models, and add robustness for circular additional_models references
* Made finalize_default_conds area math consistent with other sampling code
* Changed 'opt_hooks' input of Cond/Cond Pair Set Default Combine nodes to 'hooks'
* Remove a couple old TODO's and a no longer necessary workaround
2024-12-02 19:51:02 +00:00
self . add_hooks_to_dict ( out )
2024-07-11 00:06:50 +00:00
return out
2023-04-19 13:36:19 +00:00
if return_pooled :
2023-07-01 17:22:51 +00:00
return cond , pooled
return cond
2023-01-03 06:53:32 +00:00
2023-04-15 22:46:58 +00:00
def encode ( self , text ) :
2023-04-15 22:55:17 +00:00
tokens = self . tokenize ( text )
2023-04-15 22:46:58 +00:00
return self . encode_from_tokens ( tokens )
2024-02-19 15:29:18 +00:00
def load_sd ( self , sd , full_model = False ) :
if full_model :
return self . cond_stage_model . load_state_dict ( sd , strict = False )
else :
return self . cond_stage_model . load_sd ( sd )
2023-06-22 17:03:50 +00:00
2023-06-26 16:21:07 +00:00
def get_sd ( self ) :
2024-07-25 14:52:09 +00:00
sd_clip = self . cond_stage_model . state_dict ( )
sd_tokenizer = self . tokenizer . state_dict ( )
for k in sd_tokenizer :
sd_clip [ k ] = sd_tokenizer [ k ]
return sd_clip
2023-06-26 16:21:07 +00:00
2023-08-17 14:58:59 +00:00
def load_model ( self ) :
model_management . load_model_gpu ( self . patcher )
return self . patcher
2023-06-26 16:21:07 +00:00
2023-07-14 06:37:30 +00:00
def get_key_patches ( self ) :
return self . patcher . get_key_patches ( )
2023-01-03 06:53:32 +00:00
class VAE :
2024-06-16 06:04:24 +00:00
def __init__ ( self , sd = None , device = None , config = None , dtype = None ) :
2023-10-17 18:51:51 +00:00
if ' decoder.up_blocks.0.resnets.0.norm1.weight ' in sd . keys ( ) : #diffusers format
sd = diffusers_convert . convert_vae_state_dict ( sd )
2023-11-22 23:16:02 +00:00
self . memory_used_encode = lambda shape , dtype : ( 1767 * shape [ 2 ] * shape [ 3 ] ) * model_management . dtype_size ( dtype ) #These are for AutoencoderKL and need tweaking (should be lower)
self . memory_used_decode = lambda shape , dtype : ( 2178 * shape [ 2 ] * shape [ 3 ] * 64 ) * model_management . dtype_size ( dtype )
2024-01-02 18:24:34 +00:00
self . downscale_ratio = 8
2024-02-19 09:06:49 +00:00
self . upscale_ratio = 8
2024-06-16 06:04:24 +00:00
self . latent_channels = 4
2024-11-01 21:33:09 +00:00
self . latent_dim = 2
2024-06-15 16:14:56 +00:00
self . output_channels = 3
2024-02-16 11:30:39 +00:00
self . process_input = lambda image : image * 2.0 - 1.0
self . process_output = lambda image : torch . clamp ( ( image + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
2024-06-16 17:12:54 +00:00
self . working_dtypes = [ torch . bfloat16 , torch . float32 ]
2023-11-21 17:54:19 +00:00
2023-01-03 06:53:32 +00:00
if config is None :
2023-11-24 00:41:33 +00:00
if " decoder.mid.block_1.mix_factor " in sd :
encoder_config = { ' double_z ' : True , ' z_channels ' : 4 , ' resolution ' : 256 , ' in_channels ' : 3 , ' out_ch ' : 3 , ' ch ' : 128 , ' ch_mult ' : [ 1 , 2 , 4 , 4 ] , ' num_res_blocks ' : 2 , ' attn_resolutions ' : [ ] , ' dropout ' : 0.0 }
decoder_config = encoder_config . copy ( )
decoder_config [ " video_kernel_size " ] = [ 3 , 1 , 1 ]
decoder_config [ " alpha " ] = 0.0
self . first_stage_model = AutoencodingEngine ( regularizer_config = { ' target ' : " comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer " } ,
encoder_config = { ' target ' : " comfy.ldm.modules.diffusionmodules.model.Encoder " , ' params ' : encoder_config } ,
decoder_config = { ' target ' : " comfy.ldm.modules.temporal_ae.VideoDecoder " , ' params ' : decoder_config } )
elif " taesd_decoder.1.weight " in sd :
2024-06-16 07:10:04 +00:00
self . latent_channels = sd [ " taesd_decoder.1.weight " ] . shape [ 1 ]
self . first_stage_model = comfy . taesd . taesd . TAESD ( latent_channels = self . latent_channels )
2024-02-16 11:30:39 +00:00
elif " vquantizer.codebook.weight " in sd : #VQGan: stage a of stable cascade
self . first_stage_model = StageA ( )
self . downscale_ratio = 4
2024-02-19 09:06:49 +00:00
self . upscale_ratio = 4
2024-02-16 11:30:39 +00:00
#TODO
#self.memory_used_encode
#self.memory_used_decode
self . process_input = lambda image : image
self . process_output = lambda image : image
2024-02-19 09:06:49 +00:00
elif " backbone.1.0.block.0.1.num_batches_tracked " in sd : #effnet: encoder for stage c latent of stable cascade
self . first_stage_model = StageC_coder ( )
self . downscale_ratio = 32
self . latent_channels = 16
new_sd = { }
for k in sd :
new_sd [ " encoder. {} " . format ( k ) ] = sd [ k ]
sd = new_sd
elif " blocks.11.num_batches_tracked " in sd : #previewer: decoder for stage c latent of stable cascade
self . first_stage_model = StageC_coder ( )
self . latent_channels = 16
new_sd = { }
for k in sd :
new_sd [ " previewer. {} " . format ( k ) ] = sd [ k ]
sd = new_sd
elif " encoder.backbone.1.0.block.0.1.num_batches_tracked " in sd : #combined effnet and previewer for stable cascade
self . first_stage_model = StageC_coder ( )
self . downscale_ratio = 32
self . latent_channels = 16
2024-04-24 13:20:31 +00:00
elif " decoder.conv_in.weight " in sd :
2023-11-21 17:54:19 +00:00
#default SD1.x/SD2.x VAE parameters
ddconfig = { ' double_z ' : True , ' z_channels ' : 4 , ' resolution ' : 256 , ' in_channels ' : 3 , ' out_ch ' : 3 , ' ch ' : 128 , ' ch_mult ' : [ 1 , 2 , 4 , 4 ] , ' num_res_blocks ' : 2 , ' attn_resolutions ' : [ ] , ' dropout ' : 0.0 }
2024-01-03 08:30:39 +00:00
2024-04-19 01:05:33 +00:00
if ' encoder.down.2.downsample.conv.weight ' not in sd and ' decoder.up.3.upsample.conv.weight ' not in sd : #Stable diffusion x4 upscaler VAE
2024-01-03 08:30:39 +00:00
ddconfig [ ' ch_mult ' ] = [ 1 , 2 , 4 ]
self . downscale_ratio = 4
2024-02-19 09:06:49 +00:00
self . upscale_ratio = 4
2024-01-03 08:30:39 +00:00
2024-04-19 01:05:33 +00:00
self . latent_channels = ddconfig [ ' z_channels ' ] = sd [ " decoder.conv_in.weight " ] . shape [ 1 ]
2024-12-17 00:35:40 +00:00
if ' post_quant_conv.weight ' in sd :
self . first_stage_model = AutoencoderKL ( ddconfig = ddconfig , embed_dim = sd [ ' post_quant_conv.weight ' ] . shape [ 1 ] )
2024-04-19 01:05:33 +00:00
else :
self . first_stage_model = AutoencodingEngine ( regularizer_config = { ' target ' : " comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer " } ,
encoder_config = { ' target ' : " comfy.ldm.modules.diffusionmodules.model.Encoder " , ' params ' : ddconfig } ,
decoder_config = { ' target ' : " comfy.ldm.modules.diffusionmodules.model.Decoder " , ' params ' : ddconfig } )
2024-12-17 00:35:40 +00:00
elif " decoder.conv_in.conv.weight " in sd :
ddconfig = { ' double_z ' : True , ' z_channels ' : 4 , ' resolution ' : 256 , ' in_channels ' : 3 , ' out_ch ' : 3 , ' ch ' : 128 , ' ch_mult ' : [ 1 , 2 , 4 , 4 ] , ' num_res_blocks ' : 2 , ' attn_resolutions ' : [ ] , ' dropout ' : 0.0 }
ddconfig [ " conv3d " ] = True
ddconfig [ " time_compress " ] = 4
self . upscale_ratio = ( lambda a : max ( 0 , a * 4 - 3 ) , 8 , 8 )
self . latent_dim = 3
self . latent_channels = ddconfig [ ' z_channels ' ] = sd [ " decoder.conv_in.conv.weight " ] . shape [ 1 ]
self . first_stage_model = AutoencoderKL ( ddconfig = ddconfig , embed_dim = sd [ ' post_quant_conv.weight ' ] . shape [ 1 ] )
self . memory_used_decode = lambda shape , dtype : ( 1500 * shape [ 2 ] * shape [ 3 ] * shape [ 4 ] * ( 4 * 8 * 8 ) ) * model_management . dtype_size ( dtype )
self . memory_used_encode = lambda shape , dtype : ( 900 * max ( shape [ 2 ] , 2 ) * shape [ 3 ] * shape [ 4 ] ) * model_management . dtype_size ( dtype )
2024-12-17 04:31:10 +00:00
self . working_dtypes = [ torch . bfloat16 , torch . float16 , torch . float32 ]
2024-12-17 00:35:40 +00:00
2024-06-27 15:06:52 +00:00
elif " decoder.layers.1.layers.0.beta " in sd :
2024-06-15 16:14:56 +00:00
self . first_stage_model = AudioOobleckVAE ( )
2024-06-16 15:47:32 +00:00
self . memory_used_encode = lambda shape , dtype : ( 1000 * shape [ 2 ] ) * model_management . dtype_size ( dtype )
self . memory_used_decode = lambda shape , dtype : ( 1000 * shape [ 2 ] * 2048 ) * model_management . dtype_size ( dtype )
2024-06-15 16:14:56 +00:00
self . latent_channels = 64
self . output_channels = 2
self . upscale_ratio = 2048
self . downscale_ratio = 2048
2024-11-01 21:33:09 +00:00
self . latent_dim = 1
2024-06-15 16:14:56 +00:00
self . process_output = lambda audio : audio
self . process_input = lambda audio : audio
2024-06-16 17:12:54 +00:00
self . working_dtypes = [ torch . float16 , torch . bfloat16 , torch . float32 ]
2024-11-05 08:42:58 +00:00
elif " blocks.2.blocks.3.stack.5.weight " in sd or " decoder.blocks.2.blocks.3.stack.5.weight " in sd or " layers.4.layers.1.attn_block.attn.qkv.weight " in sd or " encoder.layers.4.layers.1.attn_block.attn.qkv.weight " in sd : #genmo mochi vae
2024-10-26 10:54:00 +00:00
if " blocks.2.blocks.3.stack.5.weight " in sd :
sd = comfy . utils . state_dict_prefix_replace ( sd , { " " : " decoder. " } )
2024-11-01 21:33:09 +00:00
if " layers.4.layers.1.attn_block.attn.qkv.weight " in sd :
sd = comfy . utils . state_dict_prefix_replace ( sd , { " " : " encoder. " } )
2024-10-26 10:54:00 +00:00
self . first_stage_model = comfy . ldm . genmo . vae . model . VideoVAE ( )
self . latent_channels = 12
2024-11-01 21:33:09 +00:00
self . latent_dim = 3
2024-10-26 10:54:00 +00:00
self . memory_used_decode = lambda shape , dtype : ( 1000 * shape [ 2 ] * shape [ 3 ] * shape [ 4 ] * ( 6 * 8 * 8 ) ) * model_management . dtype_size ( dtype )
2024-11-01 21:33:09 +00:00
self . memory_used_encode = lambda shape , dtype : ( 1.5 * max ( shape [ 2 ] , 7 ) * shape [ 3 ] * shape [ 4 ] * ( 6 * 8 * 8 ) ) * model_management . dtype_size ( dtype )
2024-10-26 10:54:00 +00:00
self . upscale_ratio = ( lambda a : max ( 0 , a * 6 - 5 ) , 8 , 8 )
2024-11-01 21:33:09 +00:00
self . working_dtypes = [ torch . float16 , torch . float32 ]
2024-11-22 13:44:42 +00:00
elif " decoder.up_blocks.0.res_blocks.0.conv1.conv.weight " in sd : #lightricks ltxv
self . first_stage_model = comfy . ldm . lightricks . vae . causal_video_autoencoder . VideoVAE ( )
self . latent_channels = 128
self . latent_dim = 3
self . memory_used_decode = lambda shape , dtype : ( 900 * shape [ 2 ] * shape [ 3 ] * shape [ 4 ] * ( 8 * 8 * 8 ) ) * model_management . dtype_size ( dtype )
self . memory_used_encode = lambda shape , dtype : ( 70 * max ( shape [ 2 ] , 7 ) * shape [ 3 ] * shape [ 4 ] ) * model_management . dtype_size ( dtype )
2024-11-22 23:00:34 +00:00
self . upscale_ratio = ( lambda a : max ( 0 , a * 8 - 7 ) , 32 , 32 )
2024-11-22 13:44:42 +00:00
self . working_dtypes = [ torch . bfloat16 , torch . float32 ]
2024-04-24 13:20:31 +00:00
else :
logging . warning ( " WARNING: No VAE weights detected, VAE not initalized. " )
self . first_stage_model = None
return
2023-01-03 06:53:32 +00:00
else :
2023-05-28 06:02:09 +00:00
self . first_stage_model = AutoencoderKL ( * * ( config [ ' params ' ] ) )
2023-01-03 06:53:32 +00:00
self . first_stage_model = self . first_stage_model . eval ( )
2023-10-17 18:51:51 +00:00
m , u = self . first_stage_model . load_state_dict ( sd , strict = False )
if len ( m ) > 0 :
2024-03-10 15:37:08 +00:00
logging . warning ( " Missing VAE keys {} " . format ( m ) )
2023-10-17 18:51:51 +00:00
if len ( u ) > 0 :
2024-03-11 17:54:56 +00:00
logging . debug ( " Leftover VAE keys {} " . format ( u ) )
2023-05-28 06:02:09 +00:00
2023-03-06 15:50:50 +00:00
if device is None :
2023-07-01 19:22:40 +00:00
device = model_management . vae_device ( )
2023-01-03 06:53:32 +00:00
self . device = device
2023-11-28 09:58:32 +00:00
offload_device = model_management . vae_offload_device ( )
2023-12-12 17:03:29 +00:00
if dtype is None :
2024-06-16 17:12:54 +00:00
dtype = model_management . vae_dtype ( self . device , self . working_dtypes )
2023-12-12 17:03:29 +00:00
self . vae_dtype = dtype
2023-07-06 22:04:28 +00:00
self . first_stage_model . to ( self . vae_dtype )
2023-12-08 07:35:45 +00:00
self . output_device = model_management . intermediate_device ( )
2023-01-03 06:53:32 +00:00
2023-11-28 09:58:32 +00:00
self . patcher = comfy . model_patcher . ModelPatcher ( self . first_stage_model , load_device = self . device , offload_device = offload_device )
2024-06-16 17:12:54 +00:00
logging . debug ( " VAE load device: {} , offload device: {} , dtype: {} " . format ( self . device , offload_device , self . vae_dtype ) )
2023-11-28 09:58:32 +00:00
2024-02-19 09:06:49 +00:00
def vae_encode_crop_pixels ( self , pixels ) :
2024-06-15 16:14:56 +00:00
dims = pixels . shape [ 1 : - 1 ]
for d in range ( len ( dims ) ) :
x = ( dims [ d ] / / self . downscale_ratio ) * self . downscale_ratio
x_offset = ( dims [ d ] % self . downscale_ratio ) / / 2
if x != dims [ d ] :
pixels = pixels . narrow ( d + 1 , x_offset , x )
2024-02-19 09:06:49 +00:00
return pixels
2023-03-22 18:49:00 +00:00
def decode_tiled_ ( self , samples , tile_x = 64 , tile_y = 64 , overlap = 16 ) :
2023-08-25 21:25:39 +00:00
steps = samples . shape [ 0 ] * comfy . utils . get_tiled_scale_steps ( samples . shape [ 3 ] , samples . shape [ 2 ] , tile_x , tile_y , overlap )
steps + = samples . shape [ 0 ] * comfy . utils . get_tiled_scale_steps ( samples . shape [ 3 ] , samples . shape [ 2 ] , tile_x / / 2 , tile_y * 2 , overlap )
steps + = samples . shape [ 0 ] * comfy . utils . get_tiled_scale_steps ( samples . shape [ 3 ] , samples . shape [ 2 ] , tile_x * 2 , tile_y / / 2 , overlap )
pbar = comfy . utils . ProgressBar ( steps )
2023-04-24 10:55:44 +00:00
2024-02-16 11:30:39 +00:00
decode_fn = lambda a : self . first_stage_model . decode ( a . to ( self . vae_dtype ) . to ( self . device ) ) . float ( )
output = self . process_output (
2024-02-19 09:06:49 +00:00
( comfy . utils . tiled_scale ( samples , decode_fn , tile_x / / 2 , tile_y * 2 , overlap , upscale_amount = self . upscale_ratio , output_device = self . output_device , pbar = pbar ) +
comfy . utils . tiled_scale ( samples , decode_fn , tile_x * 2 , tile_y / / 2 , overlap , upscale_amount = self . upscale_ratio , output_device = self . output_device , pbar = pbar ) +
comfy . utils . tiled_scale ( samples , decode_fn , tile_x , tile_y , overlap , upscale_amount = self . upscale_ratio , output_device = self . output_device , pbar = pbar ) )
2024-02-16 11:30:39 +00:00
/ 3.0 )
2023-03-22 18:49:00 +00:00
return output
2024-06-22 15:45:58 +00:00
def decode_tiled_1d ( self , samples , tile_x = 128 , overlap = 32 ) :
decode_fn = lambda a : self . first_stage_model . decode ( a . to ( self . vae_dtype ) . to ( self . device ) ) . float ( )
return comfy . utils . tiled_scale_multidim ( samples , decode_fn , tile = ( tile_x , ) , overlap = overlap , upscale_amount = self . upscale_ratio , out_channels = self . output_channels , output_device = self . output_device )
2024-06-18 02:48:23 +00:00
2024-10-26 10:54:00 +00:00
def decode_tiled_3d ( self , samples , tile_t = 999 , tile_x = 32 , tile_y = 32 , overlap = ( 1 , 8 , 8 ) ) :
decode_fn = lambda a : self . first_stage_model . decode ( a . to ( self . vae_dtype ) . to ( self . device ) ) . float ( )
return self . process_output ( comfy . utils . tiled_scale_multidim ( samples , decode_fn , tile = ( tile_t , tile_x , tile_y ) , overlap = overlap , upscale_amount = self . upscale_ratio , out_channels = self . output_channels , output_device = self . output_device ) )
2023-06-12 03:25:39 +00:00
def encode_tiled_ ( self , pixel_samples , tile_x = 512 , tile_y = 512 , overlap = 64 ) :
2023-08-25 21:25:39 +00:00
steps = pixel_samples . shape [ 0 ] * comfy . utils . get_tiled_scale_steps ( pixel_samples . shape [ 3 ] , pixel_samples . shape [ 2 ] , tile_x , tile_y , overlap )
steps + = pixel_samples . shape [ 0 ] * comfy . utils . get_tiled_scale_steps ( pixel_samples . shape [ 3 ] , pixel_samples . shape [ 2 ] , tile_x / / 2 , tile_y * 2 , overlap )
steps + = pixel_samples . shape [ 0 ] * comfy . utils . get_tiled_scale_steps ( pixel_samples . shape [ 3 ] , pixel_samples . shape [ 2 ] , tile_x * 2 , tile_y / / 2 , overlap )
pbar = comfy . utils . ProgressBar ( steps )
2023-06-12 03:25:39 +00:00
2024-02-16 11:30:39 +00:00
encode_fn = lambda a : self . first_stage_model . encode ( ( self . process_input ( a ) ) . to ( self . vae_dtype ) . to ( self . device ) ) . float ( )
2024-01-02 18:24:34 +00:00
samples = comfy . utils . tiled_scale ( pixel_samples , encode_fn , tile_x , tile_y , overlap , upscale_amount = ( 1 / self . downscale_ratio ) , out_channels = self . latent_channels , output_device = self . output_device , pbar = pbar )
samples + = comfy . utils . tiled_scale ( pixel_samples , encode_fn , tile_x * 2 , tile_y / / 2 , overlap , upscale_amount = ( 1 / self . downscale_ratio ) , out_channels = self . latent_channels , output_device = self . output_device , pbar = pbar )
samples + = comfy . utils . tiled_scale ( pixel_samples , encode_fn , tile_x / / 2 , tile_y * 2 , overlap , upscale_amount = ( 1 / self . downscale_ratio ) , out_channels = self . latent_channels , output_device = self . output_device , pbar = pbar )
2023-06-12 03:25:39 +00:00
samples / = 3.0
return samples
2024-06-22 15:45:58 +00:00
def encode_tiled_1d ( self , samples , tile_x = 128 * 2048 , overlap = 32 * 2048 ) :
encode_fn = lambda a : self . first_stage_model . encode ( ( self . process_input ( a ) ) . to ( self . vae_dtype ) . to ( self . device ) ) . float ( )
return comfy . utils . tiled_scale_multidim ( samples , encode_fn , tile = ( tile_x , ) , overlap = overlap , upscale_amount = ( 1 / self . downscale_ratio ) , out_channels = self . latent_channels , output_device = self . output_device )
2023-03-22 18:49:00 +00:00
def decode ( self , samples_in ) :
2024-10-26 10:54:00 +00:00
pixel_samples = None
2023-03-22 18:49:00 +00:00
try :
2023-11-22 23:16:02 +00:00
memory_used = self . memory_used_decode ( samples_in . shape , self . vae_dtype )
2023-11-28 09:58:32 +00:00
model_management . load_models_gpu ( [ self . patcher ] , memory_required = memory_used )
2023-03-29 06:24:37 +00:00
free_memory = model_management . get_free_memory ( self . device )
2023-08-17 05:06:34 +00:00
batch_number = int ( free_memory / memory_used )
2023-03-29 06:24:37 +00:00
batch_number = max ( 1 , batch_number )
for x in range ( 0 , samples_in . shape [ 0 ] , batch_number ) :
2023-07-06 22:04:28 +00:00
samples = samples_in [ x : x + batch_number ] . to ( self . vae_dtype ) . to ( self . device )
2024-10-26 10:54:00 +00:00
out = self . process_output ( self . first_stage_model . decode ( samples ) . to ( self . output_device ) . float ( ) )
if pixel_samples is None :
pixel_samples = torch . empty ( ( samples_in . shape [ 0 ] , ) + tuple ( out . shape [ 1 : ] ) , device = self . output_device )
pixel_samples [ x : x + batch_number ] = out
2024-12-12 22:59:16 +00:00
except model_management . OOM_EXCEPTION :
2024-03-10 15:37:08 +00:00
logging . warning ( " Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding. " )
2024-10-26 10:54:00 +00:00
dims = samples_in . ndim - 2
if dims == 1 :
2024-06-18 02:48:23 +00:00
pixel_samples = self . decode_tiled_1d ( samples_in )
2024-10-26 10:54:00 +00:00
elif dims == 2 :
2024-06-18 02:48:23 +00:00
pixel_samples = self . decode_tiled_ ( samples_in )
2024-10-26 10:54:00 +00:00
elif dims == 3 :
2024-11-22 23:00:34 +00:00
tile = 256 / / self . spacial_compression_decode ( )
overlap = tile / / 4
pixel_samples = self . decode_tiled_3d ( samples_in , tile_x = tile , tile_y = tile , overlap = ( 1 , overlap , overlap ) )
2023-03-22 18:49:00 +00:00
2023-12-08 07:35:45 +00:00
pixel_samples = pixel_samples . to ( self . output_device ) . movedim ( 1 , - 1 )
2023-01-03 06:53:32 +00:00
return pixel_samples
2024-11-07 08:47:12 +00:00
def decode_tiled ( self , samples , tile_x = None , tile_y = None , overlap = None ) :
2024-11-07 09:01:24 +00:00
memory_used = self . memory_used_decode ( samples . shape , self . vae_dtype ) #TODO: calculate mem required for tile
model_management . load_models_gpu ( [ self . patcher ] , memory_required = memory_used )
2024-11-07 08:47:12 +00:00
dims = samples . ndim - 2
args = { }
if tile_x is not None :
args [ " tile_x " ] = tile_x
if tile_y is not None :
args [ " tile_y " ] = tile_y
if overlap is not None :
args [ " overlap " ] = overlap
if dims == 1 :
args . pop ( " tile_y " )
output = self . decode_tiled_1d ( samples , * * args )
elif dims == 2 :
output = self . decode_tiled_ ( samples , * * args )
elif dims == 3 :
output = self . decode_tiled_3d ( samples , * * args )
return output . movedim ( 1 , - 1 )
2023-02-24 07:10:10 +00:00
2023-01-03 06:53:32 +00:00
def encode ( self , pixel_samples ) :
2024-02-19 09:06:49 +00:00
pixel_samples = self . vae_encode_crop_pixels ( pixel_samples )
2024-11-01 21:33:09 +00:00
pixel_samples = pixel_samples . movedim ( - 1 , 1 )
if self . latent_dim == 3 :
pixel_samples = pixel_samples . movedim ( 1 , 0 ) . unsqueeze ( 0 )
2023-06-12 03:25:39 +00:00
try :
2023-11-22 23:16:02 +00:00
memory_used = self . memory_used_encode ( pixel_samples . shape , self . vae_dtype )
2023-11-28 09:58:32 +00:00
model_management . load_models_gpu ( [ self . patcher ] , memory_required = memory_used )
2023-06-12 04:21:50 +00:00
free_memory = model_management . get_free_memory ( self . device )
2024-10-10 03:34:34 +00:00
batch_number = int ( free_memory / max ( 1 , memory_used ) )
2023-06-12 04:21:50 +00:00
batch_number = max ( 1 , batch_number )
2024-11-01 21:33:09 +00:00
samples = None
2023-06-12 03:25:39 +00:00
for x in range ( 0 , pixel_samples . shape [ 0 ] , batch_number ) :
2024-11-01 21:33:09 +00:00
pixels_in = self . process_input ( pixel_samples [ x : x + batch_number ] ) . to ( self . vae_dtype ) . to ( self . device )
out = self . first_stage_model . encode ( pixels_in ) . to ( self . output_device ) . float ( )
if samples is None :
samples = torch . empty ( ( pixel_samples . shape [ 0 ] , ) + tuple ( out . shape [ 1 : ] ) , device = self . output_device )
samples [ x : x + batch_number ] = out
2023-06-12 04:21:50 +00:00
2024-12-12 22:59:16 +00:00
except model_management . OOM_EXCEPTION :
2024-03-10 15:37:08 +00:00
logging . warning ( " Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding. " )
2024-06-22 15:45:58 +00:00
if len ( pixel_samples . shape ) == 3 :
samples = self . encode_tiled_1d ( pixel_samples )
else :
samples = self . encode_tiled_ ( pixel_samples )
2023-06-12 03:25:39 +00:00
2023-01-03 06:53:32 +00:00
return samples
2023-03-11 20:28:15 +00:00
def encode_tiled ( self , pixel_samples , tile_x = 512 , tile_y = 512 , overlap = 64 ) :
2024-02-19 09:06:49 +00:00
pixel_samples = self . vae_encode_crop_pixels ( pixel_samples )
2023-11-28 09:58:32 +00:00
model_management . load_model_gpu ( self . patcher )
2023-06-12 03:25:39 +00:00
pixel_samples = pixel_samples . movedim ( - 1 , 1 )
samples = self . encode_tiled_ ( pixel_samples , tile_x = tile_x , tile_y = tile_y , overlap = overlap )
2023-03-11 20:28:15 +00:00
return samples
2023-02-25 19:57:28 +00:00
2023-06-26 16:21:07 +00:00
def get_sd ( self ) :
return self . first_stage_model . state_dict ( )
2024-11-22 23:00:34 +00:00
def spacial_compression_decode ( self ) :
try :
return self . upscale_ratio [ - 1 ]
except :
return self . upscale_ratio
2023-03-05 23:39:25 +00:00
class StyleModel :
def __init__ ( self , model , device = " cpu " ) :
self . model = model
def get_cond ( self , input ) :
return self . model ( input . last_hidden_state )
def load_style_model ( ckpt_path ) :
2023-08-25 21:25:39 +00:00
model_data = comfy . utils . load_torch_file ( ckpt_path , safe_load = True )
2023-03-05 23:39:25 +00:00
keys = model_data . keys ( )
if " style_embedding " in keys :
2023-08-25 21:25:39 +00:00
model = comfy . t2i_adapter . adapter . StyleAdapter ( width = 1024 , context_dim = 768 , num_head = 8 , n_layes = 3 , num_token = 8 )
2024-11-21 13:38:23 +00:00
elif " redux_down.weight " in keys :
model = comfy . ldm . flux . redux . ReduxImageEncoder ( )
2023-03-05 23:39:25 +00:00
else :
raise Exception ( " invalid style model {} " . format ( ckpt_path ) )
model . load_state_dict ( model_data )
return StyleModel ( model )
2024-02-16 18:29:04 +00:00
class CLIPType ( Enum ) :
STABLE_DIFFUSION = 1
STABLE_CASCADE = 2
2024-06-12 03:27:39 +00:00
SD3 = 3
2024-06-15 16:14:56 +00:00
STABLE_AUDIO = 4
2024-07-25 22:21:08 +00:00
HUNYUAN_DIT = 5
2024-08-01 08:03:59 +00:00
FLUX = 6
2024-10-26 10:54:00 +00:00
MOCHI = 7
2024-11-22 13:44:42 +00:00
LTXV = 8
2024-12-17 00:35:40 +00:00
HUNYUAN_VIDEO = 9
2023-03-05 23:39:25 +00:00
2024-08-17 14:15:13 +00:00
def load_clip ( ckpt_paths , embedding_directory = None , clip_type = CLIPType . STABLE_DIFFUSION , model_options = { } ) :
2023-06-25 05:40:38 +00:00
clip_data = [ ]
for p in ckpt_paths :
2023-08-25 21:25:39 +00:00
clip_data . append ( comfy . utils . load_torch_file ( p , safe_load = True ) )
2024-08-19 21:36:35 +00:00
return load_text_encoder_state_dicts ( clip_data , embedding_directory = embedding_directory , clip_type = clip_type , model_options = model_options )
2023-06-25 05:40:38 +00:00
2024-10-01 11:08:41 +00:00
class TEModel ( Enum ) :
CLIP_L = 1
CLIP_H = 2
CLIP_G = 3
T5_XXL = 4
T5_XL = 5
T5_BASE = 6
2024-12-17 00:35:40 +00:00
LLAMA3_8 = 7
2024-10-01 11:08:41 +00:00
def detect_te_model ( sd ) :
if " text_model.encoder.layers.30.mlp.fc1.weight " in sd :
return TEModel . CLIP_G
if " text_model.encoder.layers.22.mlp.fc1.weight " in sd :
return TEModel . CLIP_H
if " text_model.encoder.layers.0.mlp.fc1.weight " in sd :
return TEModel . CLIP_L
if " encoder.block.23.layer.1.DenseReluDense.wi_1.weight " in sd :
weight = sd [ " encoder.block.23.layer.1.DenseReluDense.wi_1.weight " ]
if weight . shape [ - 1 ] == 4096 :
return TEModel . T5_XXL
elif weight . shape [ - 1 ] == 2048 :
return TEModel . T5_XL
if " encoder.block.0.layer.0.SelfAttention.k.weight " in sd :
return TEModel . T5_BASE
2024-12-17 00:35:40 +00:00
if " model.layers.0.post_attention_layernorm.weight " in sd :
return TEModel . LLAMA3_8
2024-10-01 11:08:41 +00:00
return None
2024-10-10 19:06:15 +00:00
2024-10-21 02:27:00 +00:00
def t5xxl_detect ( clip_data ) :
2024-10-10 19:06:15 +00:00
weight_name = " encoder.block.23.layer.1.DenseReluDense.wi_1.weight "
for sd in clip_data :
2024-10-21 02:27:00 +00:00
if weight_name in sd :
return comfy . text_encoders . sd3_clip . t5_xxl_detect ( sd )
return { }
2024-10-10 19:06:15 +00:00
2024-08-19 21:36:35 +00:00
def load_text_encoder_state_dicts ( state_dicts = [ ] , embedding_directory = None , clip_type = CLIPType . STABLE_DIFFUSION , model_options = { } ) :
clip_data = state_dicts
2024-10-09 23:43:17 +00:00
2023-06-24 17:56:46 +00:00
class EmptyClass :
pass
2023-06-25 05:40:38 +00:00
for i in range ( len ( clip_data ) ) :
if " transformer.resblocks.0.ln_1.weight " in clip_data [ i ] :
2024-02-25 06:41:08 +00:00
clip_data [ i ] = comfy . utils . clip_text_transformers_convert ( clip_data [ i ] , " " , " " )
2024-02-25 13:29:12 +00:00
else :
if " text_projection " in clip_data [ i ] :
clip_data [ i ] [ " text_projection.weight " ] = clip_data [ i ] [ " text_projection " ] . transpose ( 0 , 1 ) #old models saved with the CLIPSave node
2023-06-25 05:40:38 +00:00
2023-06-24 17:56:46 +00:00
clip_target = EmptyClass ( )
clip_target . params = { }
2023-06-25 05:40:38 +00:00
if len ( clip_data ) == 1 :
2024-10-01 11:08:41 +00:00
te_model = detect_te_model ( clip_data [ 0 ] )
if te_model == TEModel . CLIP_G :
2024-02-16 18:29:04 +00:00
if clip_type == CLIPType . STABLE_CASCADE :
clip_target . clip = sdxl_clip . StableCascadeClipModel
clip_target . tokenizer = sdxl_clip . StableCascadeTokenizer
2024-10-02 08:25:17 +00:00
elif clip_type == CLIPType . SD3 :
clip_target . clip = comfy . text_encoders . sd3_clip . sd3_clip ( clip_l = False , clip_g = True , t5 = False )
clip_target . tokenizer = comfy . text_encoders . sd3_clip . SD3Tokenizer
2024-02-16 18:29:04 +00:00
else :
clip_target . clip = sdxl_clip . SDXLRefinerClipModel
clip_target . tokenizer = sdxl_clip . SDXLTokenizer
2024-10-01 11:08:41 +00:00
elif te_model == TEModel . CLIP_H :
2024-07-28 05:19:20 +00:00
clip_target . clip = comfy . text_encoders . sd2_clip . SD2ClipModel
clip_target . tokenizer = comfy . text_encoders . sd2_clip . SD2Tokenizer
2024-10-01 11:08:41 +00:00
elif te_model == TEModel . T5_XXL :
2024-10-26 10:54:00 +00:00
if clip_type == CLIPType . SD3 :
clip_target . clip = comfy . text_encoders . sd3_clip . sd3_clip ( clip_l = False , clip_g = False , t5 = True , * * t5xxl_detect ( clip_data ) )
clip_target . tokenizer = comfy . text_encoders . sd3_clip . SD3Tokenizer
2024-11-22 13:44:42 +00:00
elif clip_type == CLIPType . LTXV :
clip_target . clip = comfy . text_encoders . lt . ltxv_te ( * * t5xxl_detect ( clip_data ) )
clip_target . tokenizer = comfy . text_encoders . lt . LTXVT5Tokenizer
2024-10-26 10:54:00 +00:00
else : #CLIPType.MOCHI
clip_target . clip = comfy . text_encoders . genmo . mochi_te ( * * t5xxl_detect ( clip_data ) )
clip_target . tokenizer = comfy . text_encoders . genmo . MochiT5Tokenizer
2024-10-01 11:08:41 +00:00
elif te_model == TEModel . T5_XL :
clip_target . clip = comfy . text_encoders . aura_t5 . AuraT5Model
clip_target . tokenizer = comfy . text_encoders . aura_t5 . AuraT5Tokenizer
elif te_model == TEModel . T5_BASE :
2024-07-15 21:36:24 +00:00
clip_target . clip = comfy . text_encoders . sa_t5 . SAT5Model
clip_target . tokenizer = comfy . text_encoders . sa_t5 . SAT5Tokenizer
2023-06-25 05:40:38 +00:00
else :
2024-10-02 08:25:17 +00:00
if clip_type == CLIPType . SD3 :
clip_target . clip = comfy . text_encoders . sd3_clip . sd3_clip ( clip_l = True , clip_g = False , t5 = False )
clip_target . tokenizer = comfy . text_encoders . sd3_clip . SD3Tokenizer
else :
clip_target . clip = sd1_clip . SD1ClipModel
clip_target . tokenizer = sd1_clip . SD1Tokenizer
2024-06-10 17:26:25 +00:00
elif len ( clip_data ) == 2 :
2024-06-12 03:27:39 +00:00
if clip_type == CLIPType . SD3 :
2024-10-03 13:26:11 +00:00
te_models = [ detect_te_model ( clip_data [ 0 ] ) , detect_te_model ( clip_data [ 1 ] ) ]
2024-10-21 02:27:00 +00:00
clip_target . clip = comfy . text_encoders . sd3_clip . sd3_clip ( clip_l = TEModel . CLIP_L in te_models , clip_g = TEModel . CLIP_G in te_models , t5 = TEModel . T5_XXL in te_models , * * t5xxl_detect ( clip_data ) )
2024-07-15 21:36:24 +00:00
clip_target . tokenizer = comfy . text_encoders . sd3_clip . SD3Tokenizer
2024-07-25 22:21:08 +00:00
elif clip_type == CLIPType . HUNYUAN_DIT :
clip_target . clip = comfy . text_encoders . hydit . HyditModel
clip_target . tokenizer = comfy . text_encoders . hydit . HyditTokenizer
2024-08-01 08:03:59 +00:00
elif clip_type == CLIPType . FLUX :
2024-10-21 02:27:00 +00:00
clip_target . clip = comfy . text_encoders . flux . flux_clip ( * * t5xxl_detect ( clip_data ) )
2024-08-01 08:03:59 +00:00
clip_target . tokenizer = comfy . text_encoders . flux . FluxTokenizer
2024-12-17 00:35:40 +00:00
elif clip_type == CLIPType . HUNYUAN_VIDEO :
clip_target . clip = comfy . text_encoders . hunyuan_video . hunyuan_video_clip ( ) #TODO
clip_target . tokenizer = comfy . text_encoders . hunyuan_video . HunyuanVideoTokenizer
2024-06-12 03:27:39 +00:00
else :
clip_target . clip = sdxl_clip . SDXLClipModel
clip_target . tokenizer = sdxl_clip . SDXLTokenizer
2024-06-10 17:26:25 +00:00
elif len ( clip_data ) == 3 :
2024-10-21 02:27:00 +00:00
clip_target . clip = comfy . text_encoders . sd3_clip . sd3_clip ( * * t5xxl_detect ( clip_data ) )
2024-07-15 21:36:24 +00:00
clip_target . tokenizer = comfy . text_encoders . sd3_clip . SD3Tokenizer
2023-06-24 17:56:46 +00:00
2024-08-12 04:06:01 +00:00
parameters = 0
2024-09-15 11:59:18 +00:00
tokenizer_data = { }
2024-08-12 04:06:01 +00:00
for c in clip_data :
parameters + = comfy . utils . calculate_parameters ( c )
2024-09-15 11:59:18 +00:00
tokenizer_data , model_options = comfy . text_encoders . long_clipl . model_options_long_clip ( c , tokenizer_data , model_options )
2024-08-12 04:06:01 +00:00
2024-09-15 11:59:18 +00:00
clip = CLIP ( clip_target , embedding_directory = embedding_directory , parameters = parameters , tokenizer_data = tokenizer_data , model_options = model_options )
2023-06-25 05:40:38 +00:00
for c in clip_data :
m , u = clip . load_sd ( c )
if len ( m ) > 0 :
2024-03-10 15:37:08 +00:00
logging . warning ( " clip missing: {} " . format ( m ) )
2023-06-25 05:40:38 +00:00
if len ( u ) > 0 :
2024-03-11 17:54:56 +00:00
logging . debug ( " clip unexpected: {} " . format ( u ) )
2023-02-05 20:20:18 +00:00
return clip
2023-01-03 06:53:32 +00:00
2023-04-19 13:36:19 +00:00
def load_gligen ( ckpt_path ) :
2023-08-25 21:25:39 +00:00
data = comfy . utils . load_torch_file ( ckpt_path , safe_load = True )
2023-04-19 13:36:19 +00:00
model = gligen . load_gligen ( data )
if model_management . should_use_fp16 ( ) :
model = model . half ( )
2023-08-28 18:49:18 +00:00
return comfy . model_patcher . ModelPatcher ( model , load_device = model_management . get_torch_device ( ) , offload_device = model_management . unet_offload_device ( ) )
2023-04-19 13:36:19 +00:00
2023-06-09 16:24:24 +00:00
def load_checkpoint ( config_path = None , ckpt_path = None , output_vae = True , output_clip = True , embedding_directory = None , state_dict = None , config = None ) :
2024-05-07 00:04:39 +00:00
logging . warning ( " Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one. " )
model , clip , vae , _ = load_checkpoint_guess_config ( ckpt_path , output_vae = output_vae , output_clip = output_clip , output_clipvision = False , embedding_directory = embedding_directory , output_model = True )
2023-06-23 06:14:12 +00:00
#TODO: this function is a mess and should be removed eventually
2023-06-09 16:24:24 +00:00
if config is None :
with open ( config_path , ' r ' ) as stream :
config = yaml . safe_load ( stream )
2023-01-03 06:53:32 +00:00
model_config_params = config [ ' model ' ] [ ' params ' ]
clip_config = model_config_params [ ' cond_stage_config ' ]
2023-06-09 16:24:24 +00:00
if " parameterization " in model_config_params :
if model_config_params [ " parameterization " ] == " v " :
2024-05-07 00:04:39 +00:00
m = model . clone ( )
class ModelSamplingAdvanced ( comfy . model_sampling . ModelSamplingDiscrete , comfy . model_sampling . V_PREDICTION ) :
pass
m . add_object_patch ( " model_sampling " , ModelSamplingAdvanced ( model . model . model_config ) )
model = m
2023-08-30 03:58:32 +00:00
2024-05-07 00:04:39 +00:00
layer_idx = clip_config . get ( " params " , { } ) . get ( " layer_idx " , None )
if layer_idx is not None :
clip . clip_layer ( layer_idx )
2023-06-22 17:03:50 +00:00
2024-05-07 00:04:39 +00:00
return ( model , clip , vae )
2023-03-03 08:37:35 +00:00
2024-08-17 14:15:13 +00:00
def load_checkpoint_guess_config ( ckpt_path , output_vae = True , output_clip = True , output_clipvision = False , embedding_directory = None , output_model = True , model_options = { } , te_model_options = { } ) :
2023-08-25 21:25:39 +00:00
sd = comfy . utils . load_torch_file ( ckpt_path )
2024-08-17 14:15:13 +00:00
out = load_state_dict_guess_config ( sd , output_vae , output_clip , output_clipvision , embedding_directory , output_model , model_options , te_model_options = te_model_options )
2024-08-11 12:37:35 +00:00
if out is None :
raise RuntimeError ( " ERROR: Could not detect model type of: {} " . format ( ckpt_path ) )
return out
2024-08-11 12:36:52 +00:00
2024-08-17 14:15:13 +00:00
def load_state_dict_guess_config ( sd , output_vae = True , output_clip = True , output_clipvision = False , embedding_directory = None , output_model = True , model_options = { } , te_model_options = { } ) :
2023-03-03 08:37:35 +00:00
clip = None
2023-04-02 03:19:15 +00:00
clipvision = None
2023-03-03 08:37:35 +00:00
vae = None
2023-06-22 17:03:50 +00:00
model = None
2023-10-06 17:48:18 +00:00
model_patcher = None
2023-03-03 08:37:35 +00:00
2024-06-15 16:14:56 +00:00
diffusion_model_prefix = model_detection . unet_prefix_from_state_dict ( sd )
parameters = comfy . utils . calculate_parameters ( sd , diffusion_model_prefix )
2024-08-03 17:45:19 +00:00
weight_dtype = comfy . utils . weight_dtype ( sd , diffusion_model_prefix )
2023-12-11 23:24:44 +00:00
load_device = model_management . get_torch_device ( )
2023-03-03 16:07:10 +00:00
2024-06-15 16:14:56 +00:00
model_config = model_detection . model_config_from_unet ( sd , diffusion_model_prefix )
2024-07-11 15:46:51 +00:00
if model_config is None :
2024-08-11 12:37:35 +00:00
return None
2024-07-11 15:46:51 +00:00
2024-08-03 19:06:40 +00:00
unet_weight_dtype = list ( model_config . supported_inference_dtypes )
2024-10-21 22:12:51 +00:00
if weight_dtype is not None and model_config . scaled_fp8 is None :
2024-08-03 19:06:40 +00:00
unet_weight_dtype . append ( weight_dtype )
2024-08-11 12:50:34 +00:00
model_config . custom_operations = model_options . get ( " custom_operations " , None )
2024-10-12 00:51:19 +00:00
unet_dtype = model_options . get ( " dtype " , model_options . get ( " weight_dtype " , None ) )
2024-08-11 12:50:34 +00:00
if unet_dtype is None :
unet_dtype = model_management . unet_dtype ( model_params = parameters , supported_dtypes = unet_weight_dtype )
2024-02-16 15:55:08 +00:00
manual_cast_dtype = model_management . unet_manual_cast ( unet_dtype , load_device , model_config . supported_inference_dtypes )
model_config . set_inference_dtype ( unet_dtype , manual_cast_dtype )
2023-12-11 23:24:44 +00:00
2023-06-22 17:03:50 +00:00
if model_config . clip_vision_prefix is not None :
2023-04-02 03:19:15 +00:00
if output_clipvision :
2023-06-23 05:08:05 +00:00
clipvision = clip_vision . load_clipvision_from_sd ( sd , model_config . clip_vision_prefix , True )
2023-03-03 08:37:35 +00:00
2023-10-06 17:48:18 +00:00
if output_model :
2023-10-13 18:35:21 +00:00
inital_load_device = model_management . unet_inital_load_device ( parameters , unet_dtype )
2024-06-15 16:14:56 +00:00
model = model_config . get_model ( sd , diffusion_model_prefix , device = inital_load_device )
model . load_model_weights ( sd , diffusion_model_prefix )
2023-04-02 03:19:15 +00:00
2023-06-22 17:03:50 +00:00
if output_vae :
2024-01-30 07:24:38 +00:00
vae_sd = comfy . utils . state_dict_prefix_replace ( sd , { k : " " for k in model_config . vae_key_prefix } , filter_keys = True )
2023-11-21 21:29:18 +00:00
vae_sd = model_config . process_vae_state_dict ( vae_sd )
2023-10-17 18:51:51 +00:00
vae = VAE ( sd = vae_sd )
2023-03-03 08:37:35 +00:00
2023-06-22 17:03:50 +00:00
if output_clip :
2024-06-11 17:14:43 +00:00
clip_target = model_config . clip_target ( state_dict = sd )
2023-10-18 23:48:36 +00:00
if clip_target is not None :
2024-02-19 15:29:18 +00:00
clip_sd = model_config . process_clip_state_dict ( sd )
if len ( clip_sd ) > 0 :
2024-08-12 03:50:01 +00:00
parameters = comfy . utils . calculate_parameters ( clip_sd )
2024-08-17 14:15:13 +00:00
clip = CLIP ( clip_target , embedding_directory = embedding_directory , tokenizer_data = clip_sd , parameters = parameters , model_options = te_model_options )
2024-02-19 15:29:18 +00:00
m , u = clip . load_sd ( clip_sd , full_model = True )
if len ( m ) > 0 :
2024-05-09 08:39:46 +00:00
m_filter = list ( filter ( lambda a : " .logit_scale " not in a and " .transformer.text_projection.weight " not in a , m ) )
if len ( m_filter ) > 0 :
logging . warning ( " clip missing: {} " . format ( m ) )
else :
logging . debug ( " clip missing: {} " . format ( m ) )
2024-02-19 15:29:18 +00:00
if len ( u ) > 0 :
2024-03-11 17:54:56 +00:00
logging . debug ( " clip unexpected {} : " . format ( u ) )
2024-02-13 05:01:08 +00:00
else :
2024-03-10 15:37:08 +00:00
logging . warning ( " no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded. " )
2023-06-09 16:24:24 +00:00
2023-06-22 17:03:50 +00:00
left_over = sd . keys ( )
if len ( left_over ) > 0 :
2024-03-11 17:54:56 +00:00
logging . debug ( " left over keys: {} " . format ( left_over ) )
2023-06-14 16:48:02 +00:00
2023-10-06 17:48:18 +00:00
if output_model :
2024-08-06 17:27:48 +00:00
model_patcher = comfy . model_patcher . ModelPatcher ( model , load_device = load_device , offload_device = model_management . unet_offload_device ( ) )
2023-10-06 17:48:18 +00:00
if inital_load_device != torch . device ( " cpu " ) :
2024-03-11 17:54:56 +00:00
logging . info ( " loaded straight to GPU " )
2024-08-13 03:42:21 +00:00
model_management . load_models_gpu ( [ model_patcher ] , force_full_load = True )
2023-08-17 05:06:34 +00:00
return ( model_patcher , clip , vae , clipvision )
2023-06-26 16:21:07 +00:00
2023-07-05 21:34:45 +00:00
2024-08-13 03:18:54 +00:00
def load_diffusion_model_state_dict ( sd , model_options = { } ) : #load unet in diffusers or regular format
dtype = model_options . get ( " dtype " , None )
2024-07-03 15:34:32 +00:00
#Allow loading unets from checkpoint files
diffusion_model_prefix = model_detection . unet_prefix_from_state_dict ( sd )
temp_sd = comfy . utils . state_dict_prefix_replace ( sd , { diffusion_model_prefix : " " } , filter_keys = True )
if len ( temp_sd ) > 0 :
sd = temp_sd
2023-08-25 21:25:39 +00:00
parameters = comfy . utils . calculate_parameters ( sd )
2024-10-20 03:47:42 +00:00
weight_dtype = comfy . utils . weight_dtype ( sd )
2023-12-11 23:24:44 +00:00
load_device = model_management . get_torch_device ( )
2024-07-11 15:37:31 +00:00
model_config = model_detection . model_config_from_unet ( sd , " " )
2023-12-11 23:24:44 +00:00
2024-07-11 15:37:31 +00:00
if model_config is not None :
2024-07-03 15:34:32 +00:00
new_sd = sd
2024-07-13 17:51:40 +00:00
else :
2024-06-20 01:46:37 +00:00
new_sd = model_detection . convert_diffusers_mmdit ( sd , " " )
2024-07-13 17:51:40 +00:00
if new_sd is not None : #diffusers mmdit
model_config = model_detection . model_config_from_unet ( new_sd , " " )
if model_config is None :
return None
else : #diffusers unet
model_config = model_detection . model_config_from_diffusers_unet ( sd )
if model_config is None :
return None
diffusers_keys = comfy . utils . unet_to_diffusers ( model_config . unet_config )
new_sd = { }
for k in diffusers_keys :
if k in sd :
new_sd [ diffusers_keys [ k ] ] = sd . pop ( k )
else :
logging . warning ( " {} {} " . format ( diffusers_keys [ k ] , k ) )
2024-02-16 15:55:08 +00:00
2023-07-22 02:58:16 +00:00
offload_device = model_management . unet_offload_device ( )
2024-10-20 03:47:42 +00:00
unet_weight_dtype = list ( model_config . supported_inference_dtypes )
2024-10-21 22:12:51 +00:00
if weight_dtype is not None and model_config . scaled_fp8 is None :
2024-10-20 03:47:42 +00:00
unet_weight_dtype . append ( weight_dtype )
2024-08-01 17:28:41 +00:00
if dtype is None :
2024-10-20 03:47:42 +00:00
unet_dtype = model_management . unet_dtype ( model_params = parameters , supported_dtypes = unet_weight_dtype )
2024-08-01 17:28:41 +00:00
else :
unet_dtype = dtype
2024-02-16 15:55:08 +00:00
manual_cast_dtype = model_management . unet_manual_cast ( unet_dtype , load_device , model_config . supported_inference_dtypes )
model_config . set_inference_dtype ( unet_dtype , manual_cast_dtype )
2024-09-19 09:01:00 +00:00
model_config . custom_operations = model_options . get ( " custom_operations " , model_config . custom_operations )
2024-10-09 23:43:17 +00:00
if model_options . get ( " fp8_optimizations " , False ) :
model_config . optimizations [ " fp8 " ] = True
2023-07-22 02:58:16 +00:00
model = model_config . get_model ( new_sd , " " )
model = model . to ( offload_device )
model . load_model_weights ( new_sd , " " )
2023-11-08 03:15:55 +00:00
left_over = sd . keys ( )
if len ( left_over ) > 0 :
2024-03-11 17:54:56 +00:00
logging . info ( " left over keys in unet: {} " . format ( left_over ) )
2023-12-11 23:24:44 +00:00
return comfy . model_patcher . ModelPatcher ( model , load_device = load_device , offload_device = offload_device )
2023-07-05 21:34:45 +00:00
2024-08-13 03:18:54 +00:00
def load_diffusion_model ( unet_path , model_options = { } ) :
2023-11-27 22:32:07 +00:00
sd = comfy . utils . load_torch_file ( unet_path )
2024-08-13 03:18:54 +00:00
model = load_diffusion_model_state_dict ( sd , model_options = model_options )
2023-11-27 22:32:07 +00:00
if model is None :
2024-03-10 15:37:08 +00:00
logging . error ( " ERROR UNSUPPORTED UNET {} " . format ( unet_path ) )
2023-11-27 22:32:07 +00:00
raise RuntimeError ( " ERROR: Could not detect model type of: {} " . format ( unet_path ) )
return model
2024-08-13 03:18:54 +00:00
def load_unet ( unet_path , dtype = None ) :
print ( " WARNING: the load_unet function has been deprecated and will be removed please switch to: load_diffusion_model " )
return load_diffusion_model ( unet_path , model_options = { " dtype " : dtype } )
def load_unet_state_dict ( sd , dtype = None ) :
print ( " WARNING: the load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict " )
return load_diffusion_model_state_dict ( sd , model_options = { " dtype " : dtype } )
2024-04-08 04:36:22 +00:00
def save_checkpoint ( output_path , model , clip = None , vae = None , clip_vision = None , metadata = None , extra_keys = { } ) :
2024-01-18 00:37:19 +00:00
clip_sd = None
load_models = [ model ]
if clip is not None :
load_models . append ( clip . load_model ( ) )
clip_sd = clip . get_sd ( )
2024-08-18 01:28:36 +00:00
vae_sd = None
if vae is not None :
vae_sd = vae . get_sd ( )
2024-01-18 00:37:19 +00:00
2024-05-12 01:46:05 +00:00
model_management . load_models_gpu ( load_models , force_patch_weights = True )
2024-01-18 00:37:19 +00:00
clip_vision_sd = clip_vision . get_sd ( ) if clip_vision is not None else None
2024-08-18 01:28:36 +00:00
sd = model . model . state_dict_for_saving ( clip_sd , vae_sd , clip_vision_sd )
2024-04-08 04:36:22 +00:00
for k in extra_keys :
sd [ k ] = extra_keys [ k ]
2024-07-03 00:21:51 +00:00
2024-07-03 00:16:33 +00:00
for k in sd :
2024-07-03 00:21:51 +00:00
t = sd [ k ]
if not t . is_contiguous ( ) :
sd [ k ] = t . contiguous ( )
2024-04-08 04:36:22 +00:00
2023-08-25 21:25:39 +00:00
comfy . utils . save_torch_file ( sd , output_path , metadata = metadata )