mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 14:09:36 +00:00
Fix SAG not working with cfg 1.0
This commit is contained in:
parent
8cf1daa108
commit
571ea8cdcc
@ -55,14 +55,18 @@ class ModelPatcher:
|
|||||||
def memory_required(self, input_shape):
|
def memory_required(self, input_shape):
|
||||||
return self.model.memory_required(input_shape=input_shape)
|
return self.model.memory_required(input_shape=input_shape)
|
||||||
|
|
||||||
def set_model_sampler_cfg_function(self, sampler_cfg_function):
|
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
||||||
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
||||||
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
||||||
else:
|
else:
|
||||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||||
|
if disable_cfg1_optimization:
|
||||||
|
self.model_options["disable_cfg1_optimization"] = True
|
||||||
|
|
||||||
def set_model_sampler_post_cfg_function(self, post_cfg_function):
|
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
||||||
self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
|
self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
|
||||||
|
if disable_cfg1_optimization:
|
||||||
|
self.model_options["disable_cfg1_optimization"] = True
|
||||||
|
|
||||||
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
||||||
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
||||||
|
@ -244,7 +244,7 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
|||||||
#The main sampling function shared by all the samplers
|
#The main sampling function shared by all the samplers
|
||||||
#Returns denoised
|
#Returns denoised
|
||||||
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||||
if math.isclose(cond_scale, 1.0):
|
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
|
||||||
uncond_ = None
|
uncond_ = None
|
||||||
else:
|
else:
|
||||||
uncond_ = uncond
|
uncond_ = uncond
|
||||||
|
@ -151,7 +151,7 @@ class SelfAttentionGuidance:
|
|||||||
(sag, _) = comfy.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options)
|
(sag, _) = comfy.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options)
|
||||||
return cfg_result + (degraded - sag) * sag_scale
|
return cfg_result + (degraded - sag) * sag_scale
|
||||||
|
|
||||||
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)
|
||||||
|
|
||||||
# from diffusers:
|
# from diffusers:
|
||||||
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
|
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
|
||||||
|
Loading…
Reference in New Issue
Block a user