diff --git a/comfy_extras/nodes_model_downscale.py b/comfy_extras/nodes_model_downscale.py new file mode 100644 index 00000000..f1b2d3ff --- /dev/null +++ b/comfy_extras/nodes_model_downscale.py @@ -0,0 +1,45 @@ +import torch + +class PatchModelAddDownscale: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}), + "downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + + def patch(self, model, block_number, downscale_factor, start_percent, end_percent): + sigma_start = model.model.model_sampling.percent_to_sigma(start_percent).item() + sigma_end = model.model.model_sampling.percent_to_sigma(end_percent).item() + + def input_block_patch(h, transformer_options): + if transformer_options["block"][1] == block_number: + sigma = transformer_options["sigmas"][0].item() + if sigma <= sigma_start and sigma >= sigma_end: + h = torch.nn.functional.interpolate(h, scale_factor=(1.0 / downscale_factor), mode="bicubic", align_corners=False) + return h + + def output_block_patch(h, hsp, transformer_options): + if h.shape[2] != hsp.shape[2]: + h = torch.nn.functional.interpolate(h, size=(hsp.shape[2], hsp.shape[3]), mode="bicubic", align_corners=False) + return h, hsp + + m = model.clone() + m.set_model_input_block_patch(input_block_patch) + m.set_model_output_block_patch(output_block_patch) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "PatchModelAddDownscale": PatchModelAddDownscale, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + # Sampling + "PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)", +} diff --git a/nodes.py b/nodes.py index e8cfb5e6..f9d2d7f6 100644 --- a/nodes.py +++ b/nodes.py @@ -1799,6 +1799,7 @@ def init_custom_nodes(): "nodes_custom_sampler.py", "nodes_hypertile.py", "nodes_model_advanced.py", + "nodes_model_downscale.py", ] for node_file in extras_files: