From 19373aee759be2f0868a69603c5d967e5e63e1c5 Mon Sep 17 00:00:00 2001 From: BVH <82035780+bvhari@users.noreply.github.com> Date: Fri, 18 Apr 2025 00:54:33 +0530 Subject: [PATCH] Add FreSca node (#7631) --- comfy_extras/nodes_fresca.py | 102 +++++++++++++++++++++++++++++++++++ nodes.py | 3 +- 2 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 comfy_extras/nodes_fresca.py diff --git a/comfy_extras/nodes_fresca.py b/comfy_extras/nodes_fresca.py new file mode 100644 index 00000000..b0b86f23 --- /dev/null +++ b/comfy_extras/nodes_fresca.py @@ -0,0 +1,102 @@ +# Code based on https://github.com/WikiChao/FreSca (MIT License) +import torch +import torch.fft as fft + + +def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20): + """ + Apply frequency-dependent scaling to an image tensor using Fourier transforms. + + Parameters: + x: Input tensor of shape (B, C, H, W) + scale_low: Scaling factor for low-frequency components (default: 1.0) + scale_high: Scaling factor for high-frequency components (default: 1.5) + freq_cutoff: Number of frequency indices around center to consider as low-frequency (default: 20) + + Returns: + x_filtered: Filtered version of x in spatial domain with frequency-specific scaling applied. + """ + # Preserve input dtype and device + dtype, device = x.dtype, x.device + + # Convert to float32 for FFT computations + x = x.to(torch.float32) + + # 1) Apply FFT and shift low frequencies to center + x_freq = fft.fftn(x, dim=(-2, -1)) + x_freq = fft.fftshift(x_freq, dim=(-2, -1)) + + # 2) Create a mask to scale frequencies differently + B, C, H, W = x_freq.shape + crow, ccol = H // 2, W // 2 + + # Initialize mask with high-frequency scaling factor + mask = torch.ones((B, C, H, W), device=device) * scale_high + + # Apply low-frequency scaling factor to center region + mask[ + ..., + crow - freq_cutoff : crow + freq_cutoff, + ccol - freq_cutoff : ccol + freq_cutoff, + ] = scale_low + + # 3) Apply frequency-specific scaling + x_freq = x_freq * mask + + # 4) Convert back to spatial domain + x_freq = fft.ifftshift(x_freq, dim=(-2, -1)) + x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real + + # 5) Restore original dtype + x_filtered = x_filtered.to(dtype) + + return x_filtered + + +class FreSca: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "scale_low": ("FLOAT", {"default": 1.0, "min": 0, "max": 10, "step": 0.01, + "tooltip": "Scaling factor for low-frequency components"}), + "scale_high": ("FLOAT", {"default": 1.25, "min": 0, "max": 10, "step": 0.01, + "tooltip": "Scaling factor for high-frequency components"}), + "freq_cutoff": ("INT", {"default": 20, "min": 1, "max": 100, "step": 1, + "tooltip": "Number of frequency indices around center to consider as low-frequency"}), + } + } + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + CATEGORY = "_for_testing" + DESCRIPTION = "Applies frequency-dependent scaling to the guidance" + def patch(self, model, scale_low, scale_high, freq_cutoff): + def custom_cfg_function(args): + cond = args["conds_out"][0] + uncond = args["conds_out"][1] + + guidance = cond - uncond + filtered_guidance = Fourier_filter( + guidance, + scale_low=scale_low, + scale_high=scale_high, + freq_cutoff=freq_cutoff, + ) + filtered_cond = filtered_guidance + uncond + + return [filtered_cond, uncond] + + m = model.clone() + m.set_model_sampler_pre_cfg_function(custom_cfg_function) + + return (m,) + + +NODE_CLASS_MAPPINGS = { + "FreSca": FreSca, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "FreSca": "FreSca", +} diff --git a/nodes.py b/nodes.py index ae0a2e18..fce3dcb3 100644 --- a/nodes.py +++ b/nodes.py @@ -2281,7 +2281,8 @@ def init_builtin_extra_nodes(): "nodes_primitive.py", "nodes_cfg.py", "nodes_optimalsteps.py", - "nodes_hidream.py" + "nodes_hidream.py", + "nodes_fresca.py", ] import_failed = []