diff --git a/comfy_extras/nodes_slg.py b/comfy_extras/nodes_slg.py index 8a1181fc..2fa09e25 100644 --- a/comfy_extras/nodes_slg.py +++ b/comfy_extras/nodes_slg.py @@ -16,7 +16,8 @@ class SkipLayerGuidanceDiT: "single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}), "start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}), - "end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}) + "end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}), + "rescaling_scale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "skip_guidance" @@ -26,7 +27,7 @@ class SkipLayerGuidanceDiT: CATEGORY = "advanced/guidance" - def skip_guidance(self, model, scale, start_percent, end_percent, double_layers="", single_layers=""): + def skip_guidance(self, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0): # check if layer is comma separated integers def skip(args, extra_args): return args @@ -65,6 +66,11 @@ class SkipLayerGuidanceDiT: if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start: (slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options) cfg_result = cfg_result + (cond_pred - slg) * scale + if rescaling_scale != 0: + factor = cond_pred.std() / cfg_result.std() + factor = rescaling_scale * factor + (1 - rescaling_scale) + cfg_result *= factor + return cfg_result m = model.clone()