From dd5b57e3d7371cab793beedfdc20e60890b05a0e Mon Sep 17 00:00:00 2001 From: DenOfEquity <166248528+DenOfEquity@users.noreply.github.com> Date: Fri, 8 Nov 2024 23:16:29 +0000 Subject: [PATCH] fix for SAG with Kohya HRFix/ Deep Shrink (#5546) now works with arbitrary downscale factors --- comfy_extras/nodes_sag.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py index 5e15b99e..3f03533e 100644 --- a/comfy_extras/nodes_sag.py +++ b/comfy_extras/nodes_sag.py @@ -57,12 +57,17 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): attn = attn.reshape(b, -1, hw1, hw2) # Global Average Pool mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold - ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length() - mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)] + + f = float(lh) / float(lw) + fh = f ** 0.5 + fw = (1/f) ** 0.5 + S = mask.size(1) ** 0.5 + w = int(0.5 + S * fw) + h = int(0.5 + S * fh) # Reshape mask = ( - mask.reshape(b, *mid_shape) + mask.reshape(b, h, w) .unsqueeze(1) .type(attn.dtype) )