mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-16 08:33:29 +00:00
Allow controlling downscale and upscale methods in PatchModelAddDownscale.
This commit is contained in:
parent
72741105a6
commit
c3ae99a749
@ -318,7 +318,9 @@ def bislerp(samples, width, height):
|
|||||||
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
|
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
|
||||||
coords_2 = coords_2.to(torch.int64)
|
coords_2 = coords_2.to(torch.int64)
|
||||||
return ratios, coords_1, coords_2
|
return ratios, coords_1, coords_2
|
||||||
|
|
||||||
|
orig_dtype = samples.dtype
|
||||||
|
samples = samples.float()
|
||||||
n,c,h,w = samples.shape
|
n,c,h,w = samples.shape
|
||||||
h_new, w_new = (height, width)
|
h_new, w_new = (height, width)
|
||||||
|
|
||||||
@ -347,7 +349,7 @@ def bislerp(samples, width, height):
|
|||||||
|
|
||||||
result = slerp(pass_1, pass_2, ratios)
|
result = slerp(pass_1, pass_2, ratios)
|
||||||
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
|
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
|
||||||
return result
|
return result.to(orig_dtype)
|
||||||
|
|
||||||
def lanczos(samples, width, height):
|
def lanczos(samples, width, height):
|
||||||
images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
|
images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
class PatchModelAddDownscale:
|
class PatchModelAddDownscale:
|
||||||
|
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "model": ("MODEL",),
|
return {"required": { "model": ("MODEL",),
|
||||||
@ -9,13 +11,15 @@ class PatchModelAddDownscale:
|
|||||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
|
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
"downscale_after_skip": ("BOOLEAN", {"default": True}),
|
"downscale_after_skip": ("BOOLEAN", {"default": True}),
|
||||||
|
"downscale_method": (s.upscale_methods,),
|
||||||
|
"upscale_method": (s.upscale_methods,),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "patch"
|
FUNCTION = "patch"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip):
|
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
|
||||||
sigma_start = model.model.model_sampling.percent_to_sigma(start_percent)
|
sigma_start = model.model.model_sampling.percent_to_sigma(start_percent)
|
||||||
sigma_end = model.model.model_sampling.percent_to_sigma(end_percent)
|
sigma_end = model.model.model_sampling.percent_to_sigma(end_percent)
|
||||||
|
|
||||||
@ -23,12 +27,12 @@ class PatchModelAddDownscale:
|
|||||||
if transformer_options["block"][1] == block_number:
|
if transformer_options["block"][1] == block_number:
|
||||||
sigma = transformer_options["sigmas"][0].item()
|
sigma = transformer_options["sigmas"][0].item()
|
||||||
if sigma <= sigma_start and sigma >= sigma_end:
|
if sigma <= sigma_start and sigma >= sigma_end:
|
||||||
h = torch.nn.functional.interpolate(h, scale_factor=(1.0 / downscale_factor), mode="bicubic", align_corners=False)
|
h = comfy.utils.common_upscale(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled")
|
||||||
return h
|
return h
|
||||||
|
|
||||||
def output_block_patch(h, hsp, transformer_options):
|
def output_block_patch(h, hsp, transformer_options):
|
||||||
if h.shape[2] != hsp.shape[2]:
|
if h.shape[2] != hsp.shape[2]:
|
||||||
h = torch.nn.functional.interpolate(h, size=(hsp.shape[2], hsp.shape[3]), mode="bicubic", align_corners=False)
|
h = comfy.utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled")
|
||||||
return h, hsp
|
return h, hsp
|
||||||
|
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
|
Loading…
Reference in New Issue
Block a user