mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-19 10:53:29 +00:00
Add FreSca node (#7631)
This commit is contained in:
parent
93292bc450
commit
19373aee75
102
comfy_extras/nodes_fresca.py
Normal file
102
comfy_extras/nodes_fresca.py
Normal file
@ -0,0 +1,102 @@
|
||||
# Code based on https://github.com/WikiChao/FreSca (MIT License)
|
||||
import torch
|
||||
import torch.fft as fft
|
||||
|
||||
|
||||
def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20):
|
||||
"""
|
||||
Apply frequency-dependent scaling to an image tensor using Fourier transforms.
|
||||
|
||||
Parameters:
|
||||
x: Input tensor of shape (B, C, H, W)
|
||||
scale_low: Scaling factor for low-frequency components (default: 1.0)
|
||||
scale_high: Scaling factor for high-frequency components (default: 1.5)
|
||||
freq_cutoff: Number of frequency indices around center to consider as low-frequency (default: 20)
|
||||
|
||||
Returns:
|
||||
x_filtered: Filtered version of x in spatial domain with frequency-specific scaling applied.
|
||||
"""
|
||||
# Preserve input dtype and device
|
||||
dtype, device = x.dtype, x.device
|
||||
|
||||
# Convert to float32 for FFT computations
|
||||
x = x.to(torch.float32)
|
||||
|
||||
# 1) Apply FFT and shift low frequencies to center
|
||||
x_freq = fft.fftn(x, dim=(-2, -1))
|
||||
x_freq = fft.fftshift(x_freq, dim=(-2, -1))
|
||||
|
||||
# 2) Create a mask to scale frequencies differently
|
||||
B, C, H, W = x_freq.shape
|
||||
crow, ccol = H // 2, W // 2
|
||||
|
||||
# Initialize mask with high-frequency scaling factor
|
||||
mask = torch.ones((B, C, H, W), device=device) * scale_high
|
||||
|
||||
# Apply low-frequency scaling factor to center region
|
||||
mask[
|
||||
...,
|
||||
crow - freq_cutoff : crow + freq_cutoff,
|
||||
ccol - freq_cutoff : ccol + freq_cutoff,
|
||||
] = scale_low
|
||||
|
||||
# 3) Apply frequency-specific scaling
|
||||
x_freq = x_freq * mask
|
||||
|
||||
# 4) Convert back to spatial domain
|
||||
x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
|
||||
x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
|
||||
|
||||
# 5) Restore original dtype
|
||||
x_filtered = x_filtered.to(dtype)
|
||||
|
||||
return x_filtered
|
||||
|
||||
|
||||
class FreSca:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"scale_low": ("FLOAT", {"default": 1.0, "min": 0, "max": 10, "step": 0.01,
|
||||
"tooltip": "Scaling factor for low-frequency components"}),
|
||||
"scale_high": ("FLOAT", {"default": 1.25, "min": 0, "max": 10, "step": 0.01,
|
||||
"tooltip": "Scaling factor for high-frequency components"}),
|
||||
"freq_cutoff": ("INT", {"default": 20, "min": 1, "max": 100, "step": 1,
|
||||
"tooltip": "Number of frequency indices around center to consider as low-frequency"}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
CATEGORY = "_for_testing"
|
||||
DESCRIPTION = "Applies frequency-dependent scaling to the guidance"
|
||||
def patch(self, model, scale_low, scale_high, freq_cutoff):
|
||||
def custom_cfg_function(args):
|
||||
cond = args["conds_out"][0]
|
||||
uncond = args["conds_out"][1]
|
||||
|
||||
guidance = cond - uncond
|
||||
filtered_guidance = Fourier_filter(
|
||||
guidance,
|
||||
scale_low=scale_low,
|
||||
scale_high=scale_high,
|
||||
freq_cutoff=freq_cutoff,
|
||||
)
|
||||
filtered_cond = filtered_guidance + uncond
|
||||
|
||||
return [filtered_cond, uncond]
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_pre_cfg_function(custom_cfg_function)
|
||||
|
||||
return (m,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"FreSca": FreSca,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"FreSca": "FreSca",
|
||||
}
|
Loading…
Reference in New Issue
Block a user