From 0eea47d58086d31695f3e8e9d7ef36c6a6986faa Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Aug 2024 03:54:38 -0400 Subject: [PATCH] Add ModelSamplingFlux to experiment with the shift value. Default shift on Flux Schnell is 0.0 --- comfy_extras/nodes_model_advanced.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 22ba9547..fef8a487 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -170,6 +170,33 @@ class ModelSamplingAuraFlow(ModelSamplingSD3): def patch_aura(self, model, shift): return self.patch(model, shift, multiplier=1.0) +class ModelSamplingFlux: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 100.0, "step":0.01}), + }} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, shift): + m = model.clone() + + sampling_base = comfy.model_sampling.ModelSamplingFlux + sampling_type = comfy.model_sampling.CONST + + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling.set_parameters(shift=shift) + m.add_object_patch("model_sampling", model_sampling) + return (m, ) + + class ModelSamplingContinuousEDM: @classmethod def INPUT_TYPES(s): @@ -284,5 +311,6 @@ NODE_CLASS_MAPPINGS = { "ModelSamplingStableCascade": ModelSamplingStableCascade, "ModelSamplingSD3": ModelSamplingSD3, "ModelSamplingAuraFlow": ModelSamplingAuraFlow, + "ModelSamplingFlux": ModelSamplingFlux, "RescaleCFG": RescaleCFG, }